From 40591a7c5067f8777eb5cdf660d01e1b7991bb2a Mon Sep 17 00:00:00 2001 From: 99 Date: Sat, 28 Mar 2026 05:05:32 +0800 Subject: [PATCH] refactor(api): use standalone graphon package (#34209) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .github/CODEOWNERS | 1 - api/.importlinter | 150 -- api/controllers/common/fields.py | 2 +- api/controllers/console/app/app.py | 4 +- api/controllers/console/app/audio.py | 2 +- api/controllers/console/app/completion.py | 2 +- api/controllers/console/app/generator.py | 2 +- api/controllers/console/app/message.py | 2 +- api/controllers/console/app/workflow.py | 8 +- .../console/app/workflow_app_log.py | 2 +- .../console/app/workflow_draft_variable.py | 8 +- api/controllers/console/app/workflow_run.py | 4 +- api/controllers/console/auth/oauth_server.py | 2 +- api/controllers/console/datasets/datasets.py | 2 +- .../console/datasets/datasets_document.py | 4 +- .../console/datasets/datasets_segments.py | 2 +- .../console/datasets/hit_testing_base.py | 2 +- .../datasets/rag_pipeline/datasource_auth.py | 4 +- .../rag_pipeline_draft_variable.py | 2 +- .../rag_pipeline/rag_pipeline_workflow.py | 2 +- api/controllers/console/explore/audio.py | 2 +- api/controllers/console/explore/completion.py | 2 +- api/controllers/console/explore/message.py | 2 +- api/controllers/console/explore/trial.py | 4 +- api/controllers/console/explore/workflow.py | 4 +- api/controllers/console/remote_files.py | 2 +- .../console/workspace/agent_providers.py | 2 +- api/controllers/console/workspace/endpoint.py | 2 +- .../workspace/load_balancing_config.py | 4 +- .../console/workspace/model_providers.py | 6 +- api/controllers/console/workspace/models.py | 6 +- api/controllers/console/workspace/plugin.py | 2 +- .../console/workspace/tool_providers.py | 2 +- .../console/workspace/trigger_providers.py | 2 +- api/controllers/inner_api/plugin/plugin.py | 2 +- api/controllers/mcp/mcp.py | 2 +- api/controllers/service_api/app/audio.py | 2 +- api/controllers/service_api/app/completion.py | 2 +- api/controllers/service_api/app/workflow.py | 6 +- .../service_api/dataset/dataset.py | 2 +- .../service_api/dataset/segment.py | 2 +- .../service_api/workspace/models.py | 2 +- api/controllers/web/audio.py | 2 +- api/controllers/web/completion.py | 2 +- api/controllers/web/message.py | 2 +- api/controllers/web/remote_files.py | 2 +- api/controllers/web/workflow.py | 4 +- api/core/agent/base_agent_runner.py | 28 +- api/core/agent/cot_agent_runner.py | 17 +- api/core/agent/cot_chat_agent_runner.py | 3 +- api/core/agent/cot_completion_agent_runner.py | 3 +- api/core/agent/fc_agent_runner.py | 15 +- .../agent/output_parser/cot_output_parser.py | 3 +- .../model_config/converter.py | 7 +- .../easy_ui_based_app/model_config/manager.py | 3 +- .../prompt_template/manager.py | 3 +- .../easy_ui_based_app/variables/manager.py | 3 +- api/core/app/app_config/entities.py | 6 +- .../features/file_upload/manager.py | 3 +- .../variables/manager.py | 3 +- .../app/apps/advanced_chat/app_generator.py | 9 +- api/core/app/apps/advanced_chat/app_runner.py | 12 +- .../advanced_chat/generate_task_pipeline.py | 12 +- api/core/app/apps/agent_chat/app_generator.py | 2 +- api/core/app/apps/agent_chat/app_runner.py | 6 +- .../base_app_generate_response_converter.py | 3 +- api/core/app/apps/base_app_generator.py | 6 +- api/core/app/apps/base_app_queue_manager.py | 2 +- api/core/app/apps/base_app_runner.py | 23 +- api/core/app/apps/chat/app_generator.py | 2 +- api/core/app/apps/chat/app_runner.py | 4 +- .../common/graph_runtime_state_support.py | 3 +- .../common/workflow_response_converter.py | 26 +- api/core/app/apps/completion/app_generator.py | 2 +- api/core/app/apps/completion/app_runner.py | 4 +- .../app/apps/pipeline/pipeline_generator.py | 4 +- api/core/app/apps/pipeline/pipeline_runner.py | 15 +- api/core/app/apps/workflow/app_generator.py | 8 +- api/core/app/apps/workflow/app_runner.py | 11 +- .../apps/workflow/generate_task_pipeline.py | 6 +- api/core/app/apps/workflow_app_runner.py | 68 +- api/core/app/entities/app_invoke_entities.py | 4 +- api/core/app/entities/queue_entities.py | 8 +- api/core/app/entities/task_entities.py | 8 +- .../hosting_moderation/hosting_moderation.py | 3 +- .../conversation_variable_persist_layer.py | 5 +- .../app/layers/pause_state_persist_layer.py | 5 +- api/core/app/layers/suspend_layer.py | 5 +- api/core/app/layers/timeslice_layer.py | 6 +- api/core/app/layers/trigger_post_layer.py | 5 +- api/core/app/llm/model_access.py | 9 +- api/core/app/llm/quota.py | 2 +- .../based_generate_task_pipeline.py | 2 +- .../easy_ui_based_generate_task_pipeline.py | 14 +- .../app/task_pipeline/message_file_utils.py | 5 +- api/core/app/workflow/file_runtime.py | 9 +- api/core/app/workflow/layers/llm_quota.py | 11 +- api/core/app/workflow/layers/observability.py | 8 +- api/core/app/workflow/layers/persistence.py | 17 +- .../base/tts/app_generator_tts_publisher.py | 5 +- api/core/datasource/datasource_manager.py | 8 +- api/core/datasource/entities/api_entities.py | 2 +- .../datasource/utils/message_transformer.py | 3 +- api/core/entities/execution_extra_content.py | 2 +- api/core/entities/mcp_provider.py | 2 +- api/core/entities/model_entities.py | 3 +- api/core/entities/provider_configuration.py | 20 +- api/core/entities/provider_entities.py | 2 +- .../helper/code_executor/code_executor.py | 2 +- api/core/helper/moderation.py | 7 +- api/core/hosting_configuration.py | 2 +- api/core/indexing_runner.py | 2 +- api/core/llm_generator/llm_generator.py | 10 +- .../output_parser/structured_output.py | 10 +- api/core/mcp/server/streamable_http.py | 3 +- api/core/mcp/utils.py | 2 +- api/core/memory/token_buffer_memory.py | 18 +- api/core/model_manager.py | 17 +- .../openai_moderation/openai_moderation.py | 3 +- api/core/ops/aliyun_trace/aliyun_trace.py | 4 +- api/core/ops/aliyun_trace/utils.py | 4 +- .../arize_phoenix_trace.py | 2 +- api/core/ops/langfuse_trace/langfuse_trace.py | 2 +- .../ops/langsmith_trace/langsmith_trace.py | 2 +- api/core/ops/mlflow_trace/mlflow_trace.py | 2 +- api/core/ops/opik_trace/opik_trace.py | 2 +- api/core/ops/tencent_trace/span_builder.py | 7 +- api/core/ops/tencent_trace/tencent_trace.py | 8 +- api/core/ops/weave_trace/weave_trace.py | 2 +- api/core/plugin/backwards_invocation/model.py | 27 +- api/core/plugin/backwards_invocation/node.py | 14 +- api/core/plugin/entities/marketplace.py | 2 +- api/core/plugin/entities/plugin.py | 2 +- api/core/plugin/entities/plugin_daemon.py | 4 +- api/core/plugin/entities/request.py | 19 +- api/core/plugin/impl/base.py | 16 +- api/core/plugin/impl/model.py | 13 +- api/core/plugin/impl/model_runtime.py | 14 +- api/core/plugin/impl/model_runtime_factory.py | 3 +- api/core/plugin/utils/converter.py | 3 +- api/core/prompt/advanced_prompt_transform.py | 18 +- .../prompt/agent_history_prompt_transform.py | 11 +- api/core/prompt/prompt_transform.py | 5 +- api/core/prompt/simple_prompt_transform.py | 15 +- api/core/prompt/utils/prompt_message_util.py | 3 +- api/core/provider_manager.py | 16 +- .../data_post_processor.py | 4 +- api/core/rag/datasource/retrieval_service.py | 2 +- api/core/rag/datasource/vdb/vector_factory.py | 2 +- api/core/rag/docstore/dataset_docstore.py | 2 +- api/core/rag/embedding/cached_embedding.py | 4 +- .../processor/paragraph_index_processor.py | 21 +- api/core/rag/models/document.py | 3 +- api/core/rag/rerank/rerank_model.py | 5 +- api/core/rag/rerank/weight_rerank.py | 2 +- api/core/rag/retrieval/dataset_retrieval.py | 10 +- .../multi_dataset_function_call_router.py | 5 +- .../router/multi_dataset_react_route.py | 7 +- api/core/rag/splitter/fixed_text_splitter.py | 3 +- .../celery_workflow_execution_repository.py | 2 +- ...lery_workflow_node_execution_repository.py | 2 +- api/core/repositories/factory.py | 2 +- .../repositories/human_input_repository.py | 4 +- ...qlalchemy_workflow_execution_repository.py | 6 +- ...hemy_workflow_node_execution_repository.py | 8 +- .../builtin_tool/providers/audio/tools/asr.py | 7 +- .../builtin_tool/providers/audio/tools/tts.py | 3 +- api/core/tools/builtin_tool/tool.py | 5 +- api/core/tools/custom_tool/tool.py | 2 +- api/core/tools/entities/api_entities.py | 2 +- api/core/tools/mcp_tool/tool.py | 3 +- api/core/tools/tool_engine.py | 3 +- api/core/tools/tool_file_manager.py | 2 +- api/core/tools/tool_manager.py | 5 +- .../dataset_multi_retriever_tool.py | 2 +- api/core/tools/utils/message_transformer.py | 2 +- .../tools/utils/model_invocation_utils.py | 7 +- .../utils/workflow_configuration_sync.py | 5 +- api/core/tools/workflow_as_tool/provider.py | 2 +- api/core/tools/workflow_as_tool/tool.py | 4 +- api/core/trigger/debug/event_selectors.py | 2 +- api/core/workflow/human_input_compat.py | 5 +- api/core/workflow/node_factory.py | 32 +- api/core/workflow/node_runtime.py | 57 +- api/core/workflow/nodes/agent/agent_node.py | 5 +- api/core/workflow/nodes/agent/entities.py | 4 +- .../nodes/agent/message_transformer.py | 16 +- .../workflow/nodes/agent/runtime_support.py | 4 +- .../nodes/datasource/datasource_node.py | 17 +- .../workflow/nodes/datasource/entities.py | 5 +- .../nodes/knowledge_index/entities.py | 4 +- .../knowledge_index/knowledge_index_node.py | 12 +- .../nodes/knowledge_retrieval/entities.py | 3 +- .../knowledge_retrieval_node.py | 13 +- .../nodes/knowledge_retrieval/retrieval.py | 4 +- .../workflow/nodes/trigger_plugin/entities.py | 4 +- .../trigger_plugin/trigger_event_node.py | 8 +- .../nodes/trigger_schedule/entities.py | 4 +- .../trigger_schedule/trigger_schedule_node.py | 8 +- .../nodes/trigger_webhook/entities.py | 6 +- .../workflow/nodes/trigger_webhook/node.py | 12 +- api/core/workflow/template_rendering.py | 3 +- api/core/workflow/workflow_entry.py | 28 +- api/enterprise/telemetry/draft_trace.py | 3 +- ...rameters_cache_when_sync_draft_workflow.py | 5 +- ...oin_when_app_published_workflow_updated.py | 2 +- api/extensions/ext_sentry.py | 3 +- ..._api_workflow_node_execution_repository.py | 2 +- .../logstore_api_workflow_run_repository.py | 2 +- .../logstore_workflow_execution_repository.py | 4 +- ...tore_workflow_node_execution_repository.py | 8 +- api/extensions/otel/parser/base.py | 10 +- api/extensions/otel/parser/llm.py | 4 +- api/extensions/otel/parser/retrieval.py | 6 +- api/extensions/otel/parser/tool.py | 8 +- api/factories/file_factory/builders.py | 3 +- api/factories/file_factory/message_files.py | 3 +- api/factories/file_factory/storage_keys.py | 2 +- api/factories/variable_factory.py | 11 +- api/fields/conversation_fields.py | 3 +- api/fields/member_fields.py | 3 +- api/fields/message_fields.py | 2 +- api/fields/raws.py | 1 - api/fields/workflow_fields.py | 2 +- api/graphon/README.md | 135 -- api/graphon/__init__.py | 0 api/graphon/entities/__init__.py | 11 - api/graphon/entities/base_node_data.py | 178 --- api/graphon/entities/exc.py | 10 - api/graphon/entities/graph_config.py | 23 - api/graphon/entities/graph_init_params.py | 24 - api/graphon/entities/pause_reason.py | 42 - api/graphon/entities/workflow_execution.py | 71 - .../entities/workflow_node_execution.py | 141 -- api/graphon/entities/workflow_start_reason.py | 8 - api/graphon/enums.py | 262 ---- api/graphon/errors.py | 16 - api/graphon/file/__init__.py | 22 - api/graphon/file/constants.py | 48 - api/graphon/file/enums.py | 57 - api/graphon/file/file_factory.py | 39 - api/graphon/file/file_manager.py | 129 -- api/graphon/file/helpers.py | 48 - api/graphon/file/models.py | 215 --- api/graphon/file/protocols.py | 56 - api/graphon/file/runtime.py | 71 - api/graphon/file/tool_file_parser.py | 9 - api/graphon/graph/__init__.py | 11 - api/graphon/graph/edge.py | 15 - api/graphon/graph/graph.py | 438 ------ api/graphon/graph/graph_template.py | 20 - api/graphon/graph/validation.py | 125 -- api/graphon/graph_engine/__init__.py | 4 - api/graphon/graph_engine/_engine_utils.py | 15 - .../graph_engine/command_channels/README.md | 33 - .../graph_engine/command_channels/__init__.py | 6 - .../command_channels/in_memory_channel.py | 53 - .../command_channels/redis_channel.py | 153 -- .../command_processing/__init__.py | 16 - .../command_processing/command_handlers.py | 56 - .../command_processing/command_processor.py | 79 - api/graphon/graph_engine/config.py | 16 - api/graphon/graph_engine/domain/__init__.py | 14 - .../graph_engine/domain/graph_execution.py | 242 --- .../graph_engine/domain/node_execution.py | 45 - api/graphon/graph_engine/entities/__init__.py | 0 api/graphon/graph_engine/entities/commands.py | 56 - api/graphon/graph_engine/error_handler.py | 213 --- .../graph_engine/event_management/__init__.py | 14 - .../event_management/event_handlers.py | 367 ----- .../event_management/event_manager.py | 186 --- api/graphon/graph_engine/graph_engine.py | 377 ----- .../graph_engine/graph_state_manager.py | 290 ---- .../graph_engine/graph_traversal/__init__.py | 14 - .../graph_traversal/edge_processor.py | 201 --- .../graph_traversal/skip_propagator.py | 96 -- api/graphon/graph_engine/layers/README.md | 55 - api/graphon/graph_engine/layers/__init__.py | 16 - api/graphon/graph_engine/layers/base.py | 128 -- .../graph_engine/layers/debug_logging.py | 247 --- .../graph_engine/layers/execution_limits.py | 150 -- api/graphon/graph_engine/manager.py | 79 - .../graph_engine/orchestration/__init__.py | 14 - .../graph_engine/orchestration/dispatcher.py | 143 -- .../orchestration/execution_coordinator.py | 104 -- .../graph_engine/protocols/command_channel.py | 41 - .../graph_engine/ready_queue/__init__.py | 12 - .../graph_engine/ready_queue/factory.py | 37 - .../graph_engine/ready_queue/in_memory.py | 140 -- .../graph_engine/ready_queue/protocol.py | 104 -- .../response_coordinator/__init__.py | 10 - .../response_coordinator/coordinator.py | 697 --------- .../graph_engine/response_coordinator/path.py | 35 - .../response_coordinator/session.py | 66 - api/graphon/graph_engine/worker.py | 204 --- .../worker_management/__init__.py | 12 - .../worker_management/worker_pool.py | 277 ---- api/graphon/graph_events/__init__.py | 84 - api/graphon/graph_events/agent.py | 17 - api/graphon/graph_events/base.py | 31 - api/graphon/graph_events/graph.py | 57 - api/graphon/graph_events/human_input.py | 0 api/graphon/graph_events/iteration.py | 40 - api/graphon/graph_events/loop.py | 40 - api/graphon/graph_events/node.py | 106 -- api/graphon/model_runtime/README.md | 51 - api/graphon/model_runtime/README_CN.md | 64 - api/graphon/model_runtime/__init__.py | 0 .../model_runtime/callbacks/__init__.py | 0 .../model_runtime/callbacks/base_callback.py | 159 -- .../callbacks/logging_callback.py | 180 --- .../model_runtime/entities/__init__.py | 43 - .../model_runtime/entities/common_entities.py | 16 - .../model_runtime/entities/defaults.py | 130 -- .../model_runtime/entities/llm_entities.py | 219 --- .../entities/message_entities.py | 279 ---- .../model_runtime/entities/model_entities.py | 242 --- .../entities/provider_entities.py | 179 --- .../model_runtime/entities/rerank_entities.py | 27 - .../entities/text_embedding_entities.py | 47 - api/graphon/model_runtime/errors/__init__.py | 0 api/graphon/model_runtime/errors/invoke.py | 41 - api/graphon/model_runtime/errors/validate.py | 6 - api/graphon/model_runtime/memory/__init__.py | 3 - .../memory/prompt_message_memory.py | 18 - .../model_providers/__base/__init__.py | 0 .../model_providers/__base/ai_model.py | 247 --- .../__base/large_language_model.py | 638 -------- .../__base/moderation_model.py | 33 - .../model_providers/__base/rerank_model.py | 76 - .../__base/speech2text_model.py | 31 - .../__base/text_embedding_model.py | 98 -- .../__base/tokenizers/gpt2_tokenizer.py | 53 - .../model_providers/__base/tts_model.py | 58 - .../model_runtime/model_providers/__init__.py | 0 .../model_providers/_position.yaml | 43 - .../model_providers/model_provider_factory.py | 173 --- api/graphon/model_runtime/runtime.py | 159 -- .../schema_validators/__init__.py | 0 .../schema_validators/common_validator.py | 92 -- .../model_credential_schema_validator.py | 27 - .../provider_credential_schema_validator.py | 19 - api/graphon/model_runtime/utils/__init__.py | 0 api/graphon/model_runtime/utils/encoders.py | 218 --- api/graphon/node_events/__init__.py | 48 - api/graphon/node_events/agent.py | 18 - api/graphon/node_events/base.py | 40 - api/graphon/node_events/iteration.py | 36 - api/graphon/node_events/loop.py | 36 - api/graphon/node_events/node.py | 72 - api/graphon/nodes/__init__.py | 3 - api/graphon/nodes/answer/__init__.py | 0 api/graphon/nodes/answer/answer_node.py | 70 - api/graphon/nodes/answer/entities.py | 67 - api/graphon/nodes/base/__init__.py | 10 - api/graphon/nodes/base/entities.py | 87 -- api/graphon/nodes/base/node.py | 787 ---------- api/graphon/nodes/base/template.py | 150 -- .../nodes/base/usage_tracking_mixin.py | 28 - .../nodes/base/variable_template_parser.py | 130 -- api/graphon/nodes/code/__init__.py | 3 - api/graphon/nodes/code/code_node.py | 493 ------ api/graphon/nodes/code/entities.py | 57 - api/graphon/nodes/code/exc.py | 16 - api/graphon/nodes/code/limits.py | 13 - .../nodes/document_extractor/__init__.py | 4 - .../nodes/document_extractor/entities.py | 16 - api/graphon/nodes/document_extractor/exc.py | 14 - api/graphon/nodes/document_extractor/node.py | 782 ---------- api/graphon/nodes/end/__init__.py | 0 api/graphon/nodes/end/end_node.py | 47 - api/graphon/nodes/end/entities.py | 27 - api/graphon/nodes/http_request/__init__.py | 22 - api/graphon/nodes/http_request/config.py | 33 - api/graphon/nodes/http_request/entities.py | 241 --- api/graphon/nodes/http_request/exc.py | 26 - api/graphon/nodes/http_request/executor.py | 488 ------ api/graphon/nodes/http_request/node.py | 261 ---- api/graphon/nodes/human_input/__init__.py | 3 - api/graphon/nodes/human_input/entities.py | 208 --- api/graphon/nodes/human_input/enums.py | 55 - .../nodes/human_input/human_input_node.py | 299 ---- api/graphon/nodes/if_else/__init__.py | 3 - api/graphon/nodes/if_else/entities.py | 29 - api/graphon/nodes/if_else/if_else_node.py | 124 -- api/graphon/nodes/iteration/__init__.py | 5 - api/graphon/nodes/iteration/entities.py | 67 - api/graphon/nodes/iteration/exc.py | 26 - api/graphon/nodes/iteration/iteration_node.py | 686 --------- .../nodes/iteration/iteration_start_node.py | 22 - api/graphon/nodes/list_operator/__init__.py | 3 - api/graphon/nodes/list_operator/entities.py | 71 - api/graphon/nodes/list_operator/exc.py | 16 - api/graphon/nodes/list_operator/node.py | 345 ----- api/graphon/nodes/llm/__init__.py | 17 - api/graphon/nodes/llm/entities.py | 100 -- api/graphon/nodes/llm/exc.py | 45 - api/graphon/nodes/llm/file_saver.py | 139 -- api/graphon/nodes/llm/llm_utils.py | 545 ------- api/graphon/nodes/llm/node.py | 1372 ----------------- api/graphon/nodes/llm/protocols.py | 21 - api/graphon/nodes/llm/runtime_protocols.py | 77 - api/graphon/nodes/loop/__init__.py | 6 - api/graphon/nodes/loop/entities.py | 107 -- api/graphon/nodes/loop/loop_end_node.py | 22 - api/graphon/nodes/loop/loop_node.py | 428 ----- api/graphon/nodes/loop/loop_start_node.py | 22 - .../nodes/parameter_extractor/__init__.py | 3 - .../nodes/parameter_extractor/entities.py | 131 -- api/graphon/nodes/parameter_extractor/exc.py | 75 - .../parameter_extractor_node.py | 846 ---------- .../nodes/parameter_extractor/prompts.py | 184 --- api/graphon/nodes/protocols.py | 46 - .../nodes/question_classifier/__init__.py | 4 - .../nodes/question_classifier/entities.py | 30 - api/graphon/nodes/question_classifier/exc.py | 6 - .../question_classifier_node.py | 395 ----- .../question_classifier/template_prompts.py | 76 - api/graphon/nodes/runtime.py | 106 -- api/graphon/nodes/start/__init__.py | 3 - api/graphon/nodes/start/entities.py | 16 - api/graphon/nodes/start/start_node.py | 57 - .../nodes/template_transform/__init__.py | 3 - .../nodes/template_transform/entities.py | 13 - .../template_transform_node.py | 119 -- api/graphon/nodes/tool/__init__.py | 3 - api/graphon/nodes/tool/entities.py | 101 -- api/graphon/nodes/tool/exc.py | 28 - api/graphon/nodes/tool/tool_node.py | 432 ------ api/graphon/nodes/tool_runtime_entities.py | 105 -- .../nodes/variable_aggregator/__init__.py | 3 - .../nodes/variable_aggregator/entities.py | 35 - .../variable_aggregator_node.py | 40 - .../nodes/variable_assigner/__init__.py | 0 .../variable_assigner/common/__init__.py | 0 .../nodes/variable_assigner/common/exc.py | 4 - .../nodes/variable_assigner/common/helpers.py | 55 - .../nodes/variable_assigner/v1/__init__.py | 3 - .../nodes/variable_assigner/v1/node.py | 106 -- .../nodes/variable_assigner/v1/node_data.py | 18 - .../nodes/variable_assigner/v2/__init__.py | 3 - .../nodes/variable_assigner/v2/entities.py | 28 - .../nodes/variable_assigner/v2/enums.py | 20 - api/graphon/nodes/variable_assigner/v2/exc.py | 36 - .../nodes/variable_assigner/v2/helpers.py | 98 -- .../nodes/variable_assigner/v2/node.py | 257 --- api/graphon/prompt_entities.py | 47 - api/graphon/runtime/__init__.py | 22 - api/graphon/runtime/graph_runtime_state.py | 704 --------- .../runtime/graph_runtime_state_protocol.py | 79 - api/graphon/runtime/read_only_wrappers.py | 82 - api/graphon/runtime/variable_pool.py | 279 ---- api/graphon/template_rendering.py | 18 - api/graphon/utils/__init__.py | 0 api/graphon/utils/condition/__init__.py | 0 api/graphon/utils/condition/entities.py | 49 - api/graphon/utils/condition/processor.py | 504 ------ api/graphon/utils/json_in_md_parser.py | 58 - api/graphon/variable_loader.py | 75 - api/graphon/variables/__init__.py | 82 - api/graphon/variables/consts.py | 7 - api/graphon/variables/exc.py | 2 - api/graphon/variables/factory.py | 202 --- api/graphon/variables/input_entities.py | 62 - api/graphon/variables/segment_group.py | 22 - api/graphon/variables/segments.py | 253 --- api/graphon/variables/types.py | 273 ---- api/graphon/variables/utils.py | 33 - api/graphon/variables/variables.py | 172 --- api/graphon/workflow_type_encoder.py | 49 - api/libs/helper.py | 4 +- api/models/human_input.py | 2 +- api/models/model.py | 6 +- api/models/utils/file_input_compat.py | 3 +- api/models/workflow.py | 31 +- api/pyproject.toml | 5 +- api/pyrefly-local-excludes.txt | 27 - .../api_workflow_run_repository.py | 4 +- ..._api_workflow_node_execution_repository.py | 2 +- .../sqlalchemy_api_workflow_run_repository.py | 6 +- ...hemy_execution_extra_content_repository.py | 6 +- api/services/app_dsl_service.py | 12 +- api/services/app_service.py | 4 +- api/services/app_task_service.py | 3 +- api/services/audio_service.py | 2 +- .../clear_free_plan_tenant_expired_logs.py | 2 +- api/services/conversation_service.py | 2 +- api/services/conversation_variable_updater.py | 2 +- api/services/dataset_service.py | 6 +- api/services/datasource_provider_service.py | 2 +- .../entities/model_provider_entities.py | 18 +- api/services/external_knowledge_service.py | 2 +- api/services/file_service.py | 2 +- api/services/hit_testing_service.py | 3 +- .../human_input_delivery_test_service.py | 2 +- api/services/human_input_service.py | 12 +- api/services/message_service.py | 2 +- api/services/model_load_balancing_service.py | 12 +- api/services/model_provider_service.py | 3 +- api/services/rag_pipeline/rag_pipeline.py | 22 +- .../rag_pipeline/rag_pipeline_dsl_service.py | 12 +- .../archive_paid_plan_workflow_run.py | 2 +- api/services/summary_index_service.py | 4 +- .../tools/api_tools_manage_service.py | 2 +- .../tools/workflow_tools_manage_service.py | 2 +- api/services/trigger/schedule_service.py | 2 +- api/services/trigger/trigger_service.py | 2 +- api/services/trigger/webhook_service.py | 6 +- api/services/variable_truncator.py | 5 +- api/services/vector_service.py | 3 +- api/services/workflow/workflow_converter.py | 10 +- api/services/workflow_app_service.py | 2 +- .../workflow_draft_variable_service.py | 26 +- .../workflow_event_snapshot_service.py | 8 +- api/services/workflow_service.py | 50 +- .../app_generate/workflow_execute_task.py | 2 +- api/tasks/async_workflow_tasks.py | 2 +- .../batch_create_segment_to_index_task.py | 2 +- api/tasks/human_input_timeout_tasks.py | 4 +- api/tasks/mail_human_input_delivery_task.py | 2 +- api/tasks/trigger_processing_tasks.py | 2 +- api/tasks/workflow_execution_tasks.py | 4 +- api/tasks/workflow_node_execution_tasks.py | 6 +- .../test_datasource_manager_integration.py | 3 +- .../test_datasource_node_integration.py | 5 +- .../factories/test_storage_key_loader.py | 2 +- .../model_runtime/__mock/plugin_model.py | 6 +- .../test_workflow_draft_variable_service.py | 8 +- .../test_remove_app_and_related_data_task.py | 4 +- .../workflow/nodes/__mock/model.py | 3 +- .../workflow/nodes/test_code.py | 12 +- .../workflow/nodes/test_http.py | 13 +- .../workflow/nodes/test_llm.py | 11 +- .../nodes/test_parameter_extractor.py | 11 +- .../workflow/nodes/test_template_transform.py | 7 +- .../workflow/nodes/test_tool.py | 11 +- ...test_chat_conversation_status_count_api.py | 2 +- .../app/test_workflow_draft_variable.py | 2 +- .../layers/test_pause_state_persist_layer.py | 19 +- .../test_human_input_form_repository_impl.py | 2 +- .../test_human_input_resume_node_execution.py | 24 +- .../factories/test_storage_key_loader.py | 2 +- .../helpers/execution_extra_content.py | 1 + ..._api_workflow_node_execution_repository.py | 2 +- ..._sqlalchemy_api_workflow_run_repository.py | 8 +- ...hemy_execution_extra_content_repository.py | 4 +- .../test_workflow_run_repository.py | 4 +- .../services/test_agent_service.py | 1 + .../test_conversation_variable_updater.py | 2 +- .../services/test_dataset_service.py | 2 +- .../test_dataset_service_update_dataset.py | 2 +- .../test_delete_archived_workflow_run.py | 2 +- .../test_human_input_delivery_test.py | 4 +- .../test_human_input_delivery_test_service.py | 2 +- .../services/test_messages_clean_service.py | 2 +- .../services/test_model_provider_service.py | 8 +- .../services/test_workflow_app_service.py | 2 +- .../test_workflow_draft_variable_service.py | 2 +- .../workflow/test_workflow_converter.py | 6 +- ...kflow_node_execution_service_repository.py | 2 +- .../test_mail_human_input_delivery_task.py | 6 +- .../test_remove_app_and_related_data_task.py | 4 +- .../test_workflow_pause_integration.py | 4 +- .../trigger/test_trigger_e2e.py | 2 +- .../controllers/console/app/test_audio.py | 2 +- .../controllers/console/app/test_workflow.py | 3 +- .../app/test_workflow_pause_details_api.py | 8 +- .../app/workflow_draft_variables_test.py | 8 +- .../rag_pipeline/test_datasource_auth.py | 2 +- .../test_rag_pipeline_draft_variable.py | 2 +- .../console/datasets/test_hit_testing_base.py | 2 +- .../controllers/console/explore/test_audio.py | 2 +- .../console/explore/test_message.py | 2 +- .../controllers/console/explore/test_trial.py | 2 +- .../workspace/test_load_balancing_config.py | 3 +- .../console/workspace/test_model_providers.py | 2 +- .../console/workspace/test_models.py | 4 +- .../controllers/service_api/app/test_audio.py | 2 +- .../service_api/app/test_completion.py | 2 +- .../service_api/app/test_workflow.py | 2 +- .../service_api/app/test_workflow_fields.py | 3 +- .../unit_tests/controllers/web/test_audio.py | 2 +- .../controllers/web/test_completion.py | 2 +- .../core/agent/test_cot_agent_runner.py | 2 +- .../core/agent/test_cot_chat_agent_runner.py | 2 +- .../agent/test_cot_completion_agent_runner.py | 4 +- .../core/agent/test_fc_agent_runner.py | 10 +- .../test_model_config_converter.py | 4 +- .../test_variables_manager.py | 2 +- .../features/file_upload/test_manager.py | 5 +- .../core/app/app_config/test_entities.py | 2 +- .../test_app_runner_conversation_variables.py | 2 +- .../test_generate_response_converter.py | 3 +- .../test_generate_task_pipeline.py | 4 +- .../test_generate_task_pipeline_core.py | 4 +- .../test_agent_chat_app_generator.py | 2 +- .../agent_chat/test_agent_chat_app_runner.py | 4 +- .../chat/test_app_generator_and_runner.py | 2 +- .../chat/test_base_app_runner_multimodal.py | 4 +- .../test_graph_runtime_state_support.py | 3 +- .../test_workflow_response_converter.py | 3 +- ...workflow_response_converter_human_input.py | 5 +- ..._workflow_response_converter_resumption.py | 5 +- ..._workflow_response_converter_truncation.py | 4 +- .../app/apps/completion/test_app_runner.py | 2 +- ...est_completion_completion_app_generator.py | 2 +- ...st_pipeline_generate_response_converter.py | 3 +- .../pipeline/test_pipeline_queue_manager.py | 2 +- .../app/apps/pipeline/test_pipeline_runner.py | 2 +- .../core/app/apps/test_base_app_generator.py | 5 +- .../core/app/apps/test_base_app_runner.py | 18 +- .../core/app/apps/test_pause_resume.py | 15 +- .../app/apps/test_workflow_app_runner_core.py | 34 +- .../test_workflow_app_runner_notifications.py | 4 +- .../test_workflow_app_runner_single_node.py | 4 +- .../app/apps/test_workflow_pause_events.py | 10 +- .../test_generate_response_converter.py | 3 +- .../workflow/test_generate_task_pipeline.py | 5 +- .../test_generate_task_pipeline_core.py | 4 +- .../core/app/entities/test_task_entities.py | 3 +- ...est_conversation_variable_persist_layer.py | 15 +- .../layers/test_pause_state_persist_layer.py | 22 +- .../core/app/layers/test_suspend_layer.py | 3 +- .../core/app/layers/test_timeslice_layer.py | 3 +- .../app/layers/test_trigger_post_layer.py | 5 +- .../test_based_generate_task_pipeline.py | 2 +- ...st_easy_ui_based_generate_task_pipeline.py | 4 +- ...sy_ui_based_generate_task_pipeline_core.py | 6 +- .../test_easy_ui_message_end_files.py | 2 +- .../app/test_easy_ui_model_config_manager.py | 3 +- .../app/workflow/layers/test_persistence.py | 4 +- .../core/app/workflow/test_file_runtime.py | 2 +- .../core/app/workflow/test_node_factory.py | 2 +- .../test_observability_layer_extra.py | 3 +- .../app/workflow/test_persistence_layer.py | 14 +- .../base/test_app_generator_tts_publisher.py | 6 +- .../datasource/test_datasource_manager.py | 7 +- .../utils/test_message_transformer.py | 3 +- .../test_entities_execution_extra_content.py | 5 +- .../entities/test_entities_model_entities.py | 6 +- .../test_entities_provider_configuration.py | 22 +- .../test_entities_provider_entities.py | 2 +- .../output_parser/test_structured_output.py | 28 +- .../core/llm_generator/test_llm_generator.py | 4 +- .../core/mcp/server/test_streamable_http.py | 2 +- .../core/memory/test_token_buffer_memory.py | 4 +- .../test_model_provider_factory.py | 1 - .../ops/aliyun_trace/test_aliyun_trace.py | 4 +- .../aliyun_trace/test_aliyun_trace_utils.py | 4 +- .../ops/langfuse_trace/test_langfuse_trace.py | 2 +- .../langsmith_trace/test_langsmith_trace.py | 2 +- .../ops/mlflow_trace/test_mlflow_trace.py | 2 +- .../core/ops/opik_trace/test_opik_trace.py | 2 +- .../ops/tencent_trace/test_span_builder.py | 4 +- .../ops/tencent_trace/test_tencent_trace.py | 4 +- .../core/ops/test_arize_phoenix_trace.py | 2 +- .../core/ops/weave_trace/test_weave_trace.py | 2 +- .../plugin/test_backwards_invocation_model.py | 3 +- .../core/plugin/test_model_runtime_adapter.py | 6 +- .../core/plugin/test_plugin_entities.py | 12 +- .../core/plugin/test_plugin_runtime.py | 16 +- .../core/plugin/utils/test_chunk_merger.py | 3 +- .../prompt/test_advanced_prompt_transform.py | 14 +- .../test_agent_history_prompt_transform.py | 13 +- .../core/prompt/test_prompt_message.py | 5 +- .../core/prompt/test_prompt_transform.py | 2 +- .../prompt/test_simple_prompt_transform.py | 12 +- .../test_data_post_processor.py | 5 +- .../rag/embedding/test_cached_embedding.py | 4 +- .../rag/embedding/test_embedding_service.py | 8 +- .../test_paragraph_index_processor.py | 6 +- .../core/rag/indexing/test_indexing_runner.py | 2 +- .../core/rag/rerank/test_reranker.py | 2 +- .../rag/retrieval/test_dataset_retrieval.py | 4 +- ...test_multi_dataset_function_call_router.py | 3 +- .../test_multi_dataset_react_route.py | 5 +- ...st_celery_workflow_execution_repository.py | 3 +- ...lery_workflow_node_execution_repository.py | 6 +- .../test_human_input_form_repository_impl.py | 10 +- .../test_human_input_repository.py | 4 +- ...qlalchemy_workflow_execution_repository.py | 3 +- ...hemy_workflow_node_execution_repository.py | 12 +- ...rkflow_node_execution_conflict_handling.py | 10 +- ...test_workflow_node_execution_truncation.py | 10 +- api/tests/unit_tests/core/test_file.py | 1 + .../unit_tests/core/test_model_manager.py | 2 +- .../core/test_provider_configuration.py | 18 +- .../unit_tests/core/test_provider_manager.py | 4 +- .../core/tools/test_builtin_tool_base.py | 2 +- .../core/tools/test_builtin_tools_extra.py | 4 +- .../core/tools/test_tool_file_manager.py | 2 +- .../utils/test_model_invocation_utils.py | 4 +- .../utils/test_workflow_configuration_sync.py | 2 +- .../tools/workflow_as_tool/test_provider.py | 2 +- .../core/tools/workflow_as_tool/test_tool.py | 2 +- .../debug/test_debug_event_selectors.py | 2 +- .../unit_tests/core/variables/test_segment.py | 10 +- .../core/variables/test_segment_type.py | 1 - .../variables/test_segment_type_validation.py | 4 +- .../core/variables/test_variables.py | 3 +- .../entities/test_graph_runtime_state.py | 307 ---- .../workflow/entities/test_pause_reason.py | 88 -- .../core/workflow/entities/test_template.py | 87 -- .../workflow/entities/test_variable_pool.py | 136 -- .../entities/test_workflow_node_execution.py | 225 --- .../core/workflow/graph/test_graph.py | 281 ---- .../core/workflow/graph/test_graph_builder.py | 59 - .../graph/test_graph_skip_validation.py | 118 -- .../workflow/graph/test_graph_validation.py | 219 --- .../core/workflow/graph_engine/README.md | 453 +----- .../command_channels/test_redis_channel.py | 315 ---- .../event_management/test_event_handlers.py | 119 -- .../event_management/test_event_manager.py | 39 - .../graph_engine/graph_traversal/__init__.py | 1 - .../graph_traversal/test_skip_propagator.py | 307 ---- .../graph_engine/human_input_test_utils.py | 131 -- .../workflow/graph_engine/layers/conftest.py | 10 +- .../layers/test_layer_initialization.py | 57 - .../graph_engine/layers/test_llm_quota.py | 11 +- .../graph_engine/layers/test_observability.py | 8 +- .../orchestration/test_dispatcher.py | 189 --- .../graph_engine/test_answer_end_with_text.py | 37 - .../test_answer_order_workflow.py | 28 - ...est_array_iteration_formatting_workflow.py | 24 - .../graph_engine/test_auto_mock_system.py | 392 ----- .../graph_engine/test_basic_chatflow.py | 41 - .../graph_engine/test_command_system.py | 266 ---- .../test_complex_branch_workflow.py | 134 -- ...ditional_streaming_vs_template_workflow.py | 220 --- .../graph_engine/test_database_utils.py | 46 - .../test_dispatcher_pause_drain.py | 72 - .../test_end_node_without_value_type.py | 60 - .../test_execution_coordinator.py | 62 - .../graph_engine/test_graph_engine.py | 770 --------- .../test_graph_execution_serialization.py | 196 --- .../graph_engine/test_graph_state_snapshot.py | 190 --- .../test_human_input_pause_multi_branch.py | 389 ----- .../test_human_input_pause_single_branch.py | 346 ----- .../graph_engine/test_if_else_streaming.py | 324 ---- .../test_iteration_flatten_output.py | 126 -- .../graph_engine/test_loop_contains_answer.py | 88 -- .../workflow/graph_engine/test_loop_node.py | 41 - .../graph_engine/test_loop_with_tool.py | 72 - .../graph_engine/test_mock_example.py | 281 ---- .../graph_engine/test_mock_factory.py | 3 +- .../test_mock_iteration_simple.py | 199 --- .../workflow/graph_engine/test_mock_nodes.py | 9 +- .../test_mock_nodes_template_code.py | 670 -------- .../workflow/graph_engine/test_mock_simple.py | 231 --- .../test_parallel_human_input_join_resume.py | 22 +- ...rallel_human_input_pause_missing_finish.py | 336 ---- .../test_parallel_streaming_workflow.py | 286 ---- .../test_pause_deferred_ready_nodes.py | 311 ---- .../graph_engine/test_pause_resume_state.py | 219 --- .../test_redis_stop_integration.py | 268 ---- .../graph_engine/test_response_session.py | 55 - .../test_streaming_conversation_variables.py | 79 - .../graph_engine/test_table_runner.py | 13 +- ..._update_conversation_variable_iteration.py | 41 - .../graph_engine/test_variable_aggregator.py | 58 - .../test_variable_update_events.py | 129 -- .../core/workflow/graph_engine/test_worker.py | 148 -- .../nodes/agent/test_message_transformer.py | 3 +- .../nodes/agent/test_runtime_support.py | 3 +- .../core/workflow/nodes/answer/test_answer.py | 9 +- .../workflow/nodes/base/test_base_node.py | 4 +- .../test_get_node_type_classes_mapping.py | 3 +- .../workflow/nodes/code/code_node_spec.py | 3 +- .../core/workflow/nodes/code/entities_spec.py | 352 ----- .../nodes/datasource/test_datasource_node.py | 5 +- .../nodes/http_request/test_config.py | 33 - .../nodes/http_request/test_entities.py | 233 --- .../test_http_request_executor.py | 8 +- .../http_request/test_http_request_node.py | 10 +- .../human_input/test_email_delivery_config.py | 3 +- .../nodes/human_input/test_entities.py | 131 +- .../test_human_input_form_filled_event.py | 9 +- .../workflow/nodes/iteration/entities_spec.py | 339 ---- .../nodes/iteration/iteration_node_spec.py | 438 ------ .../test_iteration_abort_propagation.py | 201 --- .../test_iteration_child_engine_errors.py | 4 +- .../test_parallel_iteration_duration.py | 67 - .../test_knowledge_index_node.py | 6 +- .../test_knowledge_retrieval_node.py | 8 +- .../workflow/nodes/list_operator/node_spec.py | 4 +- .../workflow/nodes/llm/test_file_saver.py | 170 -- .../core/workflow/nodes/llm/test_llm_utils.py | 7 +- .../core/workflow/nodes/llm/test_node.py | 26 +- .../core/workflow/nodes/llm/test_scenarios.py | 25 - .../parameter_extractor/test_entities.py | 27 - .../test_parameter_extractor_node.py | 4 +- .../nodes/template_transform/entities_spec.py | 225 --- .../template_transform_node_spec.py | 4 +- .../test_template_transform_node.py | 4 +- .../core/workflow/nodes/test_base_node.py | 8 +- .../nodes/test_document_extractor_node.py | 4 +- .../core/workflow/nodes/test_if_else.py | 10 +- .../core/workflow/nodes/test_list_operator.py | 4 +- .../core/workflow/nodes/test_loop_node.py | 150 -- .../nodes/test_question_classifier_node.py | 126 -- .../nodes/test_start_node_json_object.py | 8 +- .../workflow/nodes/tool/test_tool_node.py | 4 +- .../nodes/tool/test_tool_node_runtime.py | 10 +- .../trigger_plugin/test_trigger_event_node.py | 7 +- .../v1/test_variable_assigner_v1.py | 312 ---- .../nodes/variable_assigner/v2/__init__.py | 1 - .../variable_assigner/v2/test_helpers.py | 22 - .../v2/test_variable_assigner_v2.py | 430 ------ .../workflow/nodes/webhook/test_exceptions.py | 2 +- .../webhook/test_webhook_file_conversion.py | 8 +- .../nodes/webhook/test_webhook_node.py | 11 +- .../unit_tests/core/workflow/test_enums.py | 41 - .../core/workflow/test_human_input_compat.py | 2 +- .../core/workflow/test_node_factory.py | 8 +- .../core/workflow/test_node_runtime.py | 8 +- .../core/workflow/test_system_variable.py | 6 +- .../core/workflow/test_variable_pool.py | 18 +- .../core/workflow/test_workflow_entry.py | 13 +- .../workflow/test_workflow_entry_helpers.py | 15 +- .../test_workflow_entry_redis_channel.py | 5 +- .../core/workflow/utils/test_condition.py | 52 - .../utils/test_variable_template_parser.py | 48 - .../factories/test_build_from_mapping.py | 2 +- .../factories/test_variable_factory.py | 10 +- .../unit_tests/fields/test_file_fields.py | 2 +- .../graphon/file/test_file_factory.py | 18 - .../graphon/file/test_file_manager.py | 133 -- .../unit_tests/graphon/file/test_models.py | 54 - .../graphon/model_runtime/__base/__init__.py | 0 .../__base/test_increase_tool_call.py | 114 -- ...large_language_model_non_stream_parsing.py | 126 -- .../graphon/model_runtime/__init__.py | 0 .../callbacks/test_base_callback.py | 964 ------------ .../callbacks/test_logging_callback.py | 700 --------- .../entities/test_common_entities.py | 35 - .../entities/test_llm_entities.py | 148 -- .../entities/test_message_entities.py | 210 --- .../entities/test_model_entities.py | 220 --- .../model_runtime/errors/test_invoke.py | 63 - .../model_providers/__base/test_ai_model.py | 254 --- .../__base/test_large_language_model.py | 452 ------ .../__base/test_moderation_model.py | 56 - .../__base/test_rerank_model.py | 110 -- .../__base/test_runtime_user_forwarding.py | 170 -- .../__base/test_speech2text_model.py | 56 - .../__base/test_text_embedding_model.py | 146 -- .../model_providers/__base/test_tts_model.py | 83 - .../__base/tokenizers/test_gpt2_tokenizer.py | 96 -- .../test_common_validator.py | 201 --- .../test_model_credential_schema_validator.py | 233 --- ...st_provider_credential_schema_validator.py | 72 - .../model_runtime/utils/test_encoders.py | 231 --- .../graphon/node_events/test_base.py | 19 - .../graphon/utils/test_json_in_md_parser.py | 75 - .../unit_tests/libs/_human_input/support.py | 1 + .../libs/_human_input/test_form_service.py | 2 +- .../libs/_human_input/test_models.py | 2 +- .../models/test_conversation_variable.py | 3 +- api/tests/unit_tests/models/test_model.py | 2 +- api/tests/unit_tests/models/test_workflow.py | 8 +- .../unit_tests/models/test_workflow_models.py | 2 +- .../test_sqlalchemy_repository.py | 10 +- ...hemy_workflow_node_execution_repository.py | 4 +- .../services/dataset_service_test_helpers.py | 2 +- .../services/document_service_validation.py | 2 +- .../services/test_app_dsl_service.py | 2 +- .../test_datasource_provider_service.py | 2 +- .../services/test_human_input_service.py | 12 +- .../test_model_load_balancing_service.py | 6 +- ...est_model_provider_service_sanitization.py | 4 +- .../services/test_variable_truncator.py | 5 +- .../test_workflow_run_service_pause.py | 2 +- .../services/test_workflow_service.py | 2 +- .../workflow/test_draft_var_loader_simple.py | 7 +- .../test_workflow_draft_variable_service.py | 9 +- .../test_workflow_event_snapshot_service.py | 6 +- .../test_workflow_human_input_delivery.py | 6 +- .../workflow/test_workflow_service.py | 2 +- .../tasks/test_human_input_timeout_tasks.py | 2 +- api/tests/unit_tests/tools/test_mcp_tool.py | 2 +- .../test_structured_output_parser.py | 6 +- api/tests/workflow_test_utils.py | 7 +- api/uv.lock | 50 +- 883 files changed, 1779 insertions(+), 47377 deletions(-) delete mode 100644 api/graphon/README.md delete mode 100644 api/graphon/__init__.py delete mode 100644 api/graphon/entities/__init__.py delete mode 100644 api/graphon/entities/base_node_data.py delete mode 100644 api/graphon/entities/exc.py delete mode 100644 api/graphon/entities/graph_config.py delete mode 100644 api/graphon/entities/graph_init_params.py delete mode 100644 api/graphon/entities/pause_reason.py delete mode 100644 api/graphon/entities/workflow_execution.py delete mode 100644 api/graphon/entities/workflow_node_execution.py delete mode 100644 api/graphon/entities/workflow_start_reason.py delete mode 100644 api/graphon/enums.py delete mode 100644 api/graphon/errors.py delete mode 100644 api/graphon/file/__init__.py delete mode 100644 api/graphon/file/constants.py delete mode 100644 api/graphon/file/enums.py delete mode 100644 api/graphon/file/file_factory.py delete mode 100644 api/graphon/file/file_manager.py delete mode 100644 api/graphon/file/helpers.py delete mode 100644 api/graphon/file/models.py delete mode 100644 api/graphon/file/protocols.py delete mode 100644 api/graphon/file/runtime.py delete mode 100644 api/graphon/file/tool_file_parser.py delete mode 100644 api/graphon/graph/__init__.py delete mode 100644 api/graphon/graph/edge.py delete mode 100644 api/graphon/graph/graph.py delete mode 100644 api/graphon/graph/graph_template.py delete mode 100644 api/graphon/graph/validation.py delete mode 100644 api/graphon/graph_engine/__init__.py delete mode 100644 api/graphon/graph_engine/_engine_utils.py delete mode 100644 api/graphon/graph_engine/command_channels/README.md delete mode 100644 api/graphon/graph_engine/command_channels/__init__.py delete mode 100644 api/graphon/graph_engine/command_channels/in_memory_channel.py delete mode 100644 api/graphon/graph_engine/command_channels/redis_channel.py delete mode 100644 api/graphon/graph_engine/command_processing/__init__.py delete mode 100644 api/graphon/graph_engine/command_processing/command_handlers.py delete mode 100644 api/graphon/graph_engine/command_processing/command_processor.py delete mode 100644 api/graphon/graph_engine/config.py delete mode 100644 api/graphon/graph_engine/domain/__init__.py delete mode 100644 api/graphon/graph_engine/domain/graph_execution.py delete mode 100644 api/graphon/graph_engine/domain/node_execution.py delete mode 100644 api/graphon/graph_engine/entities/__init__.py delete mode 100644 api/graphon/graph_engine/entities/commands.py delete mode 100644 api/graphon/graph_engine/error_handler.py delete mode 100644 api/graphon/graph_engine/event_management/__init__.py delete mode 100644 api/graphon/graph_engine/event_management/event_handlers.py delete mode 100644 api/graphon/graph_engine/event_management/event_manager.py delete mode 100644 api/graphon/graph_engine/graph_engine.py delete mode 100644 api/graphon/graph_engine/graph_state_manager.py delete mode 100644 api/graphon/graph_engine/graph_traversal/__init__.py delete mode 100644 api/graphon/graph_engine/graph_traversal/edge_processor.py delete mode 100644 api/graphon/graph_engine/graph_traversal/skip_propagator.py delete mode 100644 api/graphon/graph_engine/layers/README.md delete mode 100644 api/graphon/graph_engine/layers/__init__.py delete mode 100644 api/graphon/graph_engine/layers/base.py delete mode 100644 api/graphon/graph_engine/layers/debug_logging.py delete mode 100644 api/graphon/graph_engine/layers/execution_limits.py delete mode 100644 api/graphon/graph_engine/manager.py delete mode 100644 api/graphon/graph_engine/orchestration/__init__.py delete mode 100644 api/graphon/graph_engine/orchestration/dispatcher.py delete mode 100644 api/graphon/graph_engine/orchestration/execution_coordinator.py delete mode 100644 api/graphon/graph_engine/protocols/command_channel.py delete mode 100644 api/graphon/graph_engine/ready_queue/__init__.py delete mode 100644 api/graphon/graph_engine/ready_queue/factory.py delete mode 100644 api/graphon/graph_engine/ready_queue/in_memory.py delete mode 100644 api/graphon/graph_engine/ready_queue/protocol.py delete mode 100644 api/graphon/graph_engine/response_coordinator/__init__.py delete mode 100644 api/graphon/graph_engine/response_coordinator/coordinator.py delete mode 100644 api/graphon/graph_engine/response_coordinator/path.py delete mode 100644 api/graphon/graph_engine/response_coordinator/session.py delete mode 100644 api/graphon/graph_engine/worker.py delete mode 100644 api/graphon/graph_engine/worker_management/__init__.py delete mode 100644 api/graphon/graph_engine/worker_management/worker_pool.py delete mode 100644 api/graphon/graph_events/__init__.py delete mode 100644 api/graphon/graph_events/agent.py delete mode 100644 api/graphon/graph_events/base.py delete mode 100644 api/graphon/graph_events/graph.py delete mode 100644 api/graphon/graph_events/human_input.py delete mode 100644 api/graphon/graph_events/iteration.py delete mode 100644 api/graphon/graph_events/loop.py delete mode 100644 api/graphon/graph_events/node.py delete mode 100644 api/graphon/model_runtime/README.md delete mode 100644 api/graphon/model_runtime/README_CN.md delete mode 100644 api/graphon/model_runtime/__init__.py delete mode 100644 api/graphon/model_runtime/callbacks/__init__.py delete mode 100644 api/graphon/model_runtime/callbacks/base_callback.py delete mode 100644 api/graphon/model_runtime/callbacks/logging_callback.py delete mode 100644 api/graphon/model_runtime/entities/__init__.py delete mode 100644 api/graphon/model_runtime/entities/common_entities.py delete mode 100644 api/graphon/model_runtime/entities/defaults.py delete mode 100644 api/graphon/model_runtime/entities/llm_entities.py delete mode 100644 api/graphon/model_runtime/entities/message_entities.py delete mode 100644 api/graphon/model_runtime/entities/model_entities.py delete mode 100644 api/graphon/model_runtime/entities/provider_entities.py delete mode 100644 api/graphon/model_runtime/entities/rerank_entities.py delete mode 100644 api/graphon/model_runtime/entities/text_embedding_entities.py delete mode 100644 api/graphon/model_runtime/errors/__init__.py delete mode 100644 api/graphon/model_runtime/errors/invoke.py delete mode 100644 api/graphon/model_runtime/errors/validate.py delete mode 100644 api/graphon/model_runtime/memory/__init__.py delete mode 100644 api/graphon/model_runtime/memory/prompt_message_memory.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/__init__.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/ai_model.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/large_language_model.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/moderation_model.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/rerank_model.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/speech2text_model.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/text_embedding_model.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/tts_model.py delete mode 100644 api/graphon/model_runtime/model_providers/__init__.py delete mode 100644 api/graphon/model_runtime/model_providers/_position.yaml delete mode 100644 api/graphon/model_runtime/model_providers/model_provider_factory.py delete mode 100644 api/graphon/model_runtime/runtime.py delete mode 100644 api/graphon/model_runtime/schema_validators/__init__.py delete mode 100644 api/graphon/model_runtime/schema_validators/common_validator.py delete mode 100644 api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py delete mode 100644 api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py delete mode 100644 api/graphon/model_runtime/utils/__init__.py delete mode 100644 api/graphon/model_runtime/utils/encoders.py delete mode 100644 api/graphon/node_events/__init__.py delete mode 100644 api/graphon/node_events/agent.py delete mode 100644 api/graphon/node_events/base.py delete mode 100644 api/graphon/node_events/iteration.py delete mode 100644 api/graphon/node_events/loop.py delete mode 100644 api/graphon/node_events/node.py delete mode 100644 api/graphon/nodes/__init__.py delete mode 100644 api/graphon/nodes/answer/__init__.py delete mode 100644 api/graphon/nodes/answer/answer_node.py delete mode 100644 api/graphon/nodes/answer/entities.py delete mode 100644 api/graphon/nodes/base/__init__.py delete mode 100644 api/graphon/nodes/base/entities.py delete mode 100644 api/graphon/nodes/base/node.py delete mode 100644 api/graphon/nodes/base/template.py delete mode 100644 api/graphon/nodes/base/usage_tracking_mixin.py delete mode 100644 api/graphon/nodes/base/variable_template_parser.py delete mode 100644 api/graphon/nodes/code/__init__.py delete mode 100644 api/graphon/nodes/code/code_node.py delete mode 100644 api/graphon/nodes/code/entities.py delete mode 100644 api/graphon/nodes/code/exc.py delete mode 100644 api/graphon/nodes/code/limits.py delete mode 100644 api/graphon/nodes/document_extractor/__init__.py delete mode 100644 api/graphon/nodes/document_extractor/entities.py delete mode 100644 api/graphon/nodes/document_extractor/exc.py delete mode 100644 api/graphon/nodes/document_extractor/node.py delete mode 100644 api/graphon/nodes/end/__init__.py delete mode 100644 api/graphon/nodes/end/end_node.py delete mode 100644 api/graphon/nodes/end/entities.py delete mode 100644 api/graphon/nodes/http_request/__init__.py delete mode 100644 api/graphon/nodes/http_request/config.py delete mode 100644 api/graphon/nodes/http_request/entities.py delete mode 100644 api/graphon/nodes/http_request/exc.py delete mode 100644 api/graphon/nodes/http_request/executor.py delete mode 100644 api/graphon/nodes/http_request/node.py delete mode 100644 api/graphon/nodes/human_input/__init__.py delete mode 100644 api/graphon/nodes/human_input/entities.py delete mode 100644 api/graphon/nodes/human_input/enums.py delete mode 100644 api/graphon/nodes/human_input/human_input_node.py delete mode 100644 api/graphon/nodes/if_else/__init__.py delete mode 100644 api/graphon/nodes/if_else/entities.py delete mode 100644 api/graphon/nodes/if_else/if_else_node.py delete mode 100644 api/graphon/nodes/iteration/__init__.py delete mode 100644 api/graphon/nodes/iteration/entities.py delete mode 100644 api/graphon/nodes/iteration/exc.py delete mode 100644 api/graphon/nodes/iteration/iteration_node.py delete mode 100644 api/graphon/nodes/iteration/iteration_start_node.py delete mode 100644 api/graphon/nodes/list_operator/__init__.py delete mode 100644 api/graphon/nodes/list_operator/entities.py delete mode 100644 api/graphon/nodes/list_operator/exc.py delete mode 100644 api/graphon/nodes/list_operator/node.py delete mode 100644 api/graphon/nodes/llm/__init__.py delete mode 100644 api/graphon/nodes/llm/entities.py delete mode 100644 api/graphon/nodes/llm/exc.py delete mode 100644 api/graphon/nodes/llm/file_saver.py delete mode 100644 api/graphon/nodes/llm/llm_utils.py delete mode 100644 api/graphon/nodes/llm/node.py delete mode 100644 api/graphon/nodes/llm/protocols.py delete mode 100644 api/graphon/nodes/llm/runtime_protocols.py delete mode 100644 api/graphon/nodes/loop/__init__.py delete mode 100644 api/graphon/nodes/loop/entities.py delete mode 100644 api/graphon/nodes/loop/loop_end_node.py delete mode 100644 api/graphon/nodes/loop/loop_node.py delete mode 100644 api/graphon/nodes/loop/loop_start_node.py delete mode 100644 api/graphon/nodes/parameter_extractor/__init__.py delete mode 100644 api/graphon/nodes/parameter_extractor/entities.py delete mode 100644 api/graphon/nodes/parameter_extractor/exc.py delete mode 100644 api/graphon/nodes/parameter_extractor/parameter_extractor_node.py delete mode 100644 api/graphon/nodes/parameter_extractor/prompts.py delete mode 100644 api/graphon/nodes/protocols.py delete mode 100644 api/graphon/nodes/question_classifier/__init__.py delete mode 100644 api/graphon/nodes/question_classifier/entities.py delete mode 100644 api/graphon/nodes/question_classifier/exc.py delete mode 100644 api/graphon/nodes/question_classifier/question_classifier_node.py delete mode 100644 api/graphon/nodes/question_classifier/template_prompts.py delete mode 100644 api/graphon/nodes/runtime.py delete mode 100644 api/graphon/nodes/start/__init__.py delete mode 100644 api/graphon/nodes/start/entities.py delete mode 100644 api/graphon/nodes/start/start_node.py delete mode 100644 api/graphon/nodes/template_transform/__init__.py delete mode 100644 api/graphon/nodes/template_transform/entities.py delete mode 100644 api/graphon/nodes/template_transform/template_transform_node.py delete mode 100644 api/graphon/nodes/tool/__init__.py delete mode 100644 api/graphon/nodes/tool/entities.py delete mode 100644 api/graphon/nodes/tool/exc.py delete mode 100644 api/graphon/nodes/tool/tool_node.py delete mode 100644 api/graphon/nodes/tool_runtime_entities.py delete mode 100644 api/graphon/nodes/variable_aggregator/__init__.py delete mode 100644 api/graphon/nodes/variable_aggregator/entities.py delete mode 100644 api/graphon/nodes/variable_aggregator/variable_aggregator_node.py delete mode 100644 api/graphon/nodes/variable_assigner/__init__.py delete mode 100644 api/graphon/nodes/variable_assigner/common/__init__.py delete mode 100644 api/graphon/nodes/variable_assigner/common/exc.py delete mode 100644 api/graphon/nodes/variable_assigner/common/helpers.py delete mode 100644 api/graphon/nodes/variable_assigner/v1/__init__.py delete mode 100644 api/graphon/nodes/variable_assigner/v1/node.py delete mode 100644 api/graphon/nodes/variable_assigner/v1/node_data.py delete mode 100644 api/graphon/nodes/variable_assigner/v2/__init__.py delete mode 100644 api/graphon/nodes/variable_assigner/v2/entities.py delete mode 100644 api/graphon/nodes/variable_assigner/v2/enums.py delete mode 100644 api/graphon/nodes/variable_assigner/v2/exc.py delete mode 100644 api/graphon/nodes/variable_assigner/v2/helpers.py delete mode 100644 api/graphon/nodes/variable_assigner/v2/node.py delete mode 100644 api/graphon/prompt_entities.py delete mode 100644 api/graphon/runtime/__init__.py delete mode 100644 api/graphon/runtime/graph_runtime_state.py delete mode 100644 api/graphon/runtime/graph_runtime_state_protocol.py delete mode 100644 api/graphon/runtime/read_only_wrappers.py delete mode 100644 api/graphon/runtime/variable_pool.py delete mode 100644 api/graphon/template_rendering.py delete mode 100644 api/graphon/utils/__init__.py delete mode 100644 api/graphon/utils/condition/__init__.py delete mode 100644 api/graphon/utils/condition/entities.py delete mode 100644 api/graphon/utils/condition/processor.py delete mode 100644 api/graphon/utils/json_in_md_parser.py delete mode 100644 api/graphon/variable_loader.py delete mode 100644 api/graphon/variables/__init__.py delete mode 100644 api/graphon/variables/consts.py delete mode 100644 api/graphon/variables/exc.py delete mode 100644 api/graphon/variables/factory.py delete mode 100644 api/graphon/variables/input_entities.py delete mode 100644 api/graphon/variables/segment_group.py delete mode 100644 api/graphon/variables/segments.py delete mode 100644 api/graphon/variables/types.py delete mode 100644 api/graphon/variables/utils.py delete mode 100644 api/graphon/variables/variables.py delete mode 100644 api/graphon/workflow_type_encoder.py delete mode 100644 api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py delete mode 100644 api/tests/unit_tests/core/workflow/entities/test_pause_reason.py delete mode 100644 api/tests/unit_tests/core/workflow/entities/test_template.py delete mode 100644 api/tests/unit_tests/core/workflow/entities/test_variable_pool.py delete mode 100644 api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py delete mode 100644 api/tests/unit_tests/core/workflow/graph/test_graph.py delete mode 100644 api/tests/unit_tests/core/workflow/graph/test_graph_builder.py delete mode 100644 api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py delete mode 100644 api/tests/unit_tests/core/workflow/graph/test_graph_validation.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_worker.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/test_loop_node.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py delete mode 100644 api/tests/unit_tests/core/workflow/test_enums.py delete mode 100644 api/tests/unit_tests/core/workflow/utils/test_condition.py delete mode 100644 api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py delete mode 100644 api/tests/unit_tests/graphon/file/test_file_factory.py delete mode 100644 api/tests/unit_tests/graphon/file/test_file_manager.py delete mode 100644 api/tests/unit_tests/graphon/file/test_models.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/__base/__init__.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/__base/test_increase_tool_call.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/__base/test_large_language_model_non_stream_parsing.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/__init__.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/callbacks/test_base_callback.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/callbacks/test_logging_callback.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/entities/test_common_entities.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/entities/test_llm_entities.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/entities/test_message_entities.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/entities/test_model_entities.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/errors/test_invoke.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_ai_model.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_large_language_model.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_moderation_model.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_rerank_model.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_runtime_user_forwarding.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_speech2text_model.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_text_embedding_model.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_tts_model.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/schema_validators/test_common_validator.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/schema_validators/test_model_credential_schema_validator.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/schema_validators/test_provider_credential_schema_validator.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/utils/test_encoders.py delete mode 100644 api/tests/unit_tests/graphon/node_events/test_base.py delete mode 100644 api/tests/unit_tests/graphon/utils/test_json_in_md_parser.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 3f53811f85..94e857f93a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -36,7 +36,6 @@ /api/core/workflow/graph/ @laipz8200 @QuantumGhost /api/core/workflow/graph_events/ @laipz8200 @QuantumGhost /api/core/workflow/node_events/ @laipz8200 @QuantumGhost -/api/graphon/model_runtime/ @laipz8200 @WH-2099 # Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM) /api/core/workflow/nodes/agent/ @Nov1c444 diff --git a/api/.importlinter b/api/.importlinter index c2841f64d2..5e06947d94 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -3,7 +3,6 @@ root_packages = core constants context - graphon configs controllers extensions @@ -13,152 +12,3 @@ root_packages = tasks services include_external_packages = True - -[importlinter:contract:workflow] -name = Workflow -type=layers -layers = - graph_engine - graph_events - graph - nodes - node_events - runtime - entities -containers = - graphon -ignore_imports = - graphon.nodes.base.node -> graphon.graph_events - graphon.nodes.iteration.iteration_node -> graphon.graph_events - graphon.nodes.loop.loop_node -> graphon.graph_events - - graphon.nodes.iteration.iteration_node -> graphon.graph_engine - graphon.nodes.loop.loop_node -> graphon.graph_engine - # TODO(QuantumGhost): fix the import violation later - graphon.entities.pause_reason -> graphon.nodes.human_input.entities - -[importlinter:contract:workflow-external-imports] -name = Workflow External Imports -type = forbidden -source_modules = - graphon -forbidden_modules = - constants - configs - context - controllers - extensions - factories - libs - models - services - tasks - core.agent - core.app - core.base - core.callback_handler - core.datasource - core.db - core.entities - core.errors - core.extension - core.external_data_tool - core.file - core.helper - core.hosting_configuration - core.indexing_runner - core.llm_generator - core.logging - core.mcp - core.memory - core.moderation - core.ops - core.plugin - core.prompt - core.provider_manager - core.rag - core.repositories - core.schemas - core.tools - core.trigger - core.variables - -[importlinter:contract:workflow-third-party-imports] -name = Workflow Third-Party Imports -type = forbidden -source_modules = - graphon -forbidden_modules = - sqlalchemy - -[importlinter:contract:rsc] -name = RSC -type = layers -layers = - graph_engine - response_coordinator -containers = - graphon.graph_engine - -[importlinter:contract:worker] -name = Worker -type = layers -layers = - graph_engine - worker -containers = - graphon.graph_engine - -[importlinter:contract:graph-engine-architecture] -name = Graph Engine Architecture -type = layers -layers = - graph_engine - orchestration - command_processing - event_management - error_handler - graph_traversal - graph_state_manager - worker_management - domain -containers = - graphon.graph_engine - -[importlinter:contract:domain-isolation] -name = Domain Model Isolation -type = forbidden -source_modules = - graphon.graph_engine.domain -forbidden_modules = - graphon.graph_engine.worker_management - graphon.graph_engine.command_channels - graphon.graph_engine.layers - graphon.graph_engine.protocols - -[importlinter:contract:worker-management] -name = Worker Management -type = forbidden -source_modules = - graphon.graph_engine.worker_management -forbidden_modules = - graphon.graph_engine.orchestration - graphon.graph_engine.command_processing - graphon.graph_engine.event_management - - -[importlinter:contract:graph-traversal-components] -name = Graph Traversal Components -type = layers -layers = - edge_processor - skip_propagator -containers = - graphon.graph_engine.graph_traversal - -[importlinter:contract:command-channels] -name = Command Channels Independence -type = independence -modules = - graphon.graph_engine.command_channels.in_memory_channel - graphon.graph_engine.command_channels.redis_channel diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 515a6a5125..7348ef62aa 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -2,9 +2,9 @@ from __future__ import annotations from typing import Any, TypeAlias +from graphon.file import helpers as file_helpers from pydantic import BaseModel, ConfigDict, computed_field -from graphon.file import helpers as file_helpers from models.model import IconType JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any] diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 357697ed30..738e77b371 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -5,6 +5,8 @@ from typing import Any, Literal, TypeAlias from flask import request from flask_restx import Resource +from graphon.enums import WorkflowExecutionStatus +from graphon.file import helpers as file_helpers from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session @@ -27,8 +29,6 @@ from core.ops.ops_trace_manager import OpsTraceManager from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.trigger.constants import TRIGGER_NODE_TYPES from extensions.ext_database import db -from graphon.enums import WorkflowExecutionStatus -from graphon.file import helpers as file_helpers from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow from models.model import IconType diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 91fbe4a85a..78ddb904e1 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import Resource, fields +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -22,7 +23,6 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode from services.audio_service import AudioService diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index fe274e4c9a..d83925d173 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -3,6 +3,7 @@ from typing import Any, Literal from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -26,7 +27,6 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user, login_required diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index c720a5e074..7101d5df7b 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from controllers.console import console_ns @@ -19,7 +20,6 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db -from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_account_with_tenant, login_required from models import App from services.workflow_service import WorkflowService diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index dc752939ae..2afe276742 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -3,6 +3,7 @@ from typing import Literal from flask import request from flask_restx import Resource, fields, marshal_with +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from sqlalchemy import exists, func, select from werkzeug.exceptions import InternalServerError, NotFound @@ -26,7 +27,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from extensions.ext_database import db from fields.raws import FilesContainedField -from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import TimestampField, uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import current_account_with_tenant, login_required diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 2737dd1dfd..1f5a84c0b2 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -5,6 +5,10 @@ from typing import Any from flask import abort, request from flask_restx import Resource, fields, marshal_with +from graphon.enums import NodeType +from graphon.file import File +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -35,10 +39,6 @@ from extensions.ext_redis import redis_client from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields -from graphon.enums import NodeType -from graphon.file.models import File -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, uuid_value diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 8cf0004b09..f0e26c86a5 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -3,6 +3,7 @@ from datetime import datetime from dateutil.parser import isoparse from flask import request from flask_restx import Resource, marshal_with +from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session @@ -14,7 +15,6 @@ from fields.workflow_app_log_fields import ( build_workflow_app_log_pagination_model, build_workflow_archived_log_pagination_model, ) -from graphon.enums import WorkflowExecutionStatus from libs.login import login_required from models import App from models.model import AppMode diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 657b072490..4052897e9a 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -5,6 +5,10 @@ from typing import Any, NoReturn, ParamSpec, TypeVar from flask import Response, request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.file import helpers as file_helpers +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import Session @@ -20,10 +24,6 @@ from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTE from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type -from graphon.file import helpers as file_helpers -from graphon.variables.segment_group import SegmentGroup -from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment -from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import App, AppMode from models.workflow import WorkflowDraftVariable diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index d1df722729..83e8bedc11 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -3,6 +3,8 @@ from typing import Literal, TypedDict, cast from flask import request from flask_restx import Resource, fields, marshal_with +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -26,8 +28,6 @@ from fields.workflow_run_fields import ( workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from libs.custom_inputs import time_duration from libs.helper import uuid_value diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 665a80802d..686b865871 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -4,11 +4,11 @@ from typing import Concatenate, ParamSpec, TypeVar from flask import jsonify, request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models import Account from models.model import OAuthProviderApp diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 5d704b6224..f23c7eb431 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -2,6 +2,7 @@ from typing import Any, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, select from werkzeug.exceptions import Forbidden, NotFound @@ -51,7 +52,6 @@ from fields.dataset_fields import ( weighted_score_fields, ) from fields.document_fields import document_status_fields -from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index edb738aad8..ab367d8483 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -9,6 +9,8 @@ from uuid import UUID import sqlalchemy as sa from flask import request, send_file from flask_restx import Resource, fields, marshal, marshal_with +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import BaseModel, Field from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound @@ -37,8 +39,6 @@ from fields.document_fields import ( document_status_fields, document_with_segments_fields, ) -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DatasetProcessRule, Document, DocumentSegment, UploadFile diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 2fd84303d7..c5f4e3a6e2 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -2,6 +2,7 @@ import uuid from flask import request from flask_restx import Resource, marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import String, cast, func, or_, select from sqlalchemy.dialects.postgresql import JSONB @@ -30,7 +31,6 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields -from graphon.model_runtime.entities.model_entities import ModelType from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 699fa599c8..8fb3699849 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -2,6 +2,7 @@ import logging from typing import Any from flask_restx import marshal +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -20,7 +21,6 @@ from core.errors.error import ( QuotaExceededError, ) from fields.hit_testing_fields import hit_testing_record_fields -from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_user from models.account import Account from services.dataset_service import DatasetService diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 946fa599e6..1976a6bc8a 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -2,6 +2,8 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, NotFound @@ -10,8 +12,6 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.plugin.impl.oauth import OAuthHandler -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 977ae93c03..f12cbd3495 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -3,6 +3,7 @@ from typing import Any, NoReturn from flask import Response, request from flask_restx import Resource, marshal, marshal_with +from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -26,7 +27,6 @@ from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTE from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type -from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import Account from models.dataset import Pipeline diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 9079fbc29a..8e44bd6873 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -4,6 +4,7 @@ from typing import Any, Literal, cast from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -39,7 +40,6 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from factories import variable_factory -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant, current_user, login_required diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index bc78ee6d2d..b1b01b5f51 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,6 +1,7 @@ import logging from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -19,7 +20,6 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index ccdccceaa6..eacd7332fe 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -2,6 +2,7 @@ import logging from typing import Any, Literal from uuid import UUID +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -25,7 +26,6 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_database import db -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index a72cf6328a..fcbefcda33 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -2,6 +2,7 @@ import logging from typing import Literal from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from libs.login import current_account_with_tenant diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 26aa086aac..e432574434 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -3,6 +3,8 @@ from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -59,8 +61,6 @@ from fields.workflow_fields import ( workflow_fields, workflow_partial_fields, ) -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 17dbbdd534..42cafc7193 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,6 +1,8 @@ import logging from typing import Any +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel from werkzeug.exceptions import InternalServerError @@ -22,8 +24,6 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.login import current_account_with_tenant from models.model import AppMode, InstalledApp diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 2a46d2250a..551c86fd82 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -2,6 +2,7 @@ import urllib.parse import httpx from flask_restx import Resource +from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field import services @@ -15,7 +16,6 @@ from controllers.console import console_ns from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo -from graphon.file import helpers as file_helpers from libs.login import current_account_with_tenant, login_required from services.file_service import FileService diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 764f488755..3fdcbc4710 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,8 +1,8 @@ from flask_restx import Resource, fields +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.agent_service import AgentService diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index f45b72f390..b6b9deb1f9 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -2,13 +2,13 @@ from typing import Any from flask import request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginPermissionDeniedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 2a6f37aec8..e4cfca9fa4 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,12 +1,12 @@ from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index b22b91706e..8e0aefc9e3 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -3,13 +3,13 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 3c7b97d7fc..2ec1a9435a 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -3,14 +3,14 @@ from typing import Any, cast from flask import request from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index b3e344ccea..aa674a63b3 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -4,6 +4,7 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden @@ -14,7 +15,6 @@ from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginDaemonClientSideError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 1273b85bc3..02eb0adc94 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -5,6 +5,7 @@ from urllib.parse import urlparse from flask import make_response, redirect, request, send_file from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -27,7 +28,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index feedf074b7..265b6ecd9a 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -3,6 +3,7 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden @@ -15,7 +16,6 @@ from core.plugin.impl.oauth import OAuthHandler from core.trigger.entities.entities import SubscriptionBuilderUpdater from core.trigger.trigger_manager import TriggerManager from extensions.ext_database import db -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_user, login_required from models.account import Account from models.provider_ids import TriggerProviderID diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 72cab3de73..83c8fa02fe 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,4 +1,5 @@ from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns @@ -29,7 +30,6 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.signature import get_signed_file_url_for_plugin -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import length_prefixed_response from models import Account, Tenant from models.model import EndUser diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 869fb73cf5..3d00f77e79 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -2,6 +2,7 @@ from typing import Any, Union from flask import Response from flask_restx import Resource +from graphon.variables.input_entities import VariableEntity from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import Session @@ -10,7 +11,6 @@ from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request from extensions.ext_database import db -from graphon.variables.input_entities import VariableEntity from libs import helper from models.enums import AppMCPServerStatus from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 86d88ddafb..6228cfc25b 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -21,7 +22,6 @@ from controllers.service_api.app.error import ( ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 31f2797d66..3142e5118e 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -4,6 +4,7 @@ from uuid import UUID from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -28,7 +29,6 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 94afd47f7f..1759075139 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -4,6 +4,9 @@ from typing import Any, Literal from dateutil.parser import isoparse from flask import request from flask_restx import Namespace, Resource, fields +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -30,9 +33,6 @@ from core.helper.trace_id_helper import get_external_trace_id from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model -from graphon.enums import WorkflowExecutionStatus -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import OptionalTimestampField, TimestampField from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index dcf788f7a8..80205b283b 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -2,6 +2,7 @@ from typing import Any, Literal, cast from flask import request from flask_restx import marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import Forbidden, NotFound @@ -18,7 +19,6 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage from core.rag.index_processor.constant.index_type import IndexTechniqueType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag -from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 28fa915117..b4cc9874b6 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -2,6 +2,7 @@ from typing import Any from flask import request from flask_restx import marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select from werkzeug.exceptions import NotFound @@ -21,7 +22,6 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields -from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 5ac65fc4e6..c0a6cb0a76 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -1,9 +1,9 @@ from flask_login import current_user from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token -from graphon.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 8081dee0bd..9ba1dc4a3a 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import fields, marshal_with +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, field_validator from werkzeug.exceptions import InternalServerError @@ -20,7 +21,6 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 0528184d79..e37f9af5f0 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,6 +1,7 @@ import logging from typing import Any, Literal +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -25,7 +26,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.model import AppMode diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 4274b8c9ab..c5505dd60d 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -2,6 +2,7 @@ import logging from typing import Literal from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -22,7 +23,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.enums import FeedbackRating diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index fe31e9d4ac..38aeccc642 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,6 +1,7 @@ import urllib.parse import httpx +from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field, HttpUrl import services @@ -13,7 +14,6 @@ from controllers.common.errors import ( from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo -from graphon.file import helpers as file_helpers from services.file_service import FileService from ..common.schema import register_schema_models diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index ccef6e5b7f..7f5521f9f5 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,6 +1,8 @@ import logging from typing import Any +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -23,8 +25,6 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index a846cf4b0f..ff8f40407f 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -4,6 +4,20 @@ import uuid from decimal import Decimal from typing import Union, cast +from graphon.file import file_manager +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMUsage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from core.agent.entities import AgentEntity, AgentToolEntity @@ -29,20 +43,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from factories import file_factory -from graphon.file import file_manager -from graphon.model_runtime.entities import ( - AssistantPromptMessage, - LLMUsage, - PromptMessage, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from graphon.model_runtime.entities.model_entities import ModelFeature -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 0a0fdfdd29..11e2aa062d 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -4,6 +4,15 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence from typing import Any +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + ToolPromptMessage, + UserPromptMessage, +) + from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError @@ -15,14 +24,6 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index b3fc8d42e6..a4c438e929 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,6 +1,5 @@ import json -from core.agent.cot_agent_runner import CotAgentRunner from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -12,6 +11,8 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.model_runtime.utils.encoders import jsonable_encoder +from core.agent.cot_agent_runner import CotAgentRunner + class CotChatAgentRunner(CotAgentRunner): def _organize_system_prompt(self) -> SystemPromptMessage: diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 51a30998ae..d4c52a8eb1 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,6 +1,5 @@ import json -from core.agent.cot_agent_runner import CotAgentRunner from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -9,6 +8,8 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.utils.encoders import jsonable_encoder +from core.agent.cot_agent_runner import CotAgentRunner + class CotCompletionAgentRunner(CotAgentRunner): def _organize_instruction_prompt(self) -> str: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index d38d24d1e7..fdffde85d0 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -4,13 +4,6 @@ from collections.abc import Generator from copy import deepcopy from typing import Any, Union -from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.errors import AgentMaxIterationError -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -26,6 +19,14 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes + +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.errors import AgentMaxIterationError +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index c3e56fe011..46c1f1230d 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -3,9 +3,10 @@ import re from collections.abc import Generator from typing import Union -from core.agent.entities import AgentScratchpadUnit from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from core.agent.entities import AgentScratchpadUnit + class CotAgentOutputParser: @classmethod diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index dbd7527fc6..b7dd55632e 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -1,13 +1,14 @@ from typing import cast +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class ModelConfigConverter: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index f279f769aa..5cc385c378 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,9 +1,10 @@ from collections.abc import Mapping from typing import Any +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType + from core.app.app_config.entities import ModelConfigEntity from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 7715a5330a..76196e7034 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,5 +1,7 @@ from typing import Any +from graphon.model_runtime.entities.message_entities import PromptMessageRole + from core.app.app_config.entities import ( AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, @@ -7,7 +9,6 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.simple_prompt_transform import ModelMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode, AppModelConfigDict diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 6d63ae04d3..f0b71c5801 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,9 +1,10 @@ import re from typing import cast +from graphon.variables.input_entities import VariableEntity, VariableEntityType + from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import AppModelConfigDict _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index c67412cc29..536617edba 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -2,13 +2,13 @@ from collections.abc import Sequence from enum import StrEnum, auto from typing import Any, Literal -from pydantic import BaseModel, Field - -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from graphon.file import FileUploadConfig from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.entities.message_entities import PromptMessageRole from graphon.variables.input_entities import VariableEntity as WorkflowVariableEntity +from pydantic import BaseModel, Field + +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from models.model import AppMode diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 9092c1a17d..e96517c426 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,9 +1,10 @@ from collections.abc import Mapping from typing import Any -from constants import DEFAULT_FILE_NUMBER_LIMITS from graphon.file import FileUploadConfig +from constants import DEFAULT_FILE_NUMBER_LIMITS + class FileUploadConfigManager: @classmethod diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 13ace32fd6..62e0c31d1a 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,7 +1,8 @@ import re -from core.app.app_config.entities import RagPipelineVariableEntity from graphon.variables.input_entities import VariableEntity + +from core.app.app_config.entities import RagPipelineVariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index d69a80e4a9..aa2b65766f 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -18,6 +18,11 @@ from constants import UUID_NIL if TYPE_CHECKING: from controllers.console.app.workflow import LoopNodeRunPayload +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader + from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner @@ -43,10 +48,6 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index d21fce144e..a884a1c7f9 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -3,6 +3,12 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import Variable from sqlalchemy import select from sqlalchemy.orm import Session @@ -37,12 +43,6 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels.redis_channel import RedisChannel -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import Variable from models import Workflow from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 3577ae139b..5203de225c 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -9,6 +9,12 @@ from datetime import datetime from threading import Thread from typing import Any, Union +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session @@ -71,12 +77,6 @@ from core.repositories.human_input_repository import HumanInputFormRepositoryImp from core.workflow.file_reference import resolve_file_record_id from core.workflow.system_variables import build_system_variables from extensions.ext_database import db -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 1a44cc235e..bb258af4c1 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -6,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, In from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 09ddce327e..a20d3f3c38 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,6 +1,9 @@ import logging from typing import cast +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from core.agent.cot_chat_agent_runner import CotChatAgentRunner @@ -16,9 +19,6 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError from extensions.ext_database import db -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 5c9ba4567a..66390116d4 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -3,10 +3,11 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping from typing import Any, Union +from graphon.model_runtime.errors.invoke import InvokeError + from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 8e8ccf2b90..7eccd59d17 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -2,6 +2,9 @@ from collections.abc import Generator, Mapping, Sequence from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING, Any, Union, final +from graphon.enums import NodeType +from graphon.file import File, FileUploadConfig +from graphon.variables.input_entities import VariableEntityType from sqlalchemy.orm import Session from core.app.apps.draft_variable_saver import ( @@ -13,9 +16,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope from extensions.ext_database import db from factories import file_factory -from graphon.enums import NodeType -from graphon.file import File, FileUploadConfig -from graphon.variables.input_entities import VariableEntityType from libs.orjson import orjson_dumps from models import Account, EndUser from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index d1771452c5..20bf81aeec 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -7,6 +7,7 @@ from enum import IntEnum, auto from typing import Any from cachetools import TTLCache, cachedmethod +from graphon.runtime import GraphRuntimeState from redis.exceptions import RedisError from sqlalchemy.orm import DeclarativeMeta @@ -21,7 +22,6 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from extensions.ext_redis import redis_client -from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 4a4c8b535d..4aebc0cb30 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,6 +5,17 @@ from collections.abc import Generator, Mapping, Sequence from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError + from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( @@ -30,21 +41,11 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.tools.tool_file_manager import ToolFileManager from extensions.ext_database import db -from graphon.file.enums import FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, Message, MessageAnnotation, MessageFile if TYPE_CHECKING: - from graphon.file.models import File + from graphon.file import File _logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index db3a98c7ac..b675a87382 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -6,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, copy_current_request_context, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeF from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account from models.model import App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 077c5239f3..050f763e95 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -16,8 +18,6 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index 2a90fbdad0..ab277857fe 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,9 +4,10 @@ from __future__ import annotations from typing import TYPE_CHECKING -from core.workflow.system_variables import SystemVariableKey, get_system_text from graphon.runtime import GraphRuntimeState +from core.workflow.system_variables import SystemVariableKey, get_system_text + if TYPE_CHECKING: from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index e4aa2ff650..a515531616 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -6,6 +6,19 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, NewType, TypedDict, Union +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import FILE_MODEL_IDENTITY, File +from graphon.runtime import GraphRuntimeState +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.variables import Variable +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from sqlalchemy.orm import Session @@ -55,19 +68,6 @@ from core.workflow.human_input_forms import load_form_tokens_by_form_id from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db -from graphon.entities.pause_reason import HumanInputRequired -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import FILE_MODEL_IDENTITY, File -from graphon.runtime import GraphRuntimeState -from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment -from graphon.variables.variables import Variable -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.human_input import HumanInputForm diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index c418fe9759..a62c5b80b5 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -6,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, copy_current_request_context, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from sqlalchemy import select @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, I from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account, App, EndUser, Message from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 6bb1ecdcb1..b216f7cf7b 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager @@ -14,8 +16,6 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 48457b5326..fa242003a2 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -10,6 +10,8 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, cast, overload from flask import Flask, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -41,8 +43,6 @@ from core.repositories.factory import ( WorkflowNodeExecutionRepository, ) from extensions.ext_database import db -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 44d2450f74..4c188dac68 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -2,6 +2,14 @@ import logging import time from typing import cast +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner @@ -18,13 +26,6 @@ from core.workflow.system_variables import build_bootstrap_variables, build_syst from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db -from graphon.entities.graph_init_params import GraphInitParams -from graphon.enums import WorkflowType -from graphon.graph import Graph -from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from models.dataset import Document, Pipeline from models.model import EndUser from models.workflow import Workflow diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 8ad6893a15..9618ab35c6 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -8,6 +8,10 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, Union, overload from flask import Flask, current_app +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -34,10 +38,6 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models.account import Account from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index c02c0b16e9..2cb8088971 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -3,6 +3,12 @@ import time from collections.abc import Sequence from typing import cast +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner @@ -15,11 +21,6 @@ from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels.redis_channel import RedisChannel -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader from libs.datetime_utils import naive_utc_now from models.workflow import Workflow diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index e0c5b44ee4..49af169e88 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -4,6 +4,9 @@ from collections.abc import Callable, Generator from contextlib import contextmanager from typing import Union +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -58,9 +61,6 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.system_variables import build_system_variables from extensions.ext_database import db -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus -from graphon.runtime import GraphRuntimeState from models import Account from models.enums import CreatorUserRole from models.model import EndUser diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index d7d3bd27de..f68c8e60b4 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -3,6 +3,40 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph import Graph +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import ( + GraphEngineEvent, + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunAgentLogEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunHumanInputFormFilledEvent, + NodeRunHumanInputFormTimeoutEvent, + NodeRunIterationFailedEvent, + NodeRunIterationNextEvent, + NodeRunIterationStartedEvent, + NodeRunIterationSucceededEvent, + NodeRunLoopFailedEvent, + NodeRunLoopNextEvent, + NodeRunLoopStartedEvent, + NodeRunLoopSucceededEvent, + NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from pydantic import ValidationError from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -44,40 +78,6 @@ from core.workflow.system_variables import ( from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph import Graph -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import ( - GraphEngineEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunAgentLogEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.graph_events.graph import GraphRunAbortedEvent -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index d8d851c505..0cdbb5f50a 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -2,13 +2,13 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import TYPE_CHECKING, Any, Optional +from graphon.file import File, FileUploadConfig +from graphon.model_runtime.entities.model_entities import AIModelEntity from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from graphon.file import File, FileUploadConfig -from graphon.model_runtime.entities.model_entities import AIModelEntity if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 63857bfff2..5e56341f89 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -3,14 +3,14 @@ from datetime import datetime from enum import StrEnum, auto from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import PauseReason +from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from graphon.entities.pause_reason import PauseReason -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk class QueueEvent(StrEnum): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 719027bd23..ba3b2e356f 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -2,14 +2,14 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.nodes.human_input.entities import FormInput, UserAction class AnnotationReplyAccount(BaseModel): diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index d59f5125e3..d2d2fea4fb 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,8 +1,9 @@ import logging +from graphon.model_runtime.entities.message_entities import PromptMessage + from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation -from graphon.model_runtime.entities.message_entities import PromptMessage logger = logging.getLogger(__name__) diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index eeb9abbbfa..e09869f5f8 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -9,10 +9,11 @@ scope updates that matter to chat applications. import logging +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent + from core.workflow.system_variables import SystemVariableKey, get_system_text from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 98e2257b1f..79a5442130 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,15 +1,14 @@ from dataclasses import dataclass from typing import Annotated, Literal, Self, TypeAlias +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from pydantic import BaseModel, Field from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.workflow.system_variables import SystemVariableKey, get_system_text -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events.base import GraphEngineEvent -from graphon.graph_events.graph import GraphRunPausedEvent from models.model import AppMode from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/core/app/layers/suspend_layer.py b/api/core/app/layers/suspend_layer.py index 172306f271..1a79a9f843 100644 --- a/api/core/app/layers/suspend_layer.py +++ b/api/core/app/layers/suspend_layer.py @@ -1,6 +1,5 @@ -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events.base import GraphEngineEvent -from graphon.graph_events.graph import GraphRunPausedEvent +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent class SuspendLayer(GraphEngineLayer): diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index fef12df504..8c8daf8712 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -3,10 +3,10 @@ import uuid from typing import ClassVar from apscheduler.schedulers.background import BackgroundScheduler # type: ignore - from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events.base import GraphEngineEvent +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent + from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index 781a0aa3d3..77c7bec67e 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -2,13 +2,12 @@ import logging from datetime import UTC, datetime from typing import Any, ClassVar +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from pydantic import TypeAdapter from core.db.session_factory import session_factory from core.workflow.system_variables import SystemVariableKey, get_system_text -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events.base import GraphEngineEvent -from graphon.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from models.enums import WorkflowTriggerStatus from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index c49c4eb0ac..278d0cb30b 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -2,15 +2,16 @@ from __future__ import annotations from typing import Any +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.llm.entities import ModelConfig +from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from graphon.nodes.llm.protocols import CredentialsProvider + from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.nodes.llm.entities import ModelConfig -from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from graphon.nodes.llm.protocols import CredentialsProvider class DifyCredentialsProvider: diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 65a3f39d64..63d2235358 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -1,3 +1,4 @@ +from graphon.model_runtime.entities.llm_entities import LLMUsage from sqlalchemy import update from sqlalchemy.orm import Session @@ -7,7 +8,6 @@ from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance from extensions.ext_database import db -from graphon.model_runtime.entities.llm_entities import LLMUsage from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType from models.provider_ids import ModelProviderID diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 9e688589db..10b9c36d3e 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,6 +1,7 @@ import logging import time +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from sqlalchemy import select from sqlalchemy.orm import Session @@ -17,7 +18,6 @@ from core.app.entities.task_entities import ( ) from core.errors.error import QuotaExceededError from core.moderation.output_moderation import ModerationRule, OutputModeration -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus from models.model import Message diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index cf9cb6d051..a410fac558 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -4,6 +4,13 @@ from collections.abc import Generator from threading import Thread from typing import Any, Union, cast +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -53,13 +60,6 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db -from graphon.file.enums import FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - TextPromptMessageContent, -) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py index 45f622c469..b23a33923b 100644 --- a/api/core/app/task_pipeline/message_file_utils.py +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -1,8 +1,9 @@ from typing import TypedDict -from core.tools.signature import sign_tool_file +from graphon.file import FileTransferMethod from graphon.file import helpers as file_helpers -from graphon.file.enums import FileTransferMethod + +from core.tools.signature import sign_tool_file from models.model import MessageFile, UploadFile MAX_TOOL_FILE_EXTENSION_LENGTH = 10 diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index aa5291bad5..8604235ef2 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -9,6 +9,10 @@ import urllib.parse from collections.abc import Generator from typing import TYPE_CHECKING, Literal +from graphon.file import FileTransferMethod +from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +from graphon.file.runtime import set_workflow_file_runtime + from configs import dify_config from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol from core.db.session_factory import session_factory @@ -16,12 +20,9 @@ from core.helper.ssrf_proxy import ssrf_proxy from core.tools.signature import sign_tool_file from core.workflow.file_reference import parse_file_reference from extensions.ext_storage import storage -from graphon.file.enums import FileTransferMethod -from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol -from graphon.file.runtime import set_workflow_file_runtime if TYPE_CHECKING: - from graphon.file.models import File + from graphon.file import File class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index 5666bf1191..48cabaf4d0 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -7,18 +7,17 @@ This layer centralizes model-quota deduction outside node implementations. import logging from typing import TYPE_CHECKING, cast, final +from graphon.enums import BuiltinNodeTypes +from graphon.graph_engine.entities.commands import AbortCommand, CommandType +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent +from graphon.nodes.base.node import Node from typing_extensions import override from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.llm import deduct_llm_quota, ensure_llm_quota_available from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from graphon.enums import BuiltinNodeTypes -from graphon.graph_engine.entities.commands import AbortCommand, CommandType -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase -from graphon.graph_events.node import NodeRunSucceededEvent -from graphon.nodes.base.node import Node if TYPE_CHECKING: from graphon.nodes.llm.node import LLMNode diff --git a/api/core/app/workflow/layers/observability.py b/api/core/app/workflow/layers/observability.py index 837bf7ff81..8565c3076c 100644 --- a/api/core/app/workflow/layers/observability.py +++ b/api/core/app/workflow/layers/observability.py @@ -11,6 +11,10 @@ import logging from dataclasses import dataclass from typing import cast, final +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node from opentelemetry import context as context_api from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context from typing_extensions import override @@ -24,10 +28,6 @@ from extensions.otel.parser import ( ToolNodeOTelParser, ) from extensions.otel.runtime import is_instrument_flag_enabled -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index e540733de2..ada065a943 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -14,13 +14,6 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, Union -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository -from core.workflow.system_variables import SystemVariableKey -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from graphon.entities import WorkflowExecution, WorkflowNodeExecution from graphon.enums import ( WorkflowExecutionStatus, @@ -28,7 +21,7 @@ from graphon.enums import ( WorkflowNodeExecutionStatus, WorkflowType, ) -from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import ( GraphEngineEvent, GraphRunAbortedEvent, @@ -45,6 +38,14 @@ from graphon.graph_events import ( NodeRunSucceededEvent, ) from graphon.node_events import NodeRunResult + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from libs.datetime_utils import naive_utc_now diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 9e3c187210..3d8a7a54f3 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -6,6 +6,9 @@ import re import threading from collections.abc import Iterable +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAgentMessageEvent, @@ -15,8 +18,6 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent -from graphon.model_runtime.entities.model_entities import ModelType class AudioTrunk: diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 8a9875e4d7..143d1e696b 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -3,6 +3,9 @@ from collections.abc import Generator from threading import Lock from typing import Any, cast +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType, get_file_type_by_mime_type +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from sqlalchemy import select import contexts @@ -28,11 +31,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager from core.workflow.file_reference import build_file_reference from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam from factories import file_factory -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import WorkflowNodeExecutionMetadataKey -from graphon.file import File, get_file_type_by_mime_type -from graphon.file.enums import FileTransferMethod, FileType -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from models.model import UploadFile from models.tools import ToolFile from services.datasource_provider_service import DatasourceProviderService diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 84dd653772..14d1af2e8b 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -1,10 +1,10 @@ from typing import Literal, Optional +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter from core.tools.entities.common_entities import I18nObject -from graphon.model_runtime.utils.encoders import jsonable_encoder class DatasourceApiEntity(BaseModel): diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 089b8b8e59..04f15dee31 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -2,10 +2,11 @@ import logging from collections.abc import Generator from mimetypes import guess_extension, guess_type +from graphon.file import File, FileTransferMethod, FileType + from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference -from graphon.file import File, FileTransferMethod, FileType from models.tools import ToolFile logger = logging.getLogger(__name__) diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py index 9d970d5db1..72f6590e68 100644 --- a/api/core/entities/execution_extra_content.py +++ b/api/core/entities/execution_extra_content.py @@ -3,9 +3,9 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from typing import Any, TypeAlias +from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field -from graphon.nodes.human_input.entities import FormInput, UserAction from models.execution_extra_content import ExecutionContentType diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index bfa4f56915..a440829b46 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -6,6 +6,7 @@ from enum import StrEnum from typing import TYPE_CHECKING, Any from urllib.parse import urlparse +from graphon.file import helpers as file_helpers from pydantic import BaseModel from configs import dify_config @@ -15,7 +16,6 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from graphon.file import helpers as file_helpers if TYPE_CHECKING: from models.tools import MCPToolProvider diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index e99a131500..84d95c38c6 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,11 +1,10 @@ from collections.abc import Sequence from enum import StrEnum, auto -from pydantic import BaseModel, ConfigDict - from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType, ProviderModel from graphon.model_runtime.entities.provider_entities import ProviderEntity +from pydantic import BaseModel, ConfigDict class ModelStatus(StrEnum): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index d90afd3f7b..8b48aa2660 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -7,6 +7,16 @@ from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.runtime import ModelRuntime from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -22,16 +32,6 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from graphon.model_runtime.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index dffc7f2fc1..2c8767a32b 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -3,6 +3,7 @@ from __future__ import annotations from enum import StrEnum, auto from typing import Union +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, ConfigDict, Field from core.entities.parameter_entities import ( @@ -12,7 +13,6 @@ from core.entities.parameter_entities import ( ToolSelectorScope, ) from core.tools.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType class ProviderQuotaType(StrEnum): diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 951e065b2c..35bfcfb6a5 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,6 +4,7 @@ from threading import Lock from typing import Any import httpx +from graphon.nodes.code.entities import CodeLanguage from pydantic import BaseModel from yarl import URL @@ -13,7 +14,6 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client -from graphon.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index dc37a36943..a1e782a094 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -2,13 +2,14 @@ import logging import secrets from typing import cast +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeBadRequestError +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from extensions.ext_hosting_provider import hosting_configuration -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeBadRequestError -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel from models.provider import ProviderType logger = logging.getLogger(__name__) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index eb762c3508..60f5434bc1 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,10 +1,10 @@ from flask import Flask +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel from configs import dify_config from core.entities import DEFAULT_PLUGIN_ID from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel -from graphon.model_runtime.entities.model_entities import ModelType class HostingQuota(BaseModel): diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 46bf1d6937..3ec17bc986 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -9,6 +9,7 @@ from collections.abc import Mapping from typing import Any from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm.exc import ObjectDeletedError @@ -34,7 +35,6 @@ from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage -from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models import Account diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 3712374305..3d94f1a596 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -5,6 +5,11 @@ from collections.abc import Sequence from typing import Protocol, cast import json_repair +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload @@ -29,11 +34,6 @@ from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from extensions.ext_storage import storage -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models import App, Message, WorkflowNodeExecutionModel from models.workflow import Workflow diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 81672ee7aa..a1710f11ac 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -5,11 +5,6 @@ from enum import StrEnum from typing import Any, Literal, cast, overload import json_repair -from pydantic import TypeAdapter, ValidationError - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT -from core.model_manager import ModelInstance from graphon.model_runtime.callbacks.base_callback import Callback from graphon.model_runtime.entities.llm_entities import ( LLMResult, @@ -26,6 +21,11 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule +from pydantic import TypeAdapter, ValidationError + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT +from core.model_manager import ModelInstance class ResponseFormat(StrEnum): diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 92d23c6dc9..27000c947c 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -3,11 +3,12 @@ import logging from collections.abc import Mapping from typing import Any, cast +from graphon.variables.input_entities import VariableEntity, VariableEntityType + from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types as mcp_types -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index 7b5a7635f1..7e35044176 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -4,11 +4,11 @@ from contextlib import AbstractContextManager import httpx import httpx_sse +from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError -from graphon.model_runtime.utils.encoders import jsonable_encoder HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 658206128d..09c84538a9 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,14 +1,5 @@ from collections.abc import Sequence -from sqlalchemy import select -from sqlalchemy.orm import sessionmaker - -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.file_access import DatabaseFileAccessController -from core.model_manager import ModelInstance -from core.prompt.utils.extract_thread_messages import extract_thread_messages -from extensions.ext_database import db -from factories import file_factory from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -19,6 +10,15 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.file_access import DatabaseFileAccessController +from core.model_manager import ModelInstance +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from extensions.ext_database import db +from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile from models.workflow import Workflow from repositories.api_workflow_run_repository import APIWorkflowRunRepository diff --git a/api/core/model_manager.py b/api/core/model_manager.py index f5ff375f65..87d1d7fba6 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -2,14 +2,6 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from typing import IO, Any, Literal, Optional, Union, cast, overload -from configs import dify_config -from core.entities.embedding_type import EmbeddingInputType -from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import ModelLoadBalancingConfiguration -from core.errors.error import ProviderTokenNotInitError -from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager -from core.provider_manager import ProviderManager -from extensions.ext_redis import redis_client from graphon.model_runtime.callbacks.base_callback import Callback from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -23,6 +15,15 @@ from graphon.model_runtime.model_providers.__base.rerank_model import RerankMode from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from graphon.model_runtime.model_providers.__base.tts_model import TTSModel + +from configs import dify_config +from core.entities.embedding_type import EmbeddingInputType +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import ModelLoadBalancingConfiguration +from core.errors.error import ProviderTokenNotInitError +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from core.provider_manager import ProviderManager +from extensions.ext_redis import redis_client from models.provider import ProviderType from services.enterprise.plugin_manager_service import PluginCredentialType diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 35d4469bc1..dd038c77f1 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,6 +1,7 @@ +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelManager from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult -from graphon.model_runtime.entities.model_entities import ModelType class OpenAIModeration(Moderation): diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 76e81242f4..70aaf2a07b 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -1,6 +1,8 @@ import logging from collections.abc import Sequence +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import SpanKind from sqlalchemy.orm import sessionmaker @@ -58,8 +60,6 @@ from core.ops.entities.trace_entity import ( ) from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/aliyun_trace/utils.py b/api/core/ops/aliyun_trace/utils.py index 956fc60191..d8e105d6a3 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/core/ops/aliyun_trace/utils.py @@ -2,6 +2,8 @@ import json from collections.abc import Mapping from typing import Any +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, Status, StatusCode from core.ops.aliyun_trace.entities.semconv import ( @@ -15,8 +17,6 @@ from core.ops.aliyun_trace.entities.semconv import ( ) from core.rag.models.document import Document from extensions.ext_database import db -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser # Constants diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index a1ea182f66..39d97e2882 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -6,6 +6,7 @@ from datetime import datetime, timedelta from typing import Any, Union, cast from urllib.parse import urlparse +from graphon.enums import WorkflowNodeExecutionStatus from openinference.semconv.trace import ( MessageAttributes, OpenInferenceMimeTypeValues, @@ -39,7 +40,6 @@ from core.ops.entities.trace_entity import ( ) from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import WorkflowNodeExecutionStatus from models.model import EndUser, MessageFile from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 3bf01eb81c..3644b6b4c2 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -2,6 +2,7 @@ import logging import os from datetime import datetime, timedelta +from graphon.enums import BuiltinNodeTypes from langfuse import Langfuse from sqlalchemy.orm import sessionmaker @@ -29,7 +30,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( from core.ops.utils import filter_none_values from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes from models import EndUser, WorkflowNodeExecutionTriggeredFrom from models.enums import MessageStatus diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index d960038f15..490c64af84 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -4,6 +4,7 @@ import uuid from datetime import datetime, timedelta from typing import cast +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from langsmith import Client from langsmith.schemas import RunBase from sqlalchemy.orm import sessionmaker @@ -29,7 +30,6 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/core/ops/mlflow_trace/mlflow_trace.py index 8bf2e5dc13..946d3cdd47 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/core/ops/mlflow_trace/mlflow_trace.py @@ -5,6 +5,7 @@ from datetime import datetime, timedelta from typing import Any, cast import mlflow +from graphon.enums import BuiltinNodeTypes from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey from mlflow.tracing.fluent import start_span_no_context, update_current_trace @@ -25,7 +26,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes from models import EndUser from models.workflow import WorkflowNodeExecutionModel diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index b98cc3ce59..2215bdeb33 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -5,6 +5,7 @@ import uuid from datetime import datetime, timedelta from typing import cast +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 from sqlalchemy.orm import sessionmaker @@ -24,7 +25,6 @@ from core.ops.entities.trace_entity import ( ) from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/core/ops/tencent_trace/span_builder.py index 4f06458157..f79095d966 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/core/ops/tencent_trace/span_builder.py @@ -6,6 +6,8 @@ import json import logging from datetime import datetime +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import Status, StatusCode from core.ops.entities.trace_entity import ( @@ -41,11 +43,6 @@ from core.ops.tencent_trace.entities.semconv import ( from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.utils import TencentTraceUtils from core.rag.models.document import Document -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index 1b1b1025bc..2bd6db22bf 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -4,6 +4,10 @@ Tencent APM tracing implementation with separated concerns import logging +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, +) +from graphon.nodes import BuiltinNodeTypes from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -25,10 +29,6 @@ from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from extensions.ext_database import db -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, -) -from graphon.nodes import BuiltinNodeTypes from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index f79544f1c7..8d9ba4694d 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -6,6 +6,7 @@ from typing import Any, cast import wandb import weave +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from sqlalchemy.orm import sessionmaker from weave.trace_server.trace_server_interface import ( CallEndReq, @@ -32,7 +33,6 @@ from core.ops.entities.trace_entity import ( from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 85625fc87d..c715b9171c 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -2,6 +2,20 @@ import tempfile from binascii import hexlify, unhexlify from collections.abc import Generator +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.llm import deduct_llm_quota from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelManager @@ -18,19 +32,6 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from graphon.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelType from models.account import Tenant diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 248f8ef3e6..9478997494 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,8 +1,5 @@ -from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from graphon.enums import BuiltinNodeTypes -from graphon.nodes.parameter_extractor.entities import ( - ModelConfig as ParameterExtractorModelConfig, -) +from graphon.nodes.llm.entities import ModelConfig as LLMModelConfig from graphon.nodes.parameter_extractor.entities import ( ParameterConfig, ParameterExtractorNodeData, @@ -11,9 +8,8 @@ from graphon.nodes.question_classifier.entities import ( ClassConfig, QuestionClassifierNodeData, ) -from graphon.nodes.question_classifier.entities import ( - ModelConfig as QuestionClassifierModelConfig, -) + +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from services.workflow_service import WorkflowService @@ -24,7 +20,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): tenant_id: str, user_id: str, parameters: list[ParameterConfig], - model_config: ParameterExtractorModelConfig, + model_config: LLMModelConfig, instruction: str, query: str, ): @@ -74,7 +70,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): cls, tenant_id: str, user_id: str, - model_config: QuestionClassifierModelConfig, + model_config: LLMModelConfig, classes: list[ClassConfig], instruction: str, query: str, diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index 1bd239a831..2177e8af90 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -1,10 +1,10 @@ +from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, Field, computed_field, model_validator from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.plugin.entities.plugin import PluginResourceRequirements from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity -from graphon.model_runtime.entities.provider_entities import ProviderEntity class MarketplacePluginDeclaration(BaseModel): diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 6aefc41400..b095b4998d 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -3,6 +3,7 @@ from collections.abc import Mapping from enum import StrEnum, auto from typing import Any +from graphon.model_runtime.entities.provider_entities import ProviderEntity from packaging.version import InvalidVersion, Version from pydantic import BaseModel, Field, field_validator, model_validator @@ -13,7 +14,6 @@ from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity from core.trigger.entities.entities import TriggerProviderEntity -from graphon.model_runtime.entities.provider_entities import ProviderEntity class PluginInstallationSource(StrEnum): diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 864e4b8dd7..94263ec44e 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -6,6 +6,8 @@ from datetime import datetime from enum import StrEnum from typing import Any, Generic, TypeVar +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, ConfigDict, Field from core.agent.plugin_entities import AgentProviderEntityWithPlugin @@ -16,8 +18,6 @@ from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.trigger.entities.entities import TriggerProviderEntity -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.provider_entities import ProviderEntity T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 704cacae2a..059f3fa9be 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -4,10 +4,6 @@ from collections.abc import Mapping from typing import Any, Literal from flask import Response -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from core.entities.provider_entities import BasicProviderConfig -from core.plugin.utils.http_parser import deserialize_response from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -18,18 +14,17 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.model_entities import ModelType -from graphon.nodes.parameter_extractor.entities import ( - ModelConfig as ParameterExtractorModelConfig, -) +from graphon.nodes.llm.entities import ModelConfig as LLMModelConfig from graphon.nodes.parameter_extractor.entities import ( ParameterConfig, ) from graphon.nodes.question_classifier.entities import ( ClassConfig, ) -from graphon.nodes.question_classifier.entities import ( - ModelConfig as QuestionClassifierModelConfig, -) +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from core.entities.provider_entities import BasicProviderConfig +from core.plugin.utils.http_parser import deserialize_response class InvokeCredentials(BaseModel): @@ -176,7 +171,7 @@ class RequestInvokeParameterExtractorNode(BaseModel): """ parameters: list[ParameterConfig] - model: ParameterExtractorModelConfig + model: LLMModelConfig instruction: str query: str @@ -187,7 +182,7 @@ class RequestInvokeQuestionClassifierNode(BaseModel): """ query: str - model: QuestionClassifierModelConfig + model: LLMModelConfig classes: list[ClassConfig] instruction: str diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 44047911da..2d0ab3fcd7 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -5,6 +5,14 @@ from collections.abc import Callable, Generator from typing import Any, TypeVar, cast import httpx +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from yarl import URL @@ -28,14 +36,6 @@ from core.trigger.errors import ( TriggerPluginInvokeError, TriggerProviderCredentialValidationError, ) -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index c91fa71374..1e38c24717 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -2,6 +2,13 @@ import binascii from collections.abc import Generator, Sequence from typing import IO, Any +from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.utils.encoders import jsonable_encoder + from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDaemonInnerError, @@ -13,12 +20,6 @@ from core.plugin.entities.plugin_daemon import ( PluginVoicesResponse, ) from core.plugin.impl.base import BasePluginClient -from graphon.model_runtime.entities.llm_entities import LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult -from graphon.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index e3fba4ef3a..22c846b6de 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -6,6 +6,13 @@ from collections.abc import Generator, Iterable, Sequence from threading import Lock from typing import IO, Any, Union +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult +from graphon.model_runtime.runtime import ModelRuntime from pydantic import ValidationError from redis import RedisError @@ -14,13 +21,6 @@ from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.model import PluginModelClient from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.model_runtime.entities.provider_entities import ProviderEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult -from graphon.model_runtime.runtime import ModelRuntime from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py index 35abd2ae8c..4b29a6fc56 100644 --- a/api/core/plugin/impl/model_runtime_factory.py +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -2,9 +2,10 @@ from __future__ import annotations from typing import TYPE_CHECKING -from core.plugin.impl.model import PluginModelClient from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model import PluginModelClient + if TYPE_CHECKING: from core.model_manager import ModelManager from core.plugin.impl.model_runtime import PluginModelRuntime diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 322f78ab4e..90350f8400 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,7 +1,8 @@ from typing import Any +from graphon.file import File + from core.tools.entities.tool_entities import ToolSelector -from graphon.file.models import File def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index de87a09652..19b5e9223a 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,15 +1,7 @@ from collections.abc import Mapping, Sequence from typing import cast -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from graphon.file import file_manager -from graphon.file.models import File +from graphon.file import File, file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, @@ -21,6 +13,14 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.runtime import VariablePool +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser + class AdvancedPromptTransform(PromptTransform): """ diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 8f1d51f08a..9be70199b7 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -1,10 +1,5 @@ from typing import cast -from core.app.entities.app_invoke_entities import ( - ModelConfigWithCredentialsEntity, -) -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.prompt_transform import PromptTransform from graphon.model_runtime.entities.message_entities import ( PromptMessage, SystemPromptMessage, @@ -12,6 +7,12 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.app.entities.app_invoke_entities import ( + ModelConfigWithCredentialsEntity, +) +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.prompt_transform import PromptTransform + class AgentHistoryPromptTransform(PromptTransform): """ diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 6ff2f44cdc..4539ae9f11 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,11 +1,12 @@ from typing import Any +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from graphon.model_runtime.entities.message_entities import PromptMessage -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey class PromptTransform: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index e091215b80..c706353ffe 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -4,12 +4,6 @@ from collections.abc import Mapping, Sequence from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, cast -from core.app.app_config.entities import PromptTemplateEntity -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import file_manager from graphon.model_runtime.entities.message_entities import ( ImagePromptMessageContent, @@ -19,10 +13,17 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) + +from core.app.app_config.entities import PromptTemplateEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode if TYPE_CHECKING: - from graphon.file.models import File + from graphon.file import File class ModelMode(StrEnum): diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index ba76eb0c4e..dbda749925 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,7 +1,6 @@ from collections.abc import Sequence from typing import Any, cast -from core.prompt.simple_prompt_transform import ModelMode from graphon.model_runtime.entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -12,6 +11,8 @@ from graphon.model_runtime.entities import ( TextPromptMessageContent, ) +from core.prompt.simple_prompt_transform import ModelMode + class PromptMessageUtil: @staticmethod diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 79fd78fe80..30933239f6 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -7,6 +7,14 @@ from collections.abc import Sequence from json import JSONDecodeError from typing import TYPE_CHECKING, Any, cast +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -33,14 +41,6 @@ from core.helper.position_helper import is_filtered from extensions import ext_hosting_provider from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ( LoadBalancingModelConfig, Provider, diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 2c81653559..b872ea8a8f 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,3 +1,5 @@ +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from typing_extensions import TypedDict from core.model_manager import ModelInstance, ModelManager @@ -8,8 +10,6 @@ from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_type import RerankMode -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class RerankingModelDict(TypedDict): diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 1e4aa24287..cc6ec12c75 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, NotRequired from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm import Session, load_only from typing_extensions import TypedDict @@ -24,7 +25,6 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ( ChildChunk, Dataset, diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index a77458706a..5a8d3a2f3f 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -4,6 +4,7 @@ import time from abc import ABC, abstractmethod from typing import Any +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from configs import dify_config @@ -17,7 +18,6 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Whitelist from models.model import UploadFile diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 369159767e..e5b794f80d 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -3,13 +3,13 @@ from __future__ import annotations from collections.abc import Sequence from typing import Any +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import func, select from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index b12a0ae2d6..3bdad00712 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -4,6 +4,8 @@ import pickle from typing import Any, cast import numpy as np +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from sqlalchemy.exc import IntegrityError from configs import dify_config @@ -12,8 +14,6 @@ from core.model_manager import ModelInstance from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from libs import helper from models.dataset import Embedding diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 9f36b7a225..5c10ffbf2d 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -8,6 +8,17 @@ from typing import Any, cast logger = logging.getLogger(__name__) +from graphon.file import File, FileTransferMethod, FileType, file_manager +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType + from core.app.file_access import DatabaseFileAccessController from core.app.llm import deduct_llm_quota from core.entities.knowledge_entities import PreviewDetail @@ -31,16 +42,6 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols from core.workflow.file_reference import build_file_reference from extensions.ext_database import db from factories.file_factory import build_from_mapping -from graphon.file import File, FileTransferMethod, FileType, file_manager -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentUnionTypes, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from libs import helper from models import UploadFile from models.account import Account diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 4ebf095904..087736d0b0 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,9 +2,8 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any -from pydantic import BaseModel, Field - from graphon.file import File +from pydantic import BaseModel, Field class ChildDocument(BaseModel): diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 6c6b077cc2..211a9f5c5c 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,5 +1,8 @@ import base64 +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.rerank_entities import RerankResult + from core.model_manager import ModelInstance, ModelManager from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.query_type import QueryType @@ -7,8 +10,6 @@ from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner from extensions.ext_database import db from extensions.ext_storage import storage -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.rerank_entities import RerankResult from models.model import UploadFile diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index d0732b269a..49123e13d0 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -2,6 +2,7 @@ import math from collections import Counter import numpy as np +from graphon.model_runtime.entities.model_entities import ModelType from core.model_manager import ModelManager from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler @@ -11,7 +12,6 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner -from graphon.model_runtime.entities.model_entities import ModelType class WeightRerankRunner(BaseRerankRunner): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 49b91707ec..1abea6639e 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -9,6 +9,11 @@ from collections.abc import Generator, Mapping from typing import Any, Union, cast from flask import Flask, current_app +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import and_, func, literal, or_, select from sqlalchemy.orm import Session @@ -66,11 +71,6 @@ from core.workflow.nodes.knowledge_retrieval.retrieval import ( ) from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.file import File, FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from libs.helper import parse_uuid_str_or_none from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index e617a9660e..dce7b6226c 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -1,9 +1,10 @@ from typing import Union +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage class FunctionCallMultiDatasetRouter: diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 83e58fe0f9..dd280cdf6a 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,6 +1,10 @@ from collections.abc import Generator, Sequence from typing import Union +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.llm import deduct_llm_quota from core.model_manager import ModelInstance, ModelManager @@ -8,9 +12,6 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from graphon.model_runtime.entities.model_entities import ModelType PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 2c27ac3cf6..e6aec4a3af 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -6,6 +6,8 @@ import codecs import re from typing import Any +from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer + from core.model_manager import ModelInstance from core.rag.splitter.text_splitter import ( TS, @@ -15,7 +17,6 @@ from core.rag.splitter.text_splitter import ( Set, Union, ) -from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index d0164b76dc..465f43da73 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -8,11 +8,11 @@ providing improved performance by offloading database operations to background w import logging from typing import Union +from graphon.entities import WorkflowExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository -from graphon.entities.workflow_execution import WorkflowExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 52361cf6dc..22ef44b3dc 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -9,6 +9,7 @@ import logging from collections.abc import Sequence from typing import Union +from graphon.entities import WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -16,7 +17,6 @@ from core.repositories.factory import ( OrderConfig, WorkflowNodeExecutionRepository, ) -from graphon.entities.workflow_node_execution import WorkflowNodeExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index dafdbf641a..ed6d44f434 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -9,11 +9,11 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Literal, Protocol, Union +from graphon.entities import WorkflowExecution, WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config -from graphon.entities import WorkflowExecution, WorkflowNodeExecution from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 02625e242f..72d9394149 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -4,6 +4,8 @@ from collections.abc import Mapping, Sequence from datetime import datetime from typing import Any, Protocol +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import select from sqlalchemy.orm import Session, selectinload @@ -17,8 +19,6 @@ from core.workflow.human_input_compat import ( InteractiveSurfaceDeliveryMethod, is_human_input_webapp_enabled, ) -from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.account import Account, TenantAccountJoin diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 1ee5d4ae77..85d20b675d 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -6,13 +6,13 @@ import json import logging from typing import Union +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 749ab44a14..a72bfa378b 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -10,6 +10,10 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, TypeVar, Union import psycopg2.errors +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import UnaryExpression, asc, desc, select from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError @@ -19,10 +23,6 @@ from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_att from configs import dify_config from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.ext_storage import storage -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from libs.uuid_utils import uuidv7 from models import ( diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index 40bf2e98c2..e539074303 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -2,14 +2,15 @@ import io from collections.abc import Generator from typing import Any +from graphon.file import FileType +from graphon.file.file_manager import download +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from graphon.file.enums import FileType -from graphon.file.file_manager import download -from graphon.model_runtime.entities.model_entities import ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index ac3820f1ab..f49c669fe0 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -2,12 +2,13 @@ import io from collections.abc import Generator from typing import Any +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType + from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index d41503e1e6..14af63a962 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,11 +1,12 @@ from __future__ import annotations +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage + from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language and you can quickly aimed at the main point of an webpage and reproduce it in your own words but diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 168e5f4493..0a2c37c563 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -6,6 +6,7 @@ from typing import Any, Union from urllib.parse import urlencode import httpx +from graphon.file.file_manager import download from core.helper import ssrf_proxy from core.tools.__base.tool import Tool @@ -13,7 +14,6 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError -from graphon.file.file_manager import download API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 08640befb4..d5d3d1b1d9 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -2,6 +2,7 @@ from collections.abc import Mapping from datetime import datetime from typing import Any, Literal +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration @@ -9,7 +10,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from graphon.model_runtime.utils.encoders import jsonable_encoder class ToolApiEntity(BaseModel): diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 00fc8a8282..f6d09472b3 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -6,6 +6,8 @@ import logging from collections.abc import Generator, Mapping from typing import Any, cast +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata + from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPConnectionError from core.mcp.types import ( @@ -21,7 +23,6 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata logger = logging.getLogger(__name__) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 1fd259f3bb..685d687d8c 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -7,6 +7,7 @@ from datetime import UTC, datetime from mimetypes import guess_type from typing import Any, Union, cast +from graphon.file import FileTransferMethod, FileType from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom @@ -32,8 +33,6 @@ from core.tools.errors import ( from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db -from graphon.file import FileType -from graphon.file.models import FileTransferMethod from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import Message, MessageFile diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 2ec292602c..7ac29cf069 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -10,13 +10,13 @@ from typing import Union from uuid import uuid4 import httpx +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from configs import dify_config from core.db.session_factory import session_factory from core.helper import ssrf_proxy from core.workflow.file_reference import build_file_reference from extensions.ext_storage import storage -from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from models.model import MessageFile from models.tools import ToolFile diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 4a10c7e23e..584bae39b9 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -8,6 +8,7 @@ from threading import Lock from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast import sqlalchemy as sa +from graphon.runtime import VariablePool from sqlalchemy import select from sqlalchemy.orm import Session from yarl import URL @@ -25,7 +26,6 @@ from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from extensions.ext_database import db -from graphon.runtime.variable_pool import VariablePool from models.provider_ids import ToolProviderID from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService @@ -33,6 +33,8 @@ from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: pass +from graphon.model_runtime.utils.encoders import jsonable_encoder + from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source @@ -56,7 +58,6 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool -from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index dad5133a7a..6a77fda7ef 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,6 +1,7 @@ import threading from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select @@ -15,7 +16,6 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model: DefaultRetrievalModelDict = { diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 5cf46b2564..bb5b3ba76e 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -8,11 +8,11 @@ from uuid import UUID import numpy as np import pytz +from graphon.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference -from graphon.file import File, FileTransferMethod, FileType from libs.login import current_user from models import Account diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 9e1d41cb39..8d6f83dc07 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -8,9 +8,6 @@ import json from decimal import Decimal from typing import cast -from core.model_manager import ModelManager -from core.tools.entities.tool_entities import ToolProviderType -from extensions.ext_database import db from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType @@ -23,6 +20,10 @@ from graphon.model_runtime.errors.invoke import ( ) from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from graphon.model_runtime.utils.encoders import jsonable_encoder + +from core.model_manager import ModelManager +from core.tools.entities.tool_entities import ToolProviderType +from extensions.ext_database import db from models.tools import ToolModelInvoke diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 1e4f3ed2a7..c4b7d57449 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -1,12 +1,13 @@ from collections.abc import Mapping, Sequence from typing import Any -from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration -from core.tools.errors import WorkflowToolHumanInputNotSupportedError from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.entities import OutputVariableEntity from graphon.variables.input_entities import VariableEntity +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration +from core.tools.errors import WorkflowToolHumanInputNotSupportedError + class WorkflowToolConfigurationUtils: @classmethod diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 716368c191..f48b24be30 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections.abc import Mapping +from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import Field from sqlalchemy.orm import Session @@ -23,7 +24,6 @@ from core.tools.entities.tool_entities import ( from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 495fcd48b3..a3fb4eda92 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -5,6 +5,8 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import Any, cast +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController @@ -20,8 +22,6 @@ from core.tools.entities.tool_entities import ( from core.tools.errors import ToolInvokeError from core.workflow.file_reference import resolve_file_record_id from factories.file_factory import build_from_mapping -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from models import Account, Tenant from models.model import App, EndUser from models.utils.file_input_compat import build_file_from_stored_mapping diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 24c1271488..61d1cd8540 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -8,6 +8,7 @@ from collections.abc import Mapping from datetime import datetime from typing import Any +from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from core.plugin.entities.request import TriggerInvokeEventResponse @@ -27,7 +28,6 @@ from core.trigger.debug.events import ( from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig from extensions.ext_redis import redis_client -from graphon.entities.graph_config import NodeConfigDict from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.schedule_utils import calculate_next_run_at from models.model import App diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_compat.py index 75a0a0c202..c95516a240 100644 --- a/api/core/workflow/human_input_compat.py +++ b/api/core/workflow/human_input_compat.py @@ -14,13 +14,12 @@ from typing import Annotated, Any, ClassVar, Literal import bleach import markdown -from markdown.extensions.tables import TableExtension -from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter - from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.variable_template_parser import VariableTemplateParser from graphon.runtime import VariablePool from graphon.variables.consts import SELECTORS_LENGTH +from markdown.extensions.tables import TableExtension +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter class DeliveryMethodType(enum.StrEnum): diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 028e38fbee..8cc21d2cd9 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -4,6 +4,22 @@ from collections.abc import Callable, Iterator, Mapping, MutableMapping from functools import lru_cache from typing import TYPE_CHECKING, Any, TypeAlias, cast, final +from graphon.entities.base_node_data import BaseNodeData +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.file.file_manager import file_manager +from graphon.graph.graph import NodeFactory +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.nodes.base.node import Node +from graphon.nodes.code.code_node import WorkflowCodeExecutor +from graphon.nodes.code.entities import CodeLanguage +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.nodes.document_extractor import UnstructuredApiConfig +from graphon.nodes.http_request import build_http_request_config +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from sqlalchemy import select from sqlalchemy.orm import Session from typing_extensions import override @@ -40,22 +56,6 @@ from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer from extensions.ext_database import db -from graphon.entities.base_node_data import BaseNodeData -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.file.file_manager import file_manager -from graphon.graph.graph import NodeFactory -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.nodes.base.node import Node -from graphon.nodes.code.code_node import WorkflowCodeExecutor -from graphon.nodes.code.entities import CodeLanguage -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.nodes.document_extractor import UnstructuredApiConfig -from graphon.nodes.http_request import build_http_request_config -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from models.model import Conversation if TYPE_CHECKING: diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index 2e632e56f0..19cb3a7b0a 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -4,32 +4,6 @@ from collections.abc import Callable, Generator, Mapping, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.app.file_access import DatabaseFileAccessController -from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from core.model_manager import ModelInstance -from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError -from core.plugin.impl.plugin import PluginInstaller -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormRepository, - HumanInputFormRepositoryImpl, -) -from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType -from core.tools.errors import ToolInvokeError -from core.tools.tool_engine import ToolEngine -from core.tools.tool_file_manager import ToolFileManager -from core.tools.tool_manager import ToolManager -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.file_reference import build_file_reference -from extensions.ext_database import db -from factories import file_factory from graphon.file import FileTransferMethod, FileType from graphon.model_runtime.entities import LLMMode from graphon.model_runtime.entities.llm_entities import ( @@ -60,6 +34,32 @@ from graphon.nodes.tool_runtime_entities import ( ToolRuntimeMessage, ToolRuntimeParameter, ) +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.file_access import DatabaseFileAccessController +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output +from core.model_manager import ModelInstance +from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError +from core.plugin.impl.plugin import PluginInstaller +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormRepository, + HumanInputFormRepositoryImpl, +) +from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.tool_file_manager import ToolFileManager +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.file_reference import build_file_reference +from extensions.ext_database import db +from factories import file_factory from models.dataset import SegmentAttachmentBinding from models.model import UploadFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService @@ -76,12 +76,13 @@ from .human_input_compat import ( from .system_variables import SystemVariableKey, get_system_text if TYPE_CHECKING: - from core.tools.__base.tool import Tool - from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage from graphon.file import File from graphon.nodes.llm.file_saver import LLMFileSaver from graphon.nodes.tool.entities import ToolNodeData + from core.tools.__base.tool import Tool + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage + _file_access_controller = DatabaseFileAccessController() diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 7b000101b0..bfd5536e4a 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -3,14 +3,15 @@ from __future__ import annotations from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.workflow.system_variables import SystemVariableKey, get_system_text from graphon.entities.graph_config import NodeConfigDict from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent from graphon.nodes.base.node import Node from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.workflow.system_variables import SystemVariableKey, get_system_text + from .entities import AgentNodeData from .exceptions import ( AgentInvocationError, diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 51452c29a3..c52aad150b 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,12 +1,12 @@ from enum import IntEnum, StrEnum, auto from typing import Any, Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType class AgentNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py index f44681377d..db74590ed7 100644 --- a/api/core/workflow/nodes/agent/message_transformer.py +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -3,14 +3,6 @@ from __future__ import annotations from collections.abc import Generator, Mapping from typing import Any, cast -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.app.file_access import DatabaseFileAccessController -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from extensions.ext_database import db -from factories import file_factory from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata @@ -23,6 +15,14 @@ from graphon.node_events import ( StreamCompletedEvent, ) from graphon.variables.segments import ArrayFileSegment +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.file_access import DatabaseFileAccessController +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from extensions.ext_database import db +from factories import file_factory from models import ToolFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py index a872774c98..be50edbc4d 100644 --- a/api/core/workflow/nodes/agent/runtime_support.py +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -4,6 +4,8 @@ import json from collections.abc import Sequence from typing import Any, cast +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.runtime import VariablePool from packaging.version import Version from pydantic import ValidationError from sqlalchemy import select @@ -19,8 +21,6 @@ from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolP from core.tools.tool_manager import ToolManager from core.workflow.system_variables import SystemVariableKey, get_system_text from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.runtime import VariablePool from models.model import Conversation from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 38f39b3f94..d9247b2593 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,18 +1,23 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( + BuiltinNodeTypes, + NodeExecutionType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.node_events import NodeRunResult, StreamCompletedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser + from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError from core.workflow.file_reference import resolve_file_record_id from core.workflow.system_variables import SystemVariableKey, get_system_segment -from graphon.entities.graph_config import NodeConfigDict -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionMetadataKey -from graphon.node_events import NodeRunResult, StreamCompletedEvent -from graphon.nodes.base.node import Node -from graphon.nodes.base.variable_template_parser import VariableTemplateParser from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam from .exc import DatasourceNodeError diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 28966f2392..cad32f8d5b 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -1,10 +1,9 @@ from typing import Any, Literal, Union -from pydantic import BaseModel, field_validator -from pydantic_core.core_schema import ValidationInfo - from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType +from pydantic import BaseModel, field_validator +from pydantic_core.core_schema import ValidationInfo class DatasourceEntity(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 11339bb122..cba6c12dca 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -1,12 +1,12 @@ from typing import Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index b465a2d8ff..bb72fe3881 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -2,17 +2,17 @@ import logging from collections.abc import Mapping from typing import TYPE_CHECKING, Any +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template + from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text -from graphon.entities.graph_config import NodeConfigDict -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import NodeExecutionType -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.base.template import Template from .entities import KnowledgeIndexNodeData from .exc import ( diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 3f7cc364d3..b1fa8593ef 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,11 +1,10 @@ from collections.abc import Sequence from typing import Literal -from pydantic import BaseModel, Field - from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.llm.entities import ModelConfig, VisionConfig +from pydantic import BaseModel, Field class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 117f426ade..13624b27b3 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,11 +8,6 @@ import logging from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal -from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.workflow.file_reference import parse_file_reference from graphon.entities import GraphInitParams from graphon.entities.graph_config import NodeConfigDict from graphon.enums import ( @@ -32,6 +27,12 @@ from graphon.variables import ( ) from graphon.variables.segments import ArrayObjectSegment +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.file_reference import parse_file_reference + from .entities import ( Condition, KnowledgeRetrievalNodeData, @@ -44,7 +45,7 @@ from .exc import ( from .retrieval import KnowledgeRetrievalRequest, Source if TYPE_CHECKING: - from graphon.file.models import File + from graphon.file import File from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) diff --git a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py index ea45dcf5c2..39e2008a2c 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -1,10 +1,10 @@ from typing import Any, Literal, Protocol +from graphon.model_runtime.entities import LLMUsage +from graphon.nodes.llm.entities import ModelConfig from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from graphon.model_runtime.entities import LLMUsage -from graphon.nodes.llm.entities import ModelConfig from .entities import MetadataFilteringCondition diff --git a/api/core/workflow/nodes/trigger_plugin/entities.py b/api/core/workflow/nodes/trigger_plugin/entities.py index 23ed2cd408..bf5be2379a 100644 --- a/api/core/workflow/nodes/trigger_plugin/entities.py +++ b/api/core/workflow/nodes/trigger_plugin/entities.py @@ -1,12 +1,12 @@ from collections.abc import Mapping from typing import Any, Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel, Field, ValidationInfo, field_validator from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.entities.entities import EventParameter -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from .exc import TriggerEventParameterError diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index a2c952a899..e50de11bb9 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -1,13 +1,13 @@ from collections.abc import Mapping from typing import Any -from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID + from .entities import TriggerEventNodeData diff --git a/api/core/workflow/nodes/trigger_schedule/entities.py b/api/core/workflow/nodes/trigger_schedule/entities.py index 207c1e7253..f14ca893c9 100644 --- a/api/core/workflow/nodes/trigger_schedule/entities.py +++ b/api/core/workflow/nodes/trigger_schedule/entities.py @@ -1,10 +1,10 @@ from typing import Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel, Field from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType class TriggerScheduleNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index dd80617dfc..a9753ab387 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,11 +1,11 @@ from collections.abc import Mapping +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node + from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import NodeExecutionType -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node from .entities import TriggerScheduleNodeData diff --git a/api/core/workflow/nodes/trigger_webhook/entities.py b/api/core/workflow/nodes/trigger_webhook/entities.py index 3125fe17e6..4d5ad72154 100644 --- a/api/core/workflow/nodes/trigger_webhook/entities.py +++ b/api/core/workflow/nodes/trigger_webhook/entities.py @@ -1,12 +1,12 @@ from collections.abc import Sequence from enum import StrEnum -from pydantic import BaseModel, Field, field_validator - -from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from graphon.entities.base_node_data import BaseNodeData from graphon.enums import NodeType from graphon.variables.types import SegmentType +from pydantic import BaseModel, Field, field_validator + +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE _WEBHOOK_HEADER_ALLOWED_TYPES = frozenset( { diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 6858d6dc35..ebaac93934 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -2,12 +2,7 @@ import logging from collections.abc import Mapping from typing import Any -from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from core.workflow.file_reference import resolve_file_record_id -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from factories.variable_factory import build_segment_with_type -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import NodeExecutionType +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus from graphon.file import FileTransferMethod from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node @@ -15,6 +10,11 @@ from graphon.nodes.protocols import FileReferenceFactoryProtocol from graphon.variables.types import SegmentType from graphon.variables.variables import FileVariable +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from factories.variable_factory import build_segment_with_type + from .entities import ContentType, WebhookData logger = logging.getLogger(__name__) diff --git a/api/core/workflow/template_rendering.py b/api/core/workflow/template_rendering.py index b4ffb37549..d51cfadd09 100644 --- a/api/core/workflow/template_rendering.py +++ b/api/core/workflow/template_rendering.py @@ -3,10 +3,11 @@ from __future__ import annotations from collections.abc import Mapping from typing import Any -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor from graphon.nodes.code.entities import CodeLanguage from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor + class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): """Sandbox-backed Jinja2 renderer for workflow-owned node composition.""" diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 7429c95c7c..2346a95d6a 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -3,6 +3,20 @@ import time from collections.abc import Generator, Mapping, Sequence from typing import Any +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file import File +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import CommandChannel, InMemoryChannel +from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool + from configs import dify_config from context import capture_current_context from core.app.apps.exc import GenerateTaskStoppedError @@ -21,20 +35,6 @@ from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add from core.workflow.variable_prefixes import ENVIRONMENT_VARIABLE_NODE_ID from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.errors import WorkflowNodeRunFailedError -from graphon.file.models import File -from graphon.graph import Graph -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer -from graphon.graph_engine.protocols.command_channel import CommandChannel -from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.base.node import Node -from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow logger = logging.getLogger(__name__) diff --git a/api/enterprise/telemetry/draft_trace.py b/api/enterprise/telemetry/draft_trace.py index dff558988c..5a8d0ee6f4 100644 --- a/api/enterprise/telemetry/draft_trace.py +++ b/api/enterprise/telemetry/draft_trace.py @@ -3,9 +3,10 @@ from __future__ import annotations from collections.abc import Mapping from typing import Any +from graphon.enums import WorkflowNodeExecutionMetadataKey + from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName from core.telemetry import emit as telemetry_emit -from graphon.enums import WorkflowNodeExecutionMetadataKey from models.workflow import WorkflowNodeExecutionModel diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index ba9758175f..7bd8e88231 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,11 +1,12 @@ import logging +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.tool.entities import ToolEntity + from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_draft_workflow_was_synced -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.tool.entities import ToolEntity logger = logging.getLogger(__name__) diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 6769b94cde..86b5b2bbf0 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,11 +1,11 @@ from typing import cast +from graphon.nodes import BuiltinNodeTypes from sqlalchemy import delete, select from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db -from graphon.nodes import BuiltinNodeTypes from models.dataset import AppDatasetJoin from models.workflow import Workflow diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 120febecfb..651f8ed898 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,13 +5,12 @@ from dify_app import DifyApp def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import sentry_sdk + from graphon.model_runtime.errors.invoke import InvokeRateLimitError from langfuse import parse_error from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException - from graphon.model_runtime.errors.invoke import InvokeRateLimitError - def before_send(event, hint): if "exc_info" in hint: _, exc_value, _ = hint["exc_info"] diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index 64ff0f0674..db599c5d49 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -11,12 +11,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value -from graphon.enums import WorkflowNodeExecutionStatus from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 5208f8f37e..3c83ab4f84 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -20,12 +20,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, cast +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string -from graphon.enums import WorkflowExecutionStatus from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index ea4a2b3dd1..f71b2fa1df 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -4,14 +4,14 @@ import os import time from typing import Union +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from extensions.logstore.aliyun_logstore import AliyunLogStore -from graphon.entities import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index 976b5db8e3..b725436681 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -13,6 +13,10 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, Union +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -21,10 +25,6 @@ from core.repositories.factory import OrderConfig, WorkflowNodeExecutionReposito from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier -from graphon.entities import WorkflowNodeExecution -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index eefcaa126e..23d324f9ea 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -10,17 +10,17 @@ Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when import json from typing import Any, Protocol +from graphon.enums import BuiltinNodeTypes +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel from configs import dify_config from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes -from graphon.enums import BuiltinNodeTypes -from graphon.file.models import File -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment def should_include_content() -> bool: diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py index ec3c78a12d..335c5cc29e 100644 --- a/api/extensions/otel/parser/llm.py +++ b/api/extensions/otel/parser/llm.py @@ -6,12 +6,12 @@ import logging from collections.abc import Mapping from typing import Any +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import LLMAttributes -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index 56672d1fd4..6df5f62c15 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -6,13 +6,13 @@ import logging from collections.abc import Sequence from typing import Any +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import RetrieverAttributes -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py index 75ddbba448..b9fdd9e1ca 100644 --- a/api/extensions/otel/parser/tool.py +++ b/api/extensions/otel/parser/tool.py @@ -2,14 +2,14 @@ Parser for tool nodes that captures tool-specific metadata. """ -from opentelemetry.trace import Span - -from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps -from extensions.otel.semconv.gen_ai import ToolAttributes from graphon.enums import WorkflowNodeExecutionMetadataKey from graphon.graph_events import GraphNodeEventBase from graphon.nodes.base.node import Node from graphon.nodes.tool.entities import ToolNodeData +from opentelemetry.trace import Span + +from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps +from extensions.otel.semconv.gen_ai import ToolAttributes class ToolNodeOTelParser: diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py index bc87510d43..7516d18c8e 100644 --- a/api/factories/file_factory/builders.py +++ b/api/factories/file_factory/builders.py @@ -7,13 +7,12 @@ import uuid from collections.abc import Mapping, Sequence from typing import Any +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type from sqlalchemy import select from core.app.file_access import FileAccessControllerProtocol from core.workflow.file_reference import build_file_reference from extensions.ext_database import db -from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers -from graphon.file.file_factory import standardize_file_type from models import ToolFile, UploadFile from .common import resolve_mapping_file_id diff --git a/api/factories/file_factory/message_files.py b/api/factories/file_factory/message_files.py index 4b3d514238..5582b85c95 100644 --- a/api/factories/file_factory/message_files.py +++ b/api/factories/file_factory/message_files.py @@ -4,8 +4,9 @@ from __future__ import annotations from collections.abc import Sequence -from core.app.file_access import FileAccessControllerProtocol from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig + +from core.app.file_access import FileAccessControllerProtocol from models import MessageFile from .builders import build_from_mapping diff --git a/api/factories/file_factory/storage_keys.py b/api/factories/file_factory/storage_keys.py index dba4c84407..db3a7f3015 100644 --- a/api/factories/file_factory/storage_keys.py +++ b/api/factories/file_factory/storage_keys.py @@ -5,12 +5,12 @@ from __future__ import annotations import uuid from collections.abc import Mapping, Sequence +from graphon.file import File, FileTransferMethod from sqlalchemy import select from sqlalchemy.orm import Session from core.app.file_access import FileAccessControllerProtocol from core.workflow.file_reference import build_file_reference, parse_file_reference -from graphon.file import File, FileTransferMethod from models import ToolFile, UploadFile diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index fd7acb14d3..57205b5739 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -8,11 +8,6 @@ shared conversion functions for legacy callers and tests. from collections.abc import Mapping, Sequence from typing import Any, cast -from configs import dify_config -from core.workflow.variable_prefixes import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, -) from graphon.variables.exc import VariableError from graphon.variables.factory import ( TypeMismatchError, @@ -36,6 +31,12 @@ from graphon.variables.variables import ( VariableBase, ) +from configs import dify_config +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) + __all__ = [ "TypeMismatchError", "UnsupportedSegmentTypeError", diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 801949747e..30d02aeedc 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -3,9 +3,8 @@ from __future__ import annotations from datetime import datetime from typing import Any, TypeAlias -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator - from graphon.file import File +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator JSONValue: TypeAlias = Any diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 4e201e66e6..b8daa5af30 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -3,9 +3,8 @@ from __future__ import annotations from datetime import datetime from flask_restx import fields -from pydantic import BaseModel, ConfigDict, computed_field, field_validator - from graphon.file import helpers as file_helpers +from pydantic import BaseModel, ConfigDict, computed_field, field_validator simple_account_fields = { "id": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 86c4f285cd..d982c31aee 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -4,11 +4,11 @@ from datetime import datetime from typing import TypeAlias from uuid import uuid4 +from graphon.file import File from pydantic import BaseModel, ConfigDict, Field, field_validator from core.entities.execution_extra_content import ExecutionExtraContentDomainModel from fields.conversation_fields import AgentThought, JSONValue, MessageFile -from graphon.file import File JSONValueType: TypeAlias = JSONValue diff --git a/api/fields/raws.py b/api/fields/raws.py index ee6f53b360..4c65cdab7a 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,5 +1,4 @@ from flask_restx import fields - from graphon.file import File diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index f9b5e98936..b0b6cc0b48 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,8 +1,8 @@ from flask_restx import fields +from graphon.variables import SecretVariable, SegmentType, VariableBase from core.helper import encrypter from fields.member_fields import simple_account_fields -from graphon.variables import SecretVariable, SegmentType, VariableBase from libs.helper import TimestampField from ._value_type_serializer import serialize_value_type diff --git a/api/graphon/README.md b/api/graphon/README.md deleted file mode 100644 index 725f122cd8..0000000000 --- a/api/graphon/README.md +++ /dev/null @@ -1,135 +0,0 @@ -# Workflow - -## Project Overview - -This is the workflow graph engine module of Dify, implementing a queue-based distributed workflow execution system. The engine handles agentic AI workflows with support for parallel execution, node iteration, conditional logic, and external command control. - -## Architecture - -### Core Components - -The graph engine follows a layered architecture with strict dependency rules: - -1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution - - - **Manager** - External control interface for stop/pause/resume commands - - **Worker** - Node execution runtime - - **Command Processing** - Handles control commands (abort, pause, resume) - - **Event Management** - Event propagation and layer notifications - - **Graph Traversal** - Edge processing and skip propagation - - **Response Coordinator** - Path tracking and session management - - **Layers** - Pluggable middleware (debug logging, execution limits) - - **Command Channels** - Communication channels (InMemory, Redis) - -1. **Graph** (`graph/`) - Graph structure and runtime state - - - **Graph Template** - Workflow definition - - **Edge** - Node connections with conditions - - **Runtime State Protocol** - State management interface - -1. **Nodes** (`nodes/`) - Node implementations - - - **Base** - Abstract node classes and variable parsing - - **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc. - -1. **Events** (`node_events/`) - Event system - - - **Base** - Event protocols - - **Node Events** - Node lifecycle events - -1. **Entities** (`entities/`) - Domain models - - - **Variable Pool** - Variable storage - - **Graph Init Params** - Initialization configuration - -## Key Design Patterns - -### Command Channel Pattern - -External workflow control via Redis or in-memory channels: - -```python -# Send stop command to running workflow -channel = RedisChannel(redis_client, f"workflow:{task_id}:commands") -channel.send_command(AbortCommand(reason="User requested")) -``` - -### Layer System - -Extensible middleware for cross-cutting concerns: - -```python -engine = GraphEngine(graph) -engine.layer(DebugLoggingLayer(level="INFO")) -engine.layer(ExecutionLimitsLayer(max_nodes=100)) -``` - -`engine.layer()` binds the read-only runtime state before execution, so layer hooks -can assume `graph_runtime_state` is available. - -### Event-Driven Architecture - -All node executions emit events for monitoring and integration: - -- `NodeRunStartedEvent` - Node execution begins -- `NodeRunSucceededEvent` - Node completes successfully -- `NodeRunFailedEvent` - Node encounters error -- `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle - -### Variable Pool - -Centralized variable storage with namespace isolation: - -```python -# Variables scoped by node_id -pool.add(["node1", "output"], value) -result = pool.get(["node1", "output"]) -``` - -## Import Architecture Rules - -The codebase enforces strict layering via import-linter: - -1. **Workflow Layers** (top to bottom): - - - graph_engine → graph_events → graph → nodes → node_events → entities - -1. **Graph Engine Internal Layers**: - - - orchestration → command_processing → event_management → graph_traversal → domain - -1. **Domain Isolation**: - - - Domain models cannot import from infrastructure layers - -1. **Command Channel Independence**: - - - InMemory and Redis channels must remain independent - -## Common Tasks - -### Adding a New Node Type - -1. Create node class in `nodes//` -1. Inherit from `BaseNode` or appropriate base class -1. Implement `_run()` method -1. Ensure the node module is importable under `nodes//` -1. Add tests in `tests/unit_tests/graphon/nodes/` - -### Implementing a Custom Layer - -1. Create class inheriting from `Layer` base -1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()` -1. Add to engine via `engine.layer()` - -### Debugging Workflow Execution - -Enable debug logging layer: - -```python -debug_layer = DebugLoggingLayer( - level="DEBUG", - include_inputs=True, - include_outputs=True -) -``` diff --git a/api/graphon/__init__.py b/api/graphon/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/entities/__init__.py b/api/graphon/entities/__init__.py deleted file mode 100644 index ef7789c49c..0000000000 --- a/api/graphon/entities/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .graph_init_params import GraphInitParams -from .workflow_execution import WorkflowExecution -from .workflow_node_execution import WorkflowNodeExecution -from .workflow_start_reason import WorkflowStartReason - -__all__ = [ - "GraphInitParams", - "WorkflowExecution", - "WorkflowNodeExecution", - "WorkflowStartReason", -] diff --git a/api/graphon/entities/base_node_data.py b/api/graphon/entities/base_node_data.py deleted file mode 100644 index e8267043a9..0000000000 --- a/api/graphon/entities/base_node_data.py +++ /dev/null @@ -1,178 +0,0 @@ -from __future__ import annotations - -import json -from abc import ABC -from builtins import type as type_ -from enum import StrEnum -from typing import Any, Union - -from pydantic import BaseModel, ConfigDict, Field, model_validator - -from graphon.entities.exc import DefaultValueTypeError -from graphon.enums import ErrorStrategy, NodeType - -# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`. -_NumberType = Union[int, float] - - -class RetryConfig(BaseModel): - """node retry config""" - - max_retries: int = 0 # max retry times - retry_interval: int = 0 # retry interval in milliseconds - retry_enabled: bool = False # whether retry is enabled - - @property - def retry_interval_seconds(self) -> float: - return self.retry_interval / 1000 - - -class DefaultValueType(StrEnum): - STRING = "string" - NUMBER = "number" - OBJECT = "object" - ARRAY_NUMBER = "array[number]" - ARRAY_STRING = "array[string]" - ARRAY_OBJECT = "array[object]" - ARRAY_FILES = "array[file]" - - -class DefaultValue(BaseModel): - value: Any = None - type: DefaultValueType - key: str - - @staticmethod - def _parse_json(value: str): - """Unified JSON parsing handler""" - try: - return json.loads(value) - except json.JSONDecodeError: - raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") - - @staticmethod - def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: - """Unified array type validation""" - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) - - @staticmethod - def _convert_number(value: str) -> float: - """Unified number conversion handler""" - try: - return float(value) - except ValueError: - raise DefaultValueTypeError(f"Cannot convert to number: {value}") - - @model_validator(mode="after") - def validate_value_type(self) -> DefaultValue: - # Type validation configuration - type_validators: dict[DefaultValueType, dict[str, Any]] = { - DefaultValueType.STRING: { - "type": str, - "converter": lambda x: x, - }, - DefaultValueType.NUMBER: { - "type": _NumberType, - "converter": self._convert_number, - }, - DefaultValueType.OBJECT: { - "type": dict, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_NUMBER: { - "type": list, - "element_type": _NumberType, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_STRING: { - "type": list, - "element_type": str, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_OBJECT: { - "type": list, - "element_type": dict, - "converter": self._parse_json, - }, - } - - validator: dict[str, Any] = type_validators.get(self.type, {}) - if not validator: - if self.type == DefaultValueType.ARRAY_FILES: - # Handle files type - return self - raise DefaultValueTypeError(f"Unsupported type: {self.type}") - - # Handle string input cases - if isinstance(self.value, str) and self.type != DefaultValueType.STRING: - self.value = validator["converter"](self.value) - - # Validate base type - if not isinstance(self.value, validator["type"]): - raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") - - # Validate array element types - if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): - raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") - - return self - - -class BaseNodeData(ABC, BaseModel): - # Raw graph payloads are first validated through `NodeConfigDictAdapter`, where - # `node["data"]` is typed as `BaseNodeData` before the concrete node class is known. - # `type` therefore accepts downstream string node kinds; unknown node implementations - # are rejected later when the node factory resolves the node registry. - # At that boundary, node-specific fields are still "extra" relative to this shared DTO, - # and persisted templates/workflows also carry undeclared compatibility keys such as - # `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive - # here until graph parsing becomes discriminated by node type or those legacy payloads - # are normalized. - model_config = ConfigDict(extra="allow") - - type: NodeType - title: str = "" - desc: str | None = None - version: str = "1" - error_strategy: ErrorStrategy | None = None - default_value: list[DefaultValue] | None = None - retry_config: RetryConfig = Field(default_factory=RetryConfig) - - @property - def default_value_dict(self) -> dict[str, Any]: - if self.default_value: - return {item.key: item.value for item in self.default_value} - return {} - - def __getitem__(self, key: str) -> Any: - """ - Dict-style access without calling model_dump() on every lookup. - Prefer using model fields and Pydantic's extra storage. - """ - # First, check declared model fields - if key in self.__class__.model_fields: - return getattr(self, key) - - # Then, check undeclared compatibility fields stored in Pydantic's extra dict. - extras = getattr(self, "__pydantic_extra__", None) - if extras is None: - extras = getattr(self, "model_extra", None) - if extras is not None and key in extras: - return extras[key] - - raise KeyError(key) - - def get(self, key: str, default: Any = None) -> Any: - """ - Dict-style .get() without calling model_dump() on every lookup. - """ - if key in self.__class__.model_fields: - return getattr(self, key) - - extras = getattr(self, "__pydantic_extra__", None) - if extras is None: - extras = getattr(self, "model_extra", None) - if extras is not None and key in extras: - return extras.get(key, default) - - return default diff --git a/api/graphon/entities/exc.py b/api/graphon/entities/exc.py deleted file mode 100644 index aeecf40640..0000000000 --- a/api/graphon/entities/exc.py +++ /dev/null @@ -1,10 +0,0 @@ -class BaseNodeError(ValueError): - """Base class for node errors.""" - - pass - - -class DefaultValueTypeError(BaseNodeError): - """Raised when the default value type is invalid.""" - - pass diff --git a/api/graphon/entities/graph_config.py b/api/graphon/entities/graph_config.py deleted file mode 100644 index 392241c631..0000000000 --- a/api/graphon/entities/graph_config.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -import sys - -from pydantic import TypeAdapter, with_config - -from graphon.entities.base_node_data import BaseNodeData - -if sys.version_info >= (3, 12): - from typing import TypedDict -else: - from typing_extensions import TypedDict - - -@with_config(extra="allow") -class NodeConfigDict(TypedDict): - id: str - # This is the permissive raw graph boundary. Node factories re-validate `data` - # with the concrete `NodeData` subtype after resolving the node implementation. - data: BaseNodeData - - -NodeConfigDictAdapter = TypeAdapter(NodeConfigDict) diff --git a/api/graphon/entities/graph_init_params.py b/api/graphon/entities/graph_init_params.py deleted file mode 100644 index f785d58a52..0000000000 --- a/api/graphon/entities/graph_init_params.py +++ /dev/null @@ -1,24 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import BaseModel, Field - -DIFY_RUN_CONTEXT_KEY = "_dify" - - -class GraphInitParams(BaseModel): - """GraphInitParams encapsulates the configurations and contextual information - that remain constant throughout a single execution of the graph engine. - - A single execution is defined as follows: as long as the execution has not reached - its conclusion, it is considered one execution. For instance, if a workflow is suspended - and later resumed, it is still regarded as a single execution, not two. - - For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`. - """ - - # init params - workflow_id: str = Field(..., description="workflow id") - graph_config: Mapping[str, Any] = Field(..., description="graph config") - run_context: Mapping[str, Any] = Field(..., description="runtime context") - call_depth: int = Field(..., description="call depth") diff --git a/api/graphon/entities/pause_reason.py b/api/graphon/entities/pause_reason.py deleted file mode 100644 index ba2973fd45..0000000000 --- a/api/graphon/entities/pause_reason.py +++ /dev/null @@ -1,42 +0,0 @@ -from collections.abc import Mapping -from enum import StrEnum, auto -from typing import Annotated, Any, Literal, TypeAlias - -from pydantic import BaseModel, Field - -from graphon.nodes.human_input.entities import FormInput, UserAction - - -class PauseReasonType(StrEnum): - HUMAN_INPUT_REQUIRED = auto() - SCHEDULED_PAUSE = auto() - - -class HumanInputRequired(BaseModel): - TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED - form_id: str - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - actions: list[UserAction] = Field(default_factory=list) - node_id: str - node_title: str - - # The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from - # `output_variable_name` to their resolved values. - # - # For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its - # selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable - # `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The - # `resolved_default_values` is `{"name": "John"}`. - # - # Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`. - resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) - - -class SchedulingPause(BaseModel): - TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE - - message: str - - -PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")] diff --git a/api/graphon/entities/workflow_execution.py b/api/graphon/entities/workflow_execution.py deleted file mode 100644 index b8de7eed1a..0000000000 --- a/api/graphon/entities/workflow_execution.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -Domain entities for workflow execution. - -Models describe graph runtime state and avoid infrastructure-specific details. -""" - -from __future__ import annotations - -from collections.abc import Mapping -from datetime import UTC, datetime -from typing import Any - -from pydantic import BaseModel, Field - -from graphon.enums import WorkflowExecutionStatus, WorkflowType - - -class WorkflowExecution(BaseModel): - """ - Domain model for a workflow execution within the graph runtime. - """ - - id_: str = Field(...) - workflow_id: str = Field(...) - workflow_version: str = Field(...) - workflow_type: WorkflowType = Field(...) - graph: Mapping[str, Any] = Field(...) - - inputs: Mapping[str, Any] = Field(...) - outputs: Mapping[str, Any] | None = None - - status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING - error_message: str = Field(default="") - total_tokens: int = Field(default=0) - total_steps: int = Field(default=0) - exceptions_count: int = Field(default=0) - - started_at: datetime = Field(...) - finished_at: datetime | None = None - - @property - def elapsed_time(self) -> float: - """ - Calculate elapsed time in seconds. - If workflow is not finished, use current time. - """ - end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None) - return (end_time - self.started_at).total_seconds() - - @classmethod - def new( - cls, - *, - id_: str, - workflow_id: str, - workflow_type: WorkflowType, - workflow_version: str, - graph: Mapping[str, Any], - inputs: Mapping[str, Any], - started_at: datetime, - ) -> WorkflowExecution: - return WorkflowExecution( - id_=id_, - workflow_id=workflow_id, - workflow_type=workflow_type, - workflow_version=workflow_version, - graph=graph, - inputs=inputs, - status=WorkflowExecutionStatus.RUNNING, - started_at=started_at, - ) diff --git a/api/graphon/entities/workflow_node_execution.py b/api/graphon/entities/workflow_node_execution.py deleted file mode 100644 index 5458572e7e..0000000000 --- a/api/graphon/entities/workflow_node_execution.py +++ /dev/null @@ -1,141 +0,0 @@ -""" -Domain entities for workflow node execution. - -These models capture node-level execution state for the graph runtime without -describing storage or application-layer concerns. -""" - -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, Field, PrivateAttr - -from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus - - -class WorkflowNodeExecution(BaseModel): - """ - Domain model for workflow node execution. - - This model represents the graph-level record of a node execution and - contains only execution state relevant to the runtime. - """ - - # --------- Core identification fields --------- - - # Unique identifier for this execution record, used when persisting to storage. - # Value is a UUID string (e.g., '09b3e04c-f9ae-404c-ad82-290b8d7bd382'). - id: str - - # Optional secondary ID for cross-referencing purposes. - # - # NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`. - # While `node_execution_id` may sometimes be a UUID string, this is not guaranteed. - # In most scenarios, `id` should be used as the primary identifier. - node_execution_id: str | None = None - workflow_id: str # ID of the workflow this node belongs to - workflow_execution_id: str | None = None # ID of the workflow execution (null for single-step debugging) - # --------- Core identification fields ends --------- - - # Execution positioning and flow - index: int # Sequence number for ordering in trace visualization - predecessor_node_id: str | None = None # ID of the node that executed before this one - node_id: str # ID of the node being executed - node_type: NodeType # Type of node (e.g., start, llm, downstream response node) - title: str # Display title of the node - - # Execution data - # The `inputs` and `outputs` fields hold the full content - inputs: Mapping[str, Any] | None = None # Input variables used by this node - process_data: Mapping[str, Any] | None = None # Intermediate processing data - outputs: Mapping[str, Any] | None = None # Output variables produced by this node - - # Execution state - status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status - error: str | None = None # Error message if execution failed - elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds - - # Additional metadata - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.) - - # Timing information - created_at: datetime # When execution started - finished_at: datetime | None = None # When execution completed - - _truncated_inputs: Mapping[str, Any] | None = PrivateAttr(None) - _truncated_outputs: Mapping[str, Any] | None = PrivateAttr(None) - _truncated_process_data: Mapping[str, Any] | None = PrivateAttr(None) - - def get_truncated_inputs(self) -> Mapping[str, Any] | None: - return self._truncated_inputs - - def get_truncated_outputs(self) -> Mapping[str, Any] | None: - return self._truncated_outputs - - def get_truncated_process_data(self) -> Mapping[str, Any] | None: - return self._truncated_process_data - - def set_truncated_inputs(self, truncated_inputs: Mapping[str, Any] | None): - self._truncated_inputs = truncated_inputs - - def set_truncated_outputs(self, truncated_outputs: Mapping[str, Any] | None): - self._truncated_outputs = truncated_outputs - - def set_truncated_process_data(self, truncated_process_data: Mapping[str, Any] | None): - self._truncated_process_data = truncated_process_data - - def get_response_inputs(self) -> Mapping[str, Any] | None: - inputs = self.get_truncated_inputs() - if inputs: - return inputs - return self.inputs - - @property - def inputs_truncated(self): - return self._truncated_inputs is not None - - @property - def outputs_truncated(self): - return self._truncated_outputs is not None - - @property - def process_data_truncated(self): - return self._truncated_process_data is not None - - def get_response_outputs(self) -> Mapping[str, Any] | None: - outputs = self.get_truncated_outputs() - if outputs is not None: - return outputs - return self.outputs - - def get_response_process_data(self) -> Mapping[str, Any] | None: - process_data = self.get_truncated_process_data() - if process_data is not None: - return process_data - return self.process_data - - def update_from_mapping( - self, - inputs: Mapping[str, Any] | None = None, - process_data: Mapping[str, Any] | None = None, - outputs: Mapping[str, Any] | None = None, - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None, - ): - """ - Update the model from mappings. - - Args: - inputs: The inputs to update - process_data: The process data to update - outputs: The outputs to update - metadata: The metadata to update - """ - if inputs is not None: - self.inputs = dict(inputs) - if process_data is not None: - self.process_data = dict(process_data) - if outputs is not None: - self.outputs = dict(outputs) - if metadata is not None: - self.metadata = dict(metadata) diff --git a/api/graphon/entities/workflow_start_reason.py b/api/graphon/entities/workflow_start_reason.py deleted file mode 100644 index df0f75383b..0000000000 --- a/api/graphon/entities/workflow_start_reason.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import StrEnum - - -class WorkflowStartReason(StrEnum): - """Reason for workflow start events across graph/queue/SSE layers.""" - - INITIAL = "initial" # First start of a workflow run. - RESUMPTION = "resumption" # Start triggered after resuming a paused run. diff --git a/api/graphon/enums.py b/api/graphon/enums.py deleted file mode 100644 index bbc973abe5..0000000000 --- a/api/graphon/enums.py +++ /dev/null @@ -1,262 +0,0 @@ -from enum import StrEnum -from typing import ClassVar, TypeAlias - - -class NodeState(StrEnum): - """State of a node or edge during workflow execution.""" - - UNKNOWN = "unknown" - TAKEN = "taken" - SKIPPED = "skipped" - - -NodeType: TypeAlias = str - - -class BuiltinNodeTypes: - """Built-in node type string constants. - - `node_type` values are plain strings throughout the graph runtime. This namespace - only exposes the built-in values shipped by `graphon`; downstream packages can - use additional strings without extending this class. - """ - - START: ClassVar[NodeType] = "start" - END: ClassVar[NodeType] = "end" - ANSWER: ClassVar[NodeType] = "answer" - LLM: ClassVar[NodeType] = "llm" - KNOWLEDGE_RETRIEVAL: ClassVar[NodeType] = "knowledge-retrieval" - IF_ELSE: ClassVar[NodeType] = "if-else" - CODE: ClassVar[NodeType] = "code" - TEMPLATE_TRANSFORM: ClassVar[NodeType] = "template-transform" - QUESTION_CLASSIFIER: ClassVar[NodeType] = "question-classifier" - HTTP_REQUEST: ClassVar[NodeType] = "http-request" - TOOL: ClassVar[NodeType] = "tool" - DATASOURCE: ClassVar[NodeType] = "datasource" - VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-aggregator" - LEGACY_VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-assigner" - LOOP: ClassVar[NodeType] = "loop" - LOOP_START: ClassVar[NodeType] = "loop-start" - LOOP_END: ClassVar[NodeType] = "loop-end" - ITERATION: ClassVar[NodeType] = "iteration" - ITERATION_START: ClassVar[NodeType] = "iteration-start" - PARAMETER_EXTRACTOR: ClassVar[NodeType] = "parameter-extractor" - VARIABLE_ASSIGNER: ClassVar[NodeType] = "assigner" - DOCUMENT_EXTRACTOR: ClassVar[NodeType] = "document-extractor" - LIST_OPERATOR: ClassVar[NodeType] = "list-operator" - AGENT: ClassVar[NodeType] = "agent" - HUMAN_INPUT: ClassVar[NodeType] = "human-input" - - -BUILT_IN_NODE_TYPES: tuple[NodeType, ...] = ( - BuiltinNodeTypes.START, - BuiltinNodeTypes.END, - BuiltinNodeTypes.ANSWER, - BuiltinNodeTypes.LLM, - BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, - BuiltinNodeTypes.IF_ELSE, - BuiltinNodeTypes.CODE, - BuiltinNodeTypes.TEMPLATE_TRANSFORM, - BuiltinNodeTypes.QUESTION_CLASSIFIER, - BuiltinNodeTypes.HTTP_REQUEST, - BuiltinNodeTypes.TOOL, - BuiltinNodeTypes.DATASOURCE, - BuiltinNodeTypes.VARIABLE_AGGREGATOR, - BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR, - BuiltinNodeTypes.LOOP, - BuiltinNodeTypes.LOOP_START, - BuiltinNodeTypes.LOOP_END, - BuiltinNodeTypes.ITERATION, - BuiltinNodeTypes.ITERATION_START, - BuiltinNodeTypes.PARAMETER_EXTRACTOR, - BuiltinNodeTypes.VARIABLE_ASSIGNER, - BuiltinNodeTypes.DOCUMENT_EXTRACTOR, - BuiltinNodeTypes.LIST_OPERATOR, - BuiltinNodeTypes.AGENT, - BuiltinNodeTypes.HUMAN_INPUT, -) - - -class NodeExecutionType(StrEnum): - """Node execution type classification.""" - - EXECUTABLE = "executable" # Regular nodes that execute and produce outputs - RESPONSE = "response" # Response nodes that stream outputs (Answer, End) - BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier) - CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph) - ROOT = "root" # Nodes that can serve as execution entry points - - -class ErrorStrategy(StrEnum): - FAIL_BRANCH = "fail-branch" - DEFAULT_VALUE = "default-value" - - -class FailBranchSourceHandle(StrEnum): - FAILED = "fail-branch" - SUCCESS = "success-branch" - - -class WorkflowType(StrEnum): - """ - Workflow Type Enum for domain layer - """ - - WORKFLOW = "workflow" - CHAT = "chat" - RAG_PIPELINE = "rag-pipeline" - - -class WorkflowExecutionStatus(StrEnum): - # State diagram for the workflw status: - # (@) means start, (*) means end - # - # ┌------------------>------------------------->------------------->--------------┐ - # | | - # | ┌-----------------------<--------------------┐ | - # ^ | | | - # | | ^ | - # | V | | - # ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V - # | Scheduled |------->| Running |---------------------->| paused | | - # └-----------┘ └-----------------------┘ └-----------┘ | - # | | | | | | | - # | | | | | | | - # ^ | | | V V | - # | | | | | ┌---------┐ | - # (@) | | | └------------------------>| Stopped |<----┘ - # | | | └---------┘ - # | | | | - # | | V V - # | | ┌-----------┐ | - # | | | Succeeded |------------->--------------┤ - # | | └-----------┘ | - # | V V - # | +--------┐ | - # | | Failed |---------------------->----------------┤ - # | └--------┘ | - # V V - # ┌---------------------┐ | - # | Partially Succeeded |---------------------->-----------------┘--------> (*) - # └---------------------┘ - # - # Mermaid diagram: - # - # --- - # title: State diagram for Workflow run state - # --- - # stateDiagram-v2 - # scheduled: Scheduled - # running: Running - # succeeded: Succeeded - # failed: Failed - # partial_succeeded: Partial Succeeded - # paused: Paused - # stopped: Stopped - # - # [*] --> scheduled: - # scheduled --> running: Start Execution - # running --> paused: Human input required - # paused --> running: human input added - # paused --> stopped: User stops execution - # running --> succeeded: Execution finishes without any error - # running --> failed: Execution finishes with errors - # running --> stopped: User stops execution - # running --> partial_succeeded: some execution occurred and handled during execution - # - # scheduled --> stopped: User stops execution - # - # succeeded --> [*] - # failed --> [*] - # partial_succeeded --> [*] - # stopped --> [*] - - # `SCHEDULED` means that the workflow is scheduled to run, but has not - # started running yet. (maybe due to possible worker saturation.) - # - # This enum value is currently unused. - SCHEDULED = "scheduled" - - # `RUNNING` means the workflow is exeuting. - RUNNING = "running" - - # `SUCCEEDED` means the execution of workflow succeed without any error. - SUCCEEDED = "succeeded" - - # `FAILED` means the execution of workflow failed without some errors. - FAILED = "failed" - - # `STOPPED` means the execution of workflow was stopped, either manually - # by the user, or automatically by the Dify application (E.G. the moderation - # mechanism.) - STOPPED = "stopped" - - # `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow - # execution, but they were successfully handled (e.g., by using an error - # strategy such as "fail branch" or "default value"). - PARTIAL_SUCCEEDED = "partial-succeeded" - - # `PAUSED` indicates that the workflow execution is temporarily paused - # (e.g., awaiting human input) and is expected to resume later. - PAUSED = "paused" - - def is_ended(self) -> bool: - return self in _END_STATE - - @classmethod - def ended_values(cls) -> list[str]: - return [status.value for status in _END_STATE] - - -_END_STATE = frozenset( - [ - WorkflowExecutionStatus.SUCCEEDED, - WorkflowExecutionStatus.FAILED, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED, - WorkflowExecutionStatus.STOPPED, - ] -) - - -class WorkflowNodeExecutionMetadataKey(StrEnum): - """ - Node Run Metadata Key. - - Values in this enum are persisted as execution metadata and must stay in sync - with every node that writes `NodeRunResult.metadata`. - """ - - TOTAL_TOKENS = "total_tokens" - TOTAL_PRICE = "total_price" - CURRENCY = "currency" - TOOL_INFO = "tool_info" - AGENT_LOG = "agent_log" - ITERATION_ID = "iteration_id" - ITERATION_INDEX = "iteration_index" - LOOP_ID = "loop_id" - LOOP_INDEX = "loop_index" - PARALLEL_ID = "parallel_id" - PARALLEL_START_NODE_ID = "parallel_start_node_id" - PARENT_PARALLEL_ID = "parent_parallel_id" - PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" - PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" - ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs - LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs - ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field - LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output - DATASOURCE_INFO = "datasource_info" - TRIGGER_INFO = "trigger_info" - COMPLETED_REASON = "completed_reason" # completed reason for loop node - - -class WorkflowNodeExecutionStatus(StrEnum): - PENDING = "pending" # Node is scheduled but not yet executing - RUNNING = "running" - SUCCEEDED = "succeeded" - FAILED = "failed" - EXCEPTION = "exception" - STOPPED = "stopped" - PAUSED = "paused" - - # Legacy statuses - kept for backward compatibility - RETRY = "retry" # Legacy: replaced by retry mechanism in error handling diff --git a/api/graphon/errors.py b/api/graphon/errors.py deleted file mode 100644 index 7eb007524d..0000000000 --- a/api/graphon/errors.py +++ /dev/null @@ -1,16 +0,0 @@ -from graphon.nodes.base.node import Node - - -class WorkflowNodeRunFailedError(Exception): - def __init__(self, node: Node, err_msg: str): - self._node = node - self._error = err_msg - super().__init__(f"Node {node.title} run failed: {err_msg}") - - @property - def node(self) -> Node: - return self._node - - @property - def error(self) -> str: - return self._error diff --git a/api/graphon/file/__init__.py b/api/graphon/file/__init__.py deleted file mode 100644 index 4908ae9795..0000000000 --- a/api/graphon/file/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .constants import FILE_MODEL_IDENTITY -from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType -from .file_factory import get_file_type_by_mime_type, standardize_file_type -from .models import ( - File, - FileUploadConfig, - ImageConfig, -) - -__all__ = [ - "FILE_MODEL_IDENTITY", - "ArrayFileAttribute", - "File", - "FileAttribute", - "FileBelongsTo", - "FileTransferMethod", - "FileType", - "FileUploadConfig", - "ImageConfig", - "get_file_type_by_mime_type", - "standardize_file_type", -] diff --git a/api/graphon/file/constants.py b/api/graphon/file/constants.py deleted file mode 100644 index 56b95b5f0d..0000000000 --- a/api/graphon/file/constants.py +++ /dev/null @@ -1,48 +0,0 @@ -from collections.abc import Iterable -from typing import Any - -# TODO(QuantumGhost): Refactor variable type identification. Instead of directly -# comparing `dify_model_identity` with constants throughout the codebase, extract -# this logic into a dedicated function. This would encapsulate the implementation -# details of how different variable types are identified. -FILE_MODEL_IDENTITY = "__dify__file__" -DEFAULT_MIME_TYPE = "application/octet-stream" -DEFAULT_EXTENSION = ".bin" - - -def _with_case_variants(extensions: Iterable[str]) -> frozenset[str]: - normalized = {extension.lower() for extension in extensions} - return frozenset(normalized | {extension.upper() for extension in normalized}) - - -IMAGE_EXTENSIONS = _with_case_variants({"jpg", "jpeg", "png", "webp", "gif", "svg"}) -VIDEO_EXTENSIONS = _with_case_variants({"mp4", "mov", "mpeg", "webm"}) -AUDIO_EXTENSIONS = _with_case_variants({"mp3", "m4a", "wav", "amr", "mpga"}) -DOCUMENT_EXTENSIONS = _with_case_variants( - { - "txt", - "markdown", - "md", - "mdx", - "pdf", - "html", - "htm", - "xlsx", - "xls", - "vtt", - "properties", - "doc", - "docx", - "csv", - "eml", - "msg", - "ppt", - "pptx", - "xml", - "epub", - } -) - - -def maybe_file_object(o: Any) -> bool: - return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY diff --git a/api/graphon/file/enums.py b/api/graphon/file/enums.py deleted file mode 100644 index 170eb4fc23..0000000000 --- a/api/graphon/file/enums.py +++ /dev/null @@ -1,57 +0,0 @@ -from enum import StrEnum - - -class FileType(StrEnum): - IMAGE = "image" - DOCUMENT = "document" - AUDIO = "audio" - VIDEO = "video" - CUSTOM = "custom" - - @staticmethod - def value_of(value): - for member in FileType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileTransferMethod(StrEnum): - REMOTE_URL = "remote_url" - LOCAL_FILE = "local_file" - TOOL_FILE = "tool_file" - DATASOURCE_FILE = "datasource_file" - - @staticmethod - def value_of(value): - for member in FileTransferMethod: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileBelongsTo(StrEnum): - USER = "user" - ASSISTANT = "assistant" - - @staticmethod - def value_of(value): - for member in FileBelongsTo: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileAttribute(StrEnum): - TYPE = "type" - SIZE = "size" - NAME = "name" - MIME_TYPE = "mime_type" - TRANSFER_METHOD = "transfer_method" - URL = "url" - EXTENSION = "extension" - RELATED_ID = "related_id" - - -class ArrayFileAttribute(StrEnum): - LENGTH = "length" diff --git a/api/graphon/file/file_factory.py b/api/graphon/file/file_factory.py deleted file mode 100644 index 3d20b9377d..0000000000 --- a/api/graphon/file/file_factory.py +++ /dev/null @@ -1,39 +0,0 @@ -from .constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS -from .enums import FileType - - -def standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType: - """ - Infer the actual file type from extension and mime type. - """ - guessed_type = None - if extension: - guessed_type = _get_file_type_by_extension(extension) - if guessed_type is None and mime_type: - guessed_type = get_file_type_by_mime_type(mime_type) - return guessed_type or FileType.CUSTOM - - -def _get_file_type_by_extension(extension: str) -> FileType | None: - normalized_extension = extension.lstrip(".") - if normalized_extension in IMAGE_EXTENSIONS: - return FileType.IMAGE - if normalized_extension in VIDEO_EXTENSIONS: - return FileType.VIDEO - if normalized_extension in AUDIO_EXTENSIONS: - return FileType.AUDIO - if normalized_extension in DOCUMENT_EXTENSIONS: - return FileType.DOCUMENT - return None - - -def get_file_type_by_mime_type(mime_type: str) -> FileType: - if "image" in mime_type: - return FileType.IMAGE - if "video" in mime_type: - return FileType.VIDEO - if "audio" in mime_type: - return FileType.AUDIO - if "text" in mime_type or "pdf" in mime_type: - return FileType.DOCUMENT - return FileType.CUSTOM diff --git a/api/graphon/file/file_manager.py b/api/graphon/file/file_manager.py deleted file mode 100644 index d7e4d472e7..0000000000 --- a/api/graphon/file/file_manager.py +++ /dev/null @@ -1,129 +0,0 @@ -from __future__ import annotations - -import base64 -from collections.abc import Mapping - -from graphon.model_runtime.entities import ( - AudioPromptMessageContent, - DocumentPromptMessageContent, - ImagePromptMessageContent, - TextPromptMessageContent, - VideoPromptMessageContent, -) -from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes - -from .enums import FileAttribute -from .models import File, FileTransferMethod, FileType -from .runtime import get_workflow_file_runtime - - -def get_attr(*, file: File, attr: FileAttribute): - match attr: - case FileAttribute.TYPE: - return file.type.value - case FileAttribute.SIZE: - return file.size - case FileAttribute.NAME: - return file.filename - case FileAttribute.MIME_TYPE: - return file.mime_type - case FileAttribute.TRANSFER_METHOD: - return file.transfer_method.value - case FileAttribute.URL: - return _to_url(file) - case FileAttribute.EXTENSION: - return file.extension - case FileAttribute.RELATED_ID: - return file.related_id - - -def to_prompt_message_content( - f: File, - /, - *, - image_detail_config: ImagePromptMessageContent.DETAIL | None = None, -) -> PromptMessageContentUnionTypes: - """Convert a file to prompt message content.""" - if f.extension is None: - raise ValueError("Missing file extension") - if f.mime_type is None: - raise ValueError("Missing file mime_type") - - prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = { - FileType.IMAGE: ImagePromptMessageContent, - FileType.AUDIO: AudioPromptMessageContent, - FileType.VIDEO: VideoPromptMessageContent, - FileType.DOCUMENT: DocumentPromptMessageContent, - } - - if f.type not in prompt_class_map: - return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]") - - send_format = get_workflow_file_runtime().multimodal_send_format - params = { - "base64_data": _get_encoded_string(f) if send_format == "base64" else "", - "url": _to_url(f) if send_format == "url" else "", - "format": f.extension.removeprefix("."), - "mime_type": f.mime_type, - "filename": f.filename or "", - } - if f.type == FileType.IMAGE: - params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - - return prompt_class_map[f.type].model_validate(params) - - -def download(f: File, /) -> bytes: - if f.transfer_method in ( - FileTransferMethod.TOOL_FILE, - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.DATASOURCE_FILE, - ): - return _download_file_content(f) - elif f.transfer_method == FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True) - response.raise_for_status() - return response.content - raise ValueError(f"unsupported transfer method: {f.transfer_method}") - - -def _download_file_content(file: File, /) -> bytes: - """Download and return a file from storage as bytes.""" - return get_workflow_file_runtime().load_file_bytes(file=file) - - -def _get_encoded_string(f: File, /) -> str: - match f.transfer_method: - case FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True) - response.raise_for_status() - data = response.content - case FileTransferMethod.LOCAL_FILE: - data = _download_file_content(f) - case FileTransferMethod.TOOL_FILE: - data = _download_file_content(f) - case FileTransferMethod.DATASOURCE_FILE: - data = _download_file_content(f) - - return base64.b64encode(data).decode("utf-8") - - -def _to_url(f: File, /): - url = f.generate_url() - if url is None: - raise ValueError(f"Unsupported transfer method: {f.transfer_method}") - return url - - -class FileManager: - """Adapter exposing file manager helpers behind FileManagerProtocol.""" - - def download(self, f: File, /) -> bytes: - return download(f) - - -file_manager = FileManager() diff --git a/api/graphon/file/helpers.py b/api/graphon/file/helpers.py deleted file mode 100644 index dade761227..0000000000 --- a/api/graphon/file/helpers.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from .runtime import get_workflow_file_runtime - -if TYPE_CHECKING: - from .models import File - - -def resolve_file_url(file: File, /, *, for_external: bool = True) -> str | None: - return get_workflow_file_runtime().resolve_file_url(file=file, for_external=for_external) - - -def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str: - return get_workflow_file_runtime().resolve_upload_file_url( - upload_file_id=upload_file_id, - as_attachment=as_attachment, - for_external=for_external, - ) - - -def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str: - return get_workflow_file_runtime().resolve_tool_file_url( - tool_file_id=tool_file_id, - extension=extension, - for_external=for_external, - ) - - -def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - return get_workflow_file_runtime().verify_preview_signature( - preview_kind="image", - file_id=upload_file_id, - timestamp=timestamp, - nonce=nonce, - sign=sign, - ) - - -def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - return get_workflow_file_runtime().verify_preview_signature( - preview_kind="file", - file_id=upload_file_id, - timestamp=timestamp, - nonce=nonce, - sign=sign, - ) diff --git a/api/graphon/file/models.py b/api/graphon/file/models.py deleted file mode 100644 index ccd7584371..0000000000 --- a/api/graphon/file/models.py +++ /dev/null @@ -1,215 +0,0 @@ -from __future__ import annotations - -import base64 -import json -from collections.abc import Mapping, Sequence -from typing import Any - -from pydantic import BaseModel, Field, model_validator - -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent - -from . import helpers -from .constants import FILE_MODEL_IDENTITY -from .enums import FileTransferMethod, FileType - -_FILE_REFERENCE_PREFIX = "dify-file-ref:" - - -def sign_tool_file(*, tool_file_id: str, extension: str, for_external: bool = True) -> str: - """Compatibility shim for tests and legacy callers patching ``models.sign_tool_file``.""" - return helpers.get_signed_tool_file_url( - tool_file_id=tool_file_id, - extension=extension, - for_external=for_external, - ) - - -class ImageConfig(BaseModel): - """ - NOTE: This part of validation is deprecated, but still used in app features "Image Upload". - """ - - number_limits: int = 0 - transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list) - detail: ImagePromptMessageContent.DETAIL | None = None - - -class FileUploadConfig(BaseModel): - """ - File Upload Entity. - """ - - image_config: ImageConfig | None = None - allowed_file_types: Sequence[FileType] = Field(default_factory=list) - allowed_file_extensions: Sequence[str] = Field(default_factory=list) - allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) - number_limits: int = 0 - - -def _parse_reference(reference: str | None) -> tuple[str | None, str | None]: - """Best-effort parser for record references and historical storage-key payloads.""" - if not reference: - return None, None - - if not reference.startswith(_FILE_REFERENCE_PREFIX): - return reference, None - - encoded_payload = reference.removeprefix(_FILE_REFERENCE_PREFIX) - try: - payload = json.loads(base64.urlsafe_b64decode(encoded_payload.encode())) - except (ValueError, json.JSONDecodeError): - return reference, None - - record_id = payload.get("record_id") - if not isinstance(record_id, str) or not record_id: - return reference, None - - storage_key = payload.get("storage_key") - if not isinstance(storage_key, str): - storage_key = None - - return record_id, storage_key - - -class File(BaseModel): - """Graph-owned file reference. - - The graph layer deliberately keeps only the metadata required to route, - serialize, and render files. Application ownership concerns such as - tenant/user/conversation identity stay in the workflow/storage layer. - """ - - # NOTE: dify_model_identity is a special identifier used to distinguish between - # new and old data formats during serialization and deserialization. - dify_model_identity: str = FILE_MODEL_IDENTITY - - id: str | None = None # message file id - type: FileType - transfer_method: FileTransferMethod - # If `transfer_method` is `FileTransferMethod.remote_url`, the - # `remote_url` attribute must not be `None`. - remote_url: str | None = None # remote url - # Opaque workflow-layer reference for files resolved outside ``graphon``. - # New payloads only carry the backing record id; historical payloads may - # still include storage_key and must remain readable. - reference: str | None = None - filename: str | None = None - extension: str | None = Field(default=None, description="File extension, should contain dot") - mime_type: str | None = None - size: int = -1 - _storage_key: str - - def __init__( - self, - *, - id: str | None = None, - tenant_id: str | None = None, - type: FileType, - transfer_method: FileTransferMethod, - remote_url: str | None = None, - reference: str | None = None, - related_id: str | None = None, - filename: str | None = None, - extension: str | None = None, - mime_type: str | None = None, - size: int = -1, - storage_key: str | None = None, - dify_model_identity: str | None = FILE_MODEL_IDENTITY, - url: str | None = None, - # Legacy compatibility fields - explicitly accept known extra fields - tool_file_id: str | None = None, - upload_file_id: str | None = None, - datasource_file_id: str | None = None, - ): - legacy_record_id = related_id or tool_file_id or upload_file_id or datasource_file_id - normalized_reference = reference - if normalized_reference is None and legacy_record_id is not None: - normalized_reference = str(legacy_record_id) - _, parsed_storage_key = _parse_reference(normalized_reference) - - super().__init__( - id=id, - type=type, - transfer_method=transfer_method, - remote_url=remote_url, - reference=normalized_reference, - filename=filename, - extension=extension, - mime_type=mime_type, - size=size, - dify_model_identity=dify_model_identity, - url=url, - ) - # Accept legacy constructor fields without promoting them back into the graph model. - _ = tenant_id - self._storage_key = storage_key or parsed_storage_key or "" - - def to_dict(self) -> Mapping[str, str | int | None]: - data = self.model_dump(mode="json") - return { - **data, - "related_id": self.related_id, - "url": self.generate_url(), - } - - @property - def markdown(self) -> str: - url = self.generate_url() - if self.type == FileType.IMAGE: - text = f"![{self.filename or ''}]({url})" - else: - text = f"[{self.filename or url}]({url})" - - return text - - def generate_url(self, for_external: bool = True) -> str | None: - return helpers.resolve_file_url(self, for_external=for_external) - - def to_plugin_parameter(self) -> dict[str, Any]: - return { - "dify_model_identity": FILE_MODEL_IDENTITY, - "mime_type": self.mime_type, - "filename": self.filename, - "extension": self.extension, - "size": self.size, - "type": self.type, - "url": self.generate_url(for_external=False), - } - - @model_validator(mode="after") - def validate_after(self) -> File: - match self.transfer_method: - case FileTransferMethod.REMOTE_URL: - if not self.remote_url: - raise ValueError("Missing file url") - if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"): - raise ValueError("Invalid file url") - case FileTransferMethod.LOCAL_FILE: - if not self.reference: - raise ValueError("Missing file reference") - case FileTransferMethod.TOOL_FILE: - if not self.reference: - raise ValueError("Missing file reference") - case FileTransferMethod.DATASOURCE_FILE: - if not self.reference: - raise ValueError("Missing file reference") - return self - - @property - def related_id(self) -> str | None: - record_id, _ = _parse_reference(self.reference) - return record_id - - @related_id.setter - def related_id(self, value: str | None) -> None: - self.reference = value - - @property - def storage_key(self) -> str: - _, storage_key = _parse_reference(self.reference) - return storage_key or self._storage_key - - @storage_key.setter - def storage_key(self, value: str) -> None: - self._storage_key = value diff --git a/api/graphon/file/protocols.py b/api/graphon/file/protocols.py deleted file mode 100644 index 0acabe35e5..0000000000 --- a/api/graphon/file/protocols.py +++ /dev/null @@ -1,56 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator -from typing import TYPE_CHECKING, Literal, Protocol - -if TYPE_CHECKING: - from .models import File - - -class HttpResponseProtocol(Protocol): - """Subset of response behavior needed by workflow file helpers.""" - - @property - def content(self) -> bytes: ... - - def raise_for_status(self) -> object: ... - - -class WorkflowFileRuntimeProtocol(Protocol): - """Runtime dependencies required by ``graphon.file``. - - Implementations are expected to be provided by integration layers (for example, - ``core.app.workflow.file_runtime``) so the workflow package avoids importing - application infrastructure modules directly. - """ - - @property - def multimodal_send_format(self) -> str: ... - - def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: ... - - def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ... - - def load_file_bytes(self, *, file: File) -> bytes: ... - - def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: ... - - def resolve_upload_file_url( - self, - *, - upload_file_id: str, - as_attachment: bool = False, - for_external: bool = True, - ) -> str: ... - - def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ... - - def verify_preview_signature( - self, - *, - preview_kind: Literal["image", "file"], - file_id: str, - timestamp: str, - nonce: str, - sign: str, - ) -> bool: ... diff --git a/api/graphon/file/runtime.py b/api/graphon/file/runtime.py deleted file mode 100644 index 1c5d1c3ca4..0000000000 --- a/api/graphon/file/runtime.py +++ /dev/null @@ -1,71 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator -from typing import TYPE_CHECKING, Literal, NoReturn - -from .protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol - -if TYPE_CHECKING: - from .models import File - - -class WorkflowFileRuntimeNotConfiguredError(RuntimeError): - """Raised when workflow file runtime dependencies were not configured.""" - - -class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol): - def _raise(self) -> NoReturn: - raise WorkflowFileRuntimeNotConfiguredError( - "workflow file runtime is not configured, call set_workflow_file_runtime(...) first" - ) - - @property - def multimodal_send_format(self) -> str: - self._raise() - - def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: - self._raise() - - def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: - self._raise() - - def load_file_bytes(self, *, file: File) -> bytes: - self._raise() - - def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: - self._raise() - - def resolve_upload_file_url( - self, - *, - upload_file_id: str, - as_attachment: bool = False, - for_external: bool = True, - ) -> str: - self._raise() - - def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: - self._raise() - - def verify_preview_signature( - self, - *, - preview_kind: Literal["image", "file"], - file_id: str, - timestamp: str, - nonce: str, - sign: str, - ) -> bool: - self._raise() - - -_runtime: WorkflowFileRuntimeProtocol = _UnconfiguredWorkflowFileRuntime() - - -def set_workflow_file_runtime(runtime: WorkflowFileRuntimeProtocol) -> None: - global _runtime - _runtime = runtime - - -def get_workflow_file_runtime() -> WorkflowFileRuntimeProtocol: - return _runtime diff --git a/api/graphon/file/tool_file_parser.py b/api/graphon/file/tool_file_parser.py deleted file mode 100644 index 2d7a3d43df..0000000000 --- a/api/graphon/file/tool_file_parser.py +++ /dev/null @@ -1,9 +0,0 @@ -from collections.abc import Callable -from typing import Any - -_tool_file_manager_factory: Callable[[], Any] | None = None - - -def set_tool_file_manager_factory(factory: Callable[[], Any]): - global _tool_file_manager_factory - _tool_file_manager_factory = factory diff --git a/api/graphon/graph/__init__.py b/api/graphon/graph/__init__.py deleted file mode 100644 index 4830ea83d3..0000000000 --- a/api/graphon/graph/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .edge import Edge -from .graph import Graph, GraphBuilder, NodeFactory -from .graph_template import GraphTemplate - -__all__ = [ - "Edge", - "Graph", - "GraphBuilder", - "GraphTemplate", - "NodeFactory", -] diff --git a/api/graphon/graph/edge.py b/api/graphon/graph/edge.py deleted file mode 100644 index 1f8a2884e3..0000000000 --- a/api/graphon/graph/edge.py +++ /dev/null @@ -1,15 +0,0 @@ -import uuid -from dataclasses import dataclass, field - -from graphon.enums import NodeState - - -@dataclass -class Edge: - """Edge connecting two nodes in a workflow graph.""" - - id: str = field(default_factory=lambda: str(uuid.uuid4())) - tail: str = "" # tail node id (source) - head: str = "" # head node id (target) - source_handle: str = "source" # source handle for conditional branching - state: NodeState = field(default=NodeState.UNKNOWN) # edge execution state diff --git a/api/graphon/graph/graph.py b/api/graphon/graph/graph.py deleted file mode 100644 index 0f4cd8925f..0000000000 --- a/api/graphon/graph/graph.py +++ /dev/null @@ -1,438 +0,0 @@ -from __future__ import annotations - -import logging -from collections import defaultdict -from collections.abc import Mapping, Sequence -from typing import Protocol, cast, final - -from pydantic import TypeAdapter - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ErrorStrategy, NodeExecutionType, NodeState -from graphon.nodes.base.node import Node - -from .edge import Edge -from .validation import get_graph_validator - -logger = logging.getLogger(__name__) - -_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict]) - - -class NodeFactory(Protocol): - """ - Protocol for creating Node instances from node data dictionaries. - - This protocol decouples the Graph class from specific node mapping implementations, - allowing for different node creation strategies while maintaining type safety. - """ - - def create_node(self, node_config: NodeConfigDict) -> Node: - """ - Create a Node instance from node configuration data. - - :param node_config: node configuration dictionary containing type and other data - :return: initialized Node instance - :raises ValueError: if node type is unknown or no implementation exists for the resolved version - :raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation - """ - ... - - -@final -class Graph: - """Graph representation with nodes and edges for workflow execution.""" - - def __init__( - self, - *, - nodes: dict[str, Node] | None = None, - edges: dict[str, Edge] | None = None, - in_edges: dict[str, list[str]] | None = None, - out_edges: dict[str, list[str]] | None = None, - root_node: Node, - ): - """ - Initialize Graph instance. - - :param nodes: graph nodes mapping (node id: node object) - :param edges: graph edges mapping (edge id: edge object) - :param in_edges: incoming edges mapping (node id: list of edge ids) - :param out_edges: outgoing edges mapping (node id: list of edge ids) - :param root_node: root node object - """ - self.nodes = nodes or {} - self.edges = edges or {} - self.in_edges = in_edges or {} - self.out_edges = out_edges or {} - self.root_node = root_node - - @classmethod - def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]: - """ - Parse node configurations and build a mapping of node IDs to configs. - - :param node_configs: list of node configuration dictionaries - :return: mapping of node ID to node config - """ - node_configs_map: dict[str, NodeConfigDict] = {} - - for node_config in node_configs: - node_configs_map[node_config["id"]] = node_config - - return node_configs_map - - @classmethod - def _build_edges( - cls, edge_configs: list[dict[str, object]] - ) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]: - """ - Build edge objects and mappings from edge configurations. - - :param edge_configs: list of edge configurations - :return: tuple of (edges dict, in_edges dict, out_edges dict) - """ - edges: dict[str, Edge] = {} - in_edges: dict[str, list[str]] = defaultdict(list) - out_edges: dict[str, list[str]] = defaultdict(list) - - edge_counter = 0 - for edge_config in edge_configs: - source = edge_config.get("source") - target = edge_config.get("target") - - if not isinstance(source, str) or not isinstance(target, str): - continue - - # Create edge - edge_id = f"edge_{edge_counter}" - edge_counter += 1 - - source_handle = edge_config.get("sourceHandle", "source") - if not isinstance(source_handle, str): - continue - - edge = Edge( - id=edge_id, - tail=source, - head=target, - source_handle=source_handle, - ) - - edges[edge_id] = edge - out_edges[source].append(edge_id) - in_edges[target].append(edge_id) - - return edges, dict(in_edges), dict(out_edges) - - @classmethod - def _create_node_instances( - cls, - node_configs_map: dict[str, NodeConfigDict], - node_factory: NodeFactory, - ) -> dict[str, Node]: - """ - Create node instances from configurations using the node factory. - - :param node_configs_map: mapping of node ID to node config - :param node_factory: factory for creating node instances - :return: mapping of node ID to node instance - """ - nodes: dict[str, Node] = {} - - for node_id, node_config in node_configs_map.items(): - try: - node_instance = node_factory.create_node(node_config) - except Exception: - logger.exception("Failed to create node instance for node_id %s", node_id) - raise - nodes[node_id] = node_instance - - return nodes - - @classmethod - def new(cls) -> GraphBuilder: - """Create a fluent builder for assembling a graph programmatically.""" - - return GraphBuilder(graph_cls=cls) - - @staticmethod - def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]: - """ - Remove editor-only nodes before `NodeConfigDict` validation. - - Persisted note widgets use a top-level `type == "custom-note"` but leave - `data.type` empty because they are never executable graph nodes. Filter - them while configs are still raw dicts so Pydantic does not validate - their placeholder payloads against `BaseNodeData.type: NodeType`. - """ - filtered_node_configs: list[dict[str, object]] = [] - for node_config in node_configs: - if node_config.get("type", "") == "custom-note": - continue - filtered_node_configs.append(dict(node_config)) - return filtered_node_configs - - @classmethod - def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None: - """ - Promote nodes configured with FAIL_BRANCH error strategy to branch execution type. - - :param nodes: mapping of node ID to node instance - """ - for node in nodes.values(): - if node.error_strategy == ErrorStrategy.FAIL_BRANCH: - node.execution_type = NodeExecutionType.BRANCH - - @classmethod - def _mark_inactive_root_branches( - cls, - nodes: dict[str, Node], - edges: dict[str, Edge], - in_edges: dict[str, list[str]], - out_edges: dict[str, list[str]], - active_root_id: str, - ) -> None: - """ - Mark nodes and edges from inactive root branches as skipped. - - Algorithm: - 1. Mark inactive root nodes as skipped - 2. For skipped nodes, mark all their outgoing edges as skipped - 3. For each edge marked as skipped, check its target node: - - If ALL incoming edges are skipped, mark the node as skipped - - Otherwise, leave the node state unchanged - - :param nodes: mapping of node ID to node instance - :param edges: mapping of edge ID to edge instance - :param in_edges: mapping of node ID to incoming edge IDs - :param out_edges: mapping of node ID to outgoing edge IDs - :param active_root_id: ID of the active root node - """ - # Find all top-level root nodes (nodes with ROOT execution type and no incoming edges) - top_level_roots: list[str] = [ - node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT - ] - - # If there's only one root or the active root is not a top-level root, no marking needed - if len(top_level_roots) <= 1 or active_root_id not in top_level_roots: - return - - # Mark inactive root nodes as skipped - inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id] - for root_id in inactive_roots: - if root_id in nodes: - nodes[root_id].state = NodeState.SKIPPED - - # Recursively mark downstream nodes and edges - def mark_downstream(node_id: str) -> None: - """Recursively mark downstream nodes and edges as skipped.""" - if nodes[node_id].state != NodeState.SKIPPED: - return - # If this node is skipped, mark all its outgoing edges as skipped - out_edge_ids = out_edges.get(node_id, []) - for edge_id in out_edge_ids: - edge = edges[edge_id] - edge.state = NodeState.SKIPPED - - # Check the target node of this edge - target_node = nodes[edge.head] - in_edge_ids = in_edges.get(target_node.id, []) - in_edge_states = [edges[eid].state for eid in in_edge_ids] - - # If all incoming edges are skipped, mark the node as skipped - if all(state == NodeState.SKIPPED for state in in_edge_states): - target_node.state = NodeState.SKIPPED - # Recursively process downstream nodes - mark_downstream(target_node.id) - - # Process each inactive root and its downstream nodes - for root_id in inactive_roots: - mark_downstream(root_id) - - @classmethod - def init( - cls, - *, - graph_config: Mapping[str, object], - node_factory: NodeFactory, - root_node_id: str, - skip_validation: bool = False, - ) -> Graph: - """ - Initialize a graph with an explicit execution entry point. - - :param graph_config: graph config containing nodes and edges - :param node_factory: factory for creating node instances from config data - :param root_node_id: active root node id - :return: graph instance - """ - # Parse configs - edge_configs = graph_config.get("edges", []) - node_configs = graph_config.get("nodes", []) - - edge_configs = cast(list[dict[str, object]], edge_configs) - node_configs = cast(list[dict[str, object]], node_configs) - node_configs = cls._filter_canvas_only_nodes(node_configs) - node_configs = _ListNodeConfigDict.validate_python(node_configs) - - if not node_configs: - raise ValueError("Graph must have at least one node") - - # Parse node configurations - node_configs_map = cls._parse_node_configs(node_configs) - - if root_node_id not in node_configs_map: - raise ValueError(f"Root node id {root_node_id} not found in the graph") - - # Build edges - edges, in_edges, out_edges = cls._build_edges(edge_configs) - - # Create node instances - nodes = cls._create_node_instances(node_configs_map, node_factory) - - # Promote fail-branch nodes to branch execution type at graph level - cls._promote_fail_branch_nodes(nodes) - - # Get root node instance - root_node = nodes[root_node_id] - - # Mark inactive root branches as skipped - cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id) - - # Create and return the graph - graph = cls( - nodes=nodes, - edges=edges, - in_edges=in_edges, - out_edges=out_edges, - root_node=root_node, - ) - - if not skip_validation: - # Validate the graph structure using built-in validators - get_graph_validator().validate(graph) - - return graph - - @property - def node_ids(self) -> list[str]: - """ - Get list of node IDs (compatibility property for existing code) - - :return: list of node IDs - """ - return list(self.nodes.keys()) - - def get_outgoing_edges(self, node_id: str) -> list[Edge]: - """ - Get all outgoing edges from a node (V2 method) - - :param node_id: node id - :return: list of outgoing edges - """ - edge_ids = self.out_edges.get(node_id, []) - return [self.edges[eid] for eid in edge_ids if eid in self.edges] - - def get_incoming_edges(self, node_id: str) -> list[Edge]: - """ - Get all incoming edges to a node (V2 method) - - :param node_id: node id - :return: list of incoming edges - """ - edge_ids = self.in_edges.get(node_id, []) - return [self.edges[eid] for eid in edge_ids if eid in self.edges] - - -@final -class GraphBuilder: - """Fluent helper for constructing simple graphs, primarily for tests.""" - - def __init__(self, *, graph_cls: type[Graph]): - self._graph_cls = graph_cls - self._nodes: list[Node] = [] - self._nodes_by_id: dict[str, Node] = {} - self._edges: list[Edge] = [] - self._edge_counter = 0 - - def add_root(self, node: Node) -> GraphBuilder: - """Register the root node. Must be called exactly once.""" - - if self._nodes: - raise ValueError("Root node has already been added") - self._register_node(node) - self._nodes.append(node) - return self - - def add_node( - self, - node: Node, - *, - from_node_id: str | None = None, - source_handle: str = "source", - ) -> GraphBuilder: - """Append a node and connect it from the specified predecessor.""" - - if not self._nodes: - raise ValueError("Root node must be added before adding other nodes") - - predecessor_id = from_node_id or self._nodes[-1].id - if predecessor_id not in self._nodes_by_id: - raise ValueError(f"Predecessor node '{predecessor_id}' not found") - - predecessor = self._nodes_by_id[predecessor_id] - self._register_node(node) - self._nodes.append(node) - - edge_id = f"edge_{self._edge_counter}" - self._edge_counter += 1 - edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle) - self._edges.append(edge) - - return self - - def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder: - """Connect two existing nodes without adding a new node.""" - - if tail not in self._nodes_by_id: - raise ValueError(f"Tail node '{tail}' not found") - if head not in self._nodes_by_id: - raise ValueError(f"Head node '{head}' not found") - - edge_id = f"edge_{self._edge_counter}" - self._edge_counter += 1 - edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle) - self._edges.append(edge) - - return self - - def build(self) -> Graph: - """Materialize the graph instance from the accumulated nodes and edges.""" - - if not self._nodes: - raise ValueError("Cannot build an empty graph") - - nodes = {node.id: node for node in self._nodes} - edges = {edge.id: edge for edge in self._edges} - in_edges: dict[str, list[str]] = defaultdict(list) - out_edges: dict[str, list[str]] = defaultdict(list) - - for edge in self._edges: - out_edges[edge.tail].append(edge.id) - in_edges[edge.head].append(edge.id) - - return self._graph_cls( - nodes=nodes, - edges=edges, - in_edges=dict(in_edges), - out_edges=dict(out_edges), - root_node=self._nodes[0], - ) - - def _register_node(self, node: Node) -> None: - if not node.id: - raise ValueError("Node must have a non-empty id") - if node.id in self._nodes_by_id: - raise ValueError(f"Duplicate node id detected: {node.id}") - self._nodes_by_id[node.id] = node diff --git a/api/graphon/graph/graph_template.py b/api/graphon/graph/graph_template.py deleted file mode 100644 index 34e2dc19e6..0000000000 --- a/api/graphon/graph/graph_template.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Any - -from pydantic import BaseModel, Field - - -class GraphTemplate(BaseModel): - """ - Graph Template for container nodes and subgraph expansion - - According to GraphEngine V2 spec, GraphTemplate contains: - - nodes: mapping of node definitions - - edges: mapping of edge definitions - - root_ids: list of root node IDs - - output_selectors: list of output selectors for the template - """ - - nodes: dict[str, dict[str, Any]] = Field(default_factory=dict, description="node definitions mapping") - edges: dict[str, dict[str, Any]] = Field(default_factory=dict, description="edge definitions mapping") - root_ids: list[str] = Field(default_factory=list, description="root node IDs") - output_selectors: list[str] = Field(default_factory=list, description="output selectors") diff --git a/api/graphon/graph/validation.py b/api/graphon/graph/validation.py deleted file mode 100644 index 04b501fd33..0000000000 --- a/api/graphon/graph/validation.py +++ /dev/null @@ -1,125 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from dataclasses import dataclass -from typing import TYPE_CHECKING, Protocol - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeType - -if TYPE_CHECKING: - from .graph import Graph - - -@dataclass(frozen=True, slots=True) -class GraphValidationIssue: - """Immutable value object describing a single validation issue.""" - - code: str - message: str - node_id: str | None = None - - -class GraphValidationError(ValueError): - """Raised when graph validation fails.""" - - def __init__(self, issues: Sequence[GraphValidationIssue]) -> None: - if not issues: - raise ValueError("GraphValidationError requires at least one issue.") - self.issues: tuple[GraphValidationIssue, ...] = tuple(issues) - message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues) - super().__init__(message) - - -class GraphValidationRule(Protocol): - """Protocol that individual validation rules must satisfy.""" - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - """Validate the provided graph and return any discovered issues.""" - ... - - -@dataclass(frozen=True, slots=True) -class _EdgeEndpointValidator: - """Ensures all edges reference existing nodes.""" - - missing_node_code: str = "MISSING_NODE" - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - issues: list[GraphValidationIssue] = [] - for edge in graph.edges.values(): - if edge.tail not in graph.nodes: - issues.append( - GraphValidationIssue( - code=self.missing_node_code, - message=f"Edge {edge.id} references unknown source node '{edge.tail}'.", - node_id=edge.tail, - ) - ) - if edge.head not in graph.nodes: - issues.append( - GraphValidationIssue( - code=self.missing_node_code, - message=f"Edge {edge.id} references unknown target node '{edge.head}'.", - node_id=edge.head, - ) - ) - return issues - - -@dataclass(frozen=True, slots=True) -class _RootNodeValidator: - """Validates root node invariants.""" - - invalid_root_code: str = "INVALID_ROOT" - container_entry_types: tuple[NodeType, ...] = (BuiltinNodeTypes.ITERATION_START, BuiltinNodeTypes.LOOP_START) - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - root_node = graph.root_node - issues: list[GraphValidationIssue] = [] - if root_node.id not in graph.nodes: - issues.append( - GraphValidationIssue( - code=self.invalid_root_code, - message=f"Root node '{root_node.id}' is missing from the node registry.", - node_id=root_node.id, - ) - ) - return issues - - node_type = root_node.node_type - if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types: - issues.append( - GraphValidationIssue( - code=self.invalid_root_code, - message=f"Root node '{root_node.id}' must declare execution type 'root'.", - node_id=root_node.id, - ) - ) - return issues - - -@dataclass(frozen=True, slots=True) -class GraphValidator: - """Coordinates execution of graph validation rules.""" - - rules: tuple[GraphValidationRule, ...] - - def validate(self, graph: Graph) -> None: - """Validate the graph against all configured rules.""" - issues: list[GraphValidationIssue] = [] - for rule in self.rules: - issues.extend(rule.validate(graph)) - - if issues: - raise GraphValidationError(issues) - - -_DEFAULT_RULES: tuple[GraphValidationRule, ...] = ( - _EdgeEndpointValidator(), - _RootNodeValidator(), -) - - -def get_graph_validator() -> GraphValidator: - """Construct the validator composed of default rules.""" - return GraphValidator(_DEFAULT_RULES) diff --git a/api/graphon/graph_engine/__init__.py b/api/graphon/graph_engine/__init__.py deleted file mode 100644 index 0e1c7dd60a..0000000000 --- a/api/graphon/graph_engine/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .config import GraphEngineConfig -from .graph_engine import GraphEngine - -__all__ = ["GraphEngine", "GraphEngineConfig"] diff --git a/api/graphon/graph_engine/_engine_utils.py b/api/graphon/graph_engine/_engine_utils.py deleted file mode 100644 index 28898268fe..0000000000 --- a/api/graphon/graph_engine/_engine_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -import time - - -def get_timestamp() -> float: - """Retrieve a timestamp as a float point numer representing the number of seconds - since the Unix epoch. - - This function is primarily used to measure the execution time of the workflow engine. - Since workflow execution may be paused and resumed on a different machine, - `time.perf_counter` cannot be used as it is inconsistent across machines. - - To address this, the function uses the wall clock as the time source. - However, it assumes that the clocks of all servers are properly synchronized. - """ - return round(time.time()) diff --git a/api/graphon/graph_engine/command_channels/README.md b/api/graphon/graph_engine/command_channels/README.md deleted file mode 100644 index e35e12054a..0000000000 --- a/api/graphon/graph_engine/command_channels/README.md +++ /dev/null @@ -1,33 +0,0 @@ -# Command Channels - -Channel implementations for external workflow control. - -## Components - -### InMemoryChannel - -Thread-safe in-memory queue for single-process deployments. - -- `fetch_commands()` - Get pending commands -- `send_command()` - Add command to queue - -### RedisChannel - -Redis-based queue for distributed deployments. - -- `fetch_commands()` - Get commands with JSON deserialization -- `send_command()` - Store commands with TTL - -## Usage - -```python -# Local execution -channel = InMemoryChannel() -channel.send_command(AbortCommand(graph_id="workflow-123")) - -# Distributed execution -redis_channel = RedisChannel( - redis_client=redis_client, - channel_key="workflow:123:commands" -) -``` diff --git a/api/graphon/graph_engine/command_channels/__init__.py b/api/graphon/graph_engine/command_channels/__init__.py deleted file mode 100644 index 863e6032d6..0000000000 --- a/api/graphon/graph_engine/command_channels/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Command channel implementations for GraphEngine.""" - -from .in_memory_channel import InMemoryChannel -from .redis_channel import RedisChannel - -__all__ = ["InMemoryChannel", "RedisChannel"] diff --git a/api/graphon/graph_engine/command_channels/in_memory_channel.py b/api/graphon/graph_engine/command_channels/in_memory_channel.py deleted file mode 100644 index bdaf236796..0000000000 --- a/api/graphon/graph_engine/command_channels/in_memory_channel.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -In-memory implementation of CommandChannel for local/testing scenarios. - -This implementation uses a thread-safe queue for command communication -within a single process. Each instance handles commands for one workflow execution. -""" - -from queue import Queue -from typing import final - -from ..entities.commands import GraphEngineCommand - - -@final -class InMemoryChannel: - """ - In-memory command channel implementation using a thread-safe queue. - - Each instance is dedicated to a single GraphEngine/workflow execution. - Suitable for local development, testing, and single-instance deployments. - """ - - def __init__(self) -> None: - """Initialize the in-memory channel with a single queue.""" - self._queue: Queue[GraphEngineCommand] = Queue() - - def fetch_commands(self) -> list[GraphEngineCommand]: - """ - Fetch all pending commands from the queue. - - Returns: - List of pending commands (drains the queue) - """ - commands: list[GraphEngineCommand] = [] - - # Drain all available commands from the queue - while not self._queue.empty(): - try: - command = self._queue.get_nowait() - commands.append(command) - except Exception: - break - - return commands - - def send_command(self, command: GraphEngineCommand) -> None: - """ - Send a command to this channel's queue. - - Args: - command: The command to send - """ - self._queue.put(command) diff --git a/api/graphon/graph_engine/command_channels/redis_channel.py b/api/graphon/graph_engine/command_channels/redis_channel.py deleted file mode 100644 index 77cf884c67..0000000000 --- a/api/graphon/graph_engine/command_channels/redis_channel.py +++ /dev/null @@ -1,153 +0,0 @@ -""" -Redis-based implementation of CommandChannel for distributed scenarios. - -This implementation uses Redis lists for command queuing, supporting -multi-instance deployments and cross-server communication. -Each instance uses a unique key for its command queue. -""" - -import json -from contextlib import AbstractContextManager -from typing import Any, Protocol, final - -from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand - - -class RedisPipelineProtocol(Protocol): - """Minimal Redis pipeline contract used by the command channel.""" - - def lrange(self, name: str, start: int, end: int) -> Any: ... - def delete(self, *names: str) -> Any: ... - def execute(self) -> list[Any]: ... - def rpush(self, name: str, *values: str) -> Any: ... - def expire(self, name: str, time: int) -> Any: ... - def set(self, name: str, value: str, ex: int | None = None) -> Any: ... - def get(self, name: str) -> Any: ... - - -class RedisClientProtocol(Protocol): - """Redis client contract required by the command channel.""" - - def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ... - - -@final -class RedisChannel: - """ - Redis-based command channel implementation for distributed systems. - - Each instance uses a unique Redis key for its command queue. - Commands are JSON-serialized for transport. - """ - - def __init__( - self, - redis_client: RedisClientProtocol, - channel_key: str, - command_ttl: int = 3600, - ) -> None: - """ - Initialize the Redis channel. - - Args: - redis_client: Redis client instance - channel_key: Unique key for this channel's command queue - command_ttl: TTL for command keys in seconds (default: 3600) - """ - self._redis = redis_client - self._key = channel_key - self._command_ttl = command_ttl - self._pending_key = f"{channel_key}:pending" - - def fetch_commands(self) -> list[GraphEngineCommand]: - """ - Fetch all pending commands from Redis. - - Returns: - List of pending commands (drains the Redis list) - """ - if not self._has_pending_commands(): - return [] - - commands: list[GraphEngineCommand] = [] - - # Use pipeline for atomic operations - with self._redis.pipeline() as pipe: - # Get all commands and clear the list atomically - pipe.lrange(self._key, 0, -1) - pipe.delete(self._key) - results = pipe.execute() - - # Parse commands from JSON - if results[0]: - for command_json in results[0]: - try: - command_data = json.loads(command_json) - command = self._deserialize_command(command_data) - if command: - commands.append(command) - except (json.JSONDecodeError, ValueError): - # Skip invalid commands - continue - - return commands - - def send_command(self, command: GraphEngineCommand) -> None: - """ - Send a command to Redis. - - Args: - command: The command to send - """ - command_json = json.dumps(command.model_dump()) - - # Push to list and set expiry - with self._redis.pipeline() as pipe: - pipe.rpush(self._key, command_json) - pipe.expire(self._key, self._command_ttl) - pipe.set(self._pending_key, "1", ex=self._command_ttl) - pipe.execute() - - def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None: - """ - Deserialize a command from dictionary data. - - Args: - data: Command data dictionary - - Returns: - Deserialized command or None if invalid - """ - command_type_value = data.get("command_type") - if not isinstance(command_type_value, str): - return None - - try: - command_type = CommandType(command_type_value) - - if command_type == CommandType.ABORT: - return AbortCommand.model_validate(data) - if command_type == CommandType.PAUSE: - return PauseCommand.model_validate(data) - if command_type == CommandType.UPDATE_VARIABLES: - return UpdateVariablesCommand.model_validate(data) - - # For other command types, use base class - return GraphEngineCommand.model_validate(data) - - except (ValueError, TypeError): - return None - - def _has_pending_commands(self) -> bool: - """ - Check and consume the pending marker to avoid unnecessary list reads. - - Returns: - True if commands should be fetched from Redis. - """ - with self._redis.pipeline() as pipe: - pipe.get(self._pending_key) - pipe.delete(self._pending_key) - pending_value, _ = pipe.execute() - - return pending_value is not None diff --git a/api/graphon/graph_engine/command_processing/__init__.py b/api/graphon/graph_engine/command_processing/__init__.py deleted file mode 100644 index 7b4f0dfff7..0000000000 --- a/api/graphon/graph_engine/command_processing/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Command processing subsystem for graph engine. - -This package handles external commands sent to the engine -during execution. -""" - -from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler -from .command_processor import CommandProcessor - -__all__ = [ - "AbortCommandHandler", - "CommandProcessor", - "PauseCommandHandler", - "UpdateVariablesCommandHandler", -] diff --git a/api/graphon/graph_engine/command_processing/command_handlers.py b/api/graphon/graph_engine/command_processing/command_handlers.py deleted file mode 100644 index ad92fd1abb..0000000000 --- a/api/graphon/graph_engine/command_processing/command_handlers.py +++ /dev/null @@ -1,56 +0,0 @@ -import logging -from typing import final - -from typing_extensions import override - -from graphon.entities.pause_reason import SchedulingPause -from graphon.runtime import VariablePool - -from ..domain.graph_execution import GraphExecution -from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand -from .command_processor import CommandHandler - -logger = logging.getLogger(__name__) - - -@final -class AbortCommandHandler(CommandHandler): - @override - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - assert isinstance(command, AbortCommand) - logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason) - execution.abort(command.reason or "User requested abort") - - -@final -class PauseCommandHandler(CommandHandler): - @override - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - assert isinstance(command, PauseCommand) - logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason) - # Convert string reason to PauseReason if needed - reason = command.reason - pause_reason = SchedulingPause(message=reason) - execution.pause(pause_reason) - - -@final -class UpdateVariablesCommandHandler(CommandHandler): - def __init__(self, variable_pool: VariablePool) -> None: - self._variable_pool = variable_pool - - @override - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - assert isinstance(command, UpdateVariablesCommand) - for update in command.updates: - try: - variable = update.value - self._variable_pool.add(variable.selector, variable) - logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id) - except ValueError as exc: - logger.warning( - "Skipping invalid variable selector %s for workflow %s: %s", - getattr(update.value, "selector", None), - execution.workflow_id, - exc, - ) diff --git a/api/graphon/graph_engine/command_processing/command_processor.py b/api/graphon/graph_engine/command_processing/command_processor.py deleted file mode 100644 index 942c2d77a5..0000000000 --- a/api/graphon/graph_engine/command_processing/command_processor.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Main command processor for handling external commands. -""" - -import logging -from typing import Protocol, final - -from ..domain.graph_execution import GraphExecution -from ..entities.commands import GraphEngineCommand -from ..protocols.command_channel import CommandChannel - -logger = logging.getLogger(__name__) - - -class CommandHandler(Protocol): - """Protocol for command handlers.""" - - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ... - - -@final -class CommandProcessor: - """ - Processes external commands sent to the engine. - - This polls the command channel and dispatches commands to - appropriate handlers. - """ - - def __init__( - self, - command_channel: CommandChannel, - graph_execution: GraphExecution, - ) -> None: - """ - Initialize the command processor. - - Args: - command_channel: Channel for receiving commands - graph_execution: Graph execution aggregate - """ - self._command_channel = command_channel - self._graph_execution = graph_execution - self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {} - - def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None: - """ - Register a handler for a command type. - - Args: - command_type: Type of command to handle - handler: Handler for the command - """ - self._handlers[command_type] = handler - - def process_commands(self) -> None: - """Check for and process any pending commands.""" - try: - commands = self._command_channel.fetch_commands() - for command in commands: - self._handle_command(command) - except Exception as e: - logger.warning("Error processing commands: %s", e) - - def _handle_command(self, command: GraphEngineCommand) -> None: - """ - Handle a single command. - - Args: - command: The command to handle - """ - handler = self._handlers.get(type(command)) - if handler: - try: - handler.handle(command, self._graph_execution) - except Exception: - logger.exception("Error handling command %s", command.__class__.__name__) - else: - logger.warning("No handler registered for command: %s", command.__class__.__name__) diff --git a/api/graphon/graph_engine/config.py b/api/graphon/graph_engine/config.py deleted file mode 100644 index d56a69cee0..0000000000 --- a/api/graphon/graph_engine/config.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -GraphEngine configuration models. -""" - -from pydantic import BaseModel, ConfigDict - - -class GraphEngineConfig(BaseModel): - """Configuration for GraphEngine worker pool scaling.""" - - model_config = ConfigDict(frozen=True) - - min_workers: int = 1 - max_workers: int = 5 - scale_up_threshold: int = 3 - scale_down_idle_time: float = 5.0 diff --git a/api/graphon/graph_engine/domain/__init__.py b/api/graphon/graph_engine/domain/__init__.py deleted file mode 100644 index 9e9afe4c21..0000000000 --- a/api/graphon/graph_engine/domain/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Domain models for graph engine. - -This package contains the core domain entities, value objects, and aggregates -that represent the business concepts of workflow graph execution. -""" - -from .graph_execution import GraphExecution -from .node_execution import NodeExecution - -__all__ = [ - "GraphExecution", - "NodeExecution", -] diff --git a/api/graphon/graph_engine/domain/graph_execution.py b/api/graphon/graph_engine/domain/graph_execution.py deleted file mode 100644 index 9c0c7d1624..0000000000 --- a/api/graphon/graph_engine/domain/graph_execution.py +++ /dev/null @@ -1,242 +0,0 @@ -"""GraphExecution aggregate root managing the overall graph execution state.""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from importlib import import_module -from typing import Literal - -from pydantic import BaseModel, Field - -from graphon.entities.pause_reason import PauseReason -from graphon.enums import NodeState -from graphon.runtime.graph_runtime_state import GraphExecutionProtocol - -from .node_execution import NodeExecution - - -class GraphExecutionErrorState(BaseModel): - """Serializable representation of an execution error.""" - - module: str = Field(description="Module containing the exception class") - qualname: str = Field(description="Qualified name of the exception class") - message: str | None = Field(default=None, description="Exception message string") - - -class NodeExecutionState(BaseModel): - """Serializable representation of a node execution entity.""" - - node_id: str - state: NodeState = Field(default=NodeState.UNKNOWN) - retry_count: int = Field(default=0) - execution_id: str | None = Field(default=None) - error: str | None = Field(default=None) - - -class GraphExecutionState(BaseModel): - """Pydantic model describing serialized GraphExecution state.""" - - type: Literal["GraphExecution"] = Field(default="GraphExecution") - version: str = Field(default="1.0") - workflow_id: str - started: bool = Field(default=False) - completed: bool = Field(default=False) - aborted: bool = Field(default=False) - paused: bool = Field(default=False) - pause_reasons: list[PauseReason] = Field(default_factory=list) - error: GraphExecutionErrorState | None = Field(default=None) - exceptions_count: int = Field(default=0) - node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState]) - - -def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None: - """Convert an exception into its serializable representation.""" - - if error is None: - return None - - return GraphExecutionErrorState( - module=error.__class__.__module__, - qualname=error.__class__.__qualname__, - message=str(error), - ) - - -def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]: - """Locate an exception class from its module and qualified name.""" - - module = import_module(module_name) - attr: object = module - for part in qualname.split("."): - attr = getattr(attr, part) - - if isinstance(attr, type) and issubclass(attr, Exception): - return attr - - raise TypeError(f"{qualname} in {module_name} is not an Exception subclass") - - -def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None: - """Reconstruct an exception instance from serialized data.""" - - if state is None: - return None - - try: - exception_class = _resolve_exception_class(state.module, state.qualname) - if state.message is None: - return exception_class() - return exception_class(state.message) - except Exception: - # Fallback to RuntimeError when reconstruction fails - if state.message is None: - return RuntimeError(state.qualname) - return RuntimeError(state.message) - - -@dataclass -class GraphExecution: - """ - Aggregate root for graph execution. - - This manages the overall execution state of a workflow graph, - coordinating between multiple node executions. - """ - - workflow_id: str - started: bool = False - completed: bool = False - aborted: bool = False - paused: bool = False - pause_reasons: list[PauseReason] = field(default_factory=list) - error: Exception | None = None - node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution]) - exceptions_count: int = 0 - - def start(self) -> None: - """Mark the graph execution as started.""" - if self.started: - raise RuntimeError("Graph execution already started") - self.started = True - - def complete(self) -> None: - """Mark the graph execution as completed.""" - if not self.started: - raise RuntimeError("Cannot complete execution that hasn't started") - if self.completed: - raise RuntimeError("Graph execution already completed") - self.completed = True - - def abort(self, reason: str) -> None: - """Abort the graph execution.""" - self.aborted = True - self.error = RuntimeError(f"Aborted: {reason}") - - def pause(self, reason: PauseReason) -> None: - """Pause the graph execution without marking it complete.""" - if self.completed: - raise RuntimeError("Cannot pause execution that has completed") - if self.aborted: - raise RuntimeError("Cannot pause execution that has been aborted") - self.paused = True - self.pause_reasons.append(reason) - - def fail(self, error: Exception) -> None: - """Mark the graph execution as failed.""" - self.error = error - self.completed = True - - def get_or_create_node_execution(self, node_id: str) -> NodeExecution: - """Get or create a node execution entity.""" - if node_id not in self.node_executions: - self.node_executions[node_id] = NodeExecution(node_id=node_id) - return self.node_executions[node_id] - - @property - def is_running(self) -> bool: - """Check if the execution is currently running.""" - return self.started and not self.completed and not self.aborted and not self.paused - - @property - def is_paused(self) -> bool: - """Check if the execution is currently paused.""" - return self.paused - - @property - def has_error(self) -> bool: - """Check if the execution has encountered an error.""" - return self.error is not None - - @property - def error_message(self) -> str | None: - """Get the error message if an error exists.""" - if not self.error: - return None - return str(self.error) - - def dumps(self) -> str: - """Serialize the aggregate state into a JSON string.""" - - node_states = [ - NodeExecutionState( - node_id=node_id, - state=node_execution.state, - retry_count=node_execution.retry_count, - execution_id=node_execution.execution_id, - error=node_execution.error, - ) - for node_id, node_execution in sorted(self.node_executions.items()) - ] - - state = GraphExecutionState( - workflow_id=self.workflow_id, - started=self.started, - completed=self.completed, - aborted=self.aborted, - paused=self.paused, - pause_reasons=self.pause_reasons, - error=_serialize_error(self.error), - exceptions_count=self.exceptions_count, - node_executions=node_states, - ) - - return state.model_dump_json() - - def loads(self, data: str) -> None: - """Restore aggregate state from a serialized JSON string.""" - - state = GraphExecutionState.model_validate_json(data) - - if state.type != "GraphExecution": - raise ValueError(f"Invalid serialized data type: {state.type}") - - if state.version != "1.0": - raise ValueError(f"Unsupported serialized version: {state.version}") - - if self.workflow_id != state.workflow_id: - raise ValueError("Serialized workflow_id does not match aggregate identity") - - self.started = state.started - self.completed = state.completed - self.aborted = state.aborted - self.paused = state.paused - self.pause_reasons = state.pause_reasons - self.error = _deserialize_error(state.error) - self.exceptions_count = state.exceptions_count - self.node_executions = { - item.node_id: NodeExecution( - node_id=item.node_id, - state=item.state, - retry_count=item.retry_count, - execution_id=item.execution_id, - error=item.error, - ) - for item in state.node_executions - } - - def record_node_failure(self) -> None: - """Increment the count of node failures encountered during execution.""" - self.exceptions_count += 1 - - -_: GraphExecutionProtocol = GraphExecution(workflow_id="") diff --git a/api/graphon/graph_engine/domain/node_execution.py b/api/graphon/graph_engine/domain/node_execution.py deleted file mode 100644 index dafd6ccd8a..0000000000 --- a/api/graphon/graph_engine/domain/node_execution.py +++ /dev/null @@ -1,45 +0,0 @@ -""" -NodeExecution entity representing a node's execution state. -""" - -from dataclasses import dataclass - -from graphon.enums import NodeState - - -@dataclass -class NodeExecution: - """ - Entity representing the execution state of a single node. - - This is a mutable entity that tracks the runtime state of a node - during graph execution. - """ - - node_id: str - state: NodeState = NodeState.UNKNOWN - retry_count: int = 0 - execution_id: str | None = None - error: str | None = None - - def mark_started(self, execution_id: str) -> None: - """Mark the node as started with an execution ID.""" - self.state = NodeState.TAKEN - self.execution_id = execution_id - - def mark_taken(self) -> None: - """Mark the node as successfully completed.""" - self.state = NodeState.TAKEN - self.error = None - - def mark_failed(self, error: str) -> None: - """Mark the node as failed with an error.""" - self.error = error - - def mark_skipped(self) -> None: - """Mark the node as skipped.""" - self.state = NodeState.SKIPPED - - def increment_retry(self) -> None: - """Increment the retry count for this node.""" - self.retry_count += 1 diff --git a/api/graphon/graph_engine/entities/__init__.py b/api/graphon/graph_engine/entities/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/graph_engine/entities/commands.py b/api/graphon/graph_engine/entities/commands.py deleted file mode 100644 index 25ebc804b6..0000000000 --- a/api/graphon/graph_engine/entities/commands.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -GraphEngine command entities for external control. - -This module defines command types that can be sent to a running GraphEngine -instance to control its execution flow. -""" - -from collections.abc import Sequence -from enum import StrEnum, auto -from typing import Any - -from pydantic import BaseModel, Field - -from graphon.variables.variables import Variable - - -class CommandType(StrEnum): - """Types of commands that can be sent to GraphEngine.""" - - ABORT = auto() - PAUSE = auto() - UPDATE_VARIABLES = auto() - - -class GraphEngineCommand(BaseModel): - """Base class for all GraphEngine commands.""" - - command_type: CommandType = Field(..., description="Type of command") - payload: dict[str, Any] | None = Field(default=None, description="Optional command payload") - - -class AbortCommand(GraphEngineCommand): - """Command to abort a running workflow execution.""" - - command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command") - reason: str | None = Field(default=None, description="Optional reason for abort") - - -class PauseCommand(GraphEngineCommand): - """Command to pause a running workflow execution.""" - - command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command") - reason: str = Field(default="unknown reason", description="reason for pause") - - -class VariableUpdate(BaseModel): - """Represents a single variable update instruction.""" - - value: Variable = Field(description="New variable value") - - -class UpdateVariablesCommand(GraphEngineCommand): - """Command to update a group of variables in the variable pool.""" - - command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command") - updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates") diff --git a/api/graphon/graph_engine/error_handler.py b/api/graphon/graph_engine/error_handler.py deleted file mode 100644 index 43ce8bb502..0000000000 --- a/api/graphon/graph_engine/error_handler.py +++ /dev/null @@ -1,213 +0,0 @@ -""" -Main error handler that coordinates error strategies. -""" - -import logging -import time -from typing import TYPE_CHECKING, final - -from graphon.enums import ( - ErrorStrategy as ErrorStrategyEnum, -) -from graphon.enums import ( - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.graph import Graph -from graphon.graph_events import ( - GraphNodeEventBase, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunRetryEvent, -) -from graphon.node_events import NodeRunResult - -if TYPE_CHECKING: - from .domain import GraphExecution - -logger = logging.getLogger(__name__) - - -@final -class ErrorHandler: - """ - Coordinates error handling strategies for node failures. - - This acts as a facade for the various error strategies, - selecting and applying the appropriate strategy based on - node configuration. - """ - - def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None: - """ - Initialize the error handler. - - Args: - graph: The workflow graph - graph_execution: The graph execution state - """ - self._graph = graph - self._graph_execution = graph_execution - - def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None: - """ - Handle a node failure event. - - Selects and applies the appropriate error strategy based on - the node's configuration. - - Args: - event: The node failure event - - Returns: - Optional new event to process, or None to abort - """ - node = self._graph.nodes[event.node_id] - # Get retry count from NodeExecution - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - retry_count = node_execution.retry_count - - # First check if retry is configured and not exhausted - if node.retry and retry_count < node.retry_config.max_retries: - result = self._handle_retry(event, retry_count) - if result: - # Retry count will be incremented when NodeRunRetryEvent is handled - return result - - # Apply configured error strategy - strategy = node.error_strategy - - match strategy: - case None: - return self._handle_abort(event) - case ErrorStrategyEnum.FAIL_BRANCH: - return self._handle_fail_branch(event) - case ErrorStrategyEnum.DEFAULT_VALUE: - return self._handle_default_value(event) - - def _handle_abort(self, event: NodeRunFailedEvent): - """ - Handle error by aborting execution. - - This is the default strategy when no other strategy is specified. - It stops the entire graph execution when a node fails. - - Args: - event: The failure event - - Returns: - None - signals abortion - """ - logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error) - # Return None to signal that execution should stop - - def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int): - """ - Handle error by retrying the node. - - This strategy re-attempts node execution up to a configured - maximum number of retries with configurable intervals. - - Args: - event: The failure event - retry_count: Current retry attempt count - - Returns: - NodeRunRetryEvent if retry should occur, None otherwise - """ - node = self._graph.nodes[event.node_id] - - # Check if we've exceeded max retries - if not node.retry or retry_count >= node.retry_config.max_retries: - return None - - # Wait for retry interval - time.sleep(node.retry_config.retry_interval_seconds) - - # Create retry event - return NodeRunRetryEvent( - id=event.id, - node_title=node.title, - node_id=event.node_id, - node_type=event.node_type, - node_run_result=event.node_run_result, - start_at=event.start_at, - error=event.error, - retry_index=retry_count + 1, - ) - - def _handle_fail_branch(self, event: NodeRunFailedEvent): - """ - Handle error by taking the fail branch. - - This strategy converts failures to exceptions and routes execution - through a designated fail-branch edge. - - Args: - event: The failure event - - Returns: - NodeRunExceptionEvent to continue via fail branch - """ - outputs = { - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - finished_at=event.finished_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - edge_source_handle="fail-branch", - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH, - }, - ), - error=event.error, - ) - - def _handle_default_value(self, event: NodeRunFailedEvent): - """ - Handle error by using default values. - - This strategy allows nodes to fail gracefully by providing - predefined default output values. - - Args: - event: The failure event - - Returns: - NodeRunExceptionEvent with default values - """ - node = self._graph.nodes[event.node_id] - - outputs = { - **node.default_value_dict, - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - finished_at=event.finished_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE, - }, - ), - error=event.error, - ) diff --git a/api/graphon/graph_engine/event_management/__init__.py b/api/graphon/graph_engine/event_management/__init__.py deleted file mode 100644 index f6c3c0f753..0000000000 --- a/api/graphon/graph_engine/event_management/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Event management subsystem for graph engine. - -This package handles event routing, collection, and emission for -workflow graph execution events. -""" - -from .event_handlers import EventHandler -from .event_manager import EventManager - -__all__ = [ - "EventHandler", - "EventManager", -] diff --git a/api/graphon/graph_engine/event_management/event_handlers.py b/api/graphon/graph_engine/event_management/event_handlers.py deleted file mode 100644 index 184148280d..0000000000 --- a/api/graphon/graph_engine/event_management/event_handlers.py +++ /dev/null @@ -1,367 +0,0 @@ -""" -Event handler implementations for different event types. -""" - -import logging -from collections.abc import Mapping -from functools import singledispatchmethod -from typing import TYPE_CHECKING, final - -from graphon.enums import ErrorStrategy, NodeExecutionType, NodeState -from graphon.graph import Graph -from graphon.graph_events import ( - GraphNodeEventBase, - NodeRunAgentLogEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState - -from ..domain.graph_execution import GraphExecution -from ..response_coordinator import ResponseStreamCoordinator - -if TYPE_CHECKING: - from ..error_handler import ErrorHandler - from ..graph_state_manager import GraphStateManager - from ..graph_traversal import EdgeProcessor - from .event_manager import EventManager - -logger = logging.getLogger(__name__) - - -@final -class EventHandler: - """ - Registry of event handlers for different event types. - - This centralizes the business logic for handling specific events, - keeping it separate from the routing and collection infrastructure. - """ - - def __init__( - self, - graph: Graph, - graph_runtime_state: GraphRuntimeState, - graph_execution: GraphExecution, - response_coordinator: ResponseStreamCoordinator, - event_collector: "EventManager", - edge_processor: "EdgeProcessor", - state_manager: "GraphStateManager", - error_handler: "ErrorHandler", - ) -> None: - """ - Initialize the event handler registry. - - Args: - graph: The workflow graph - graph_runtime_state: Runtime state with variable pool - graph_execution: Graph execution aggregate - response_coordinator: Response stream coordinator - event_collector: Event manager for collecting events - edge_processor: Edge processor for edge traversal - state_manager: Unified state manager - error_handler: Error handler - """ - self._graph = graph - self._graph_runtime_state = graph_runtime_state - self._graph_execution = graph_execution - self._response_coordinator = response_coordinator - self._event_collector = event_collector - self._edge_processor = edge_processor - self._state_manager = state_manager - self._error_handler = error_handler - - def dispatch(self, event: GraphNodeEventBase) -> None: - """ - Handle any node event by dispatching to the appropriate handler. - - Args: - event: The event to handle - """ - if isinstance(event, NodeRunVariableUpdatedEvent): - self._dispatch(event) - return - - # Events in loops or iterations are always collected - if event.in_loop_id or event.in_iteration_id: - self._event_collector.collect(event) - return - return self._dispatch(event) - - @singledispatchmethod - def _dispatch(self, event: GraphNodeEventBase) -> None: - self._event_collector.collect(event) - logger.warning("Unhandled event type: %s", type(event).__name__) - - @_dispatch.register(NodeRunIterationStartedEvent) - @_dispatch.register(NodeRunIterationNextEvent) - @_dispatch.register(NodeRunIterationSucceededEvent) - @_dispatch.register(NodeRunIterationFailedEvent) - @_dispatch.register(NodeRunLoopStartedEvent) - @_dispatch.register(NodeRunLoopNextEvent) - @_dispatch.register(NodeRunLoopSucceededEvent) - @_dispatch.register(NodeRunLoopFailedEvent) - @_dispatch.register(NodeRunAgentLogEvent) - @_dispatch.register(NodeRunRetrieverResourceEvent) - def _(self, event: GraphNodeEventBase) -> None: - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunStartedEvent) -> None: - """ - Handle node started event. - - Args: - event: The node started event - """ - # Track execution in domain model - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - is_initial_attempt = node_execution.retry_count == 0 - node_execution.mark_started(event.id) - self._graph_runtime_state.increment_node_run_steps() - - # Track in response coordinator for stream ordering - self._response_coordinator.track_node_execution(event.node_id, event.id) - - # Collect the event only for the first attempt; retries remain silent - if is_initial_attempt: - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunStreamChunkEvent) -> None: - """ - Handle stream chunk event with full processing. - - Args: - event: The stream chunk event - """ - # Process with response coordinator - streaming_events = list(self._response_coordinator.intercept_event(event)) - - # Collect all events - for stream_event in streaming_events: - self._event_collector.collect(stream_event) - - @_dispatch.register - def _(self, event: NodeRunVariableUpdatedEvent) -> None: - """ - Apply a node-requested variable mutation before downstream observers run. - - The event is collected like other node events so parent/container engines can - forward the updated payload to outer layers, including persistence listeners. - """ - self._graph_runtime_state.variable_pool.add(event.variable.selector, event.variable) - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunSucceededEvent) -> None: - """ - Handle node success by coordinating subsystems. - - This method coordinates between different subsystems to process - node completion, handle edges, and trigger downstream execution. - - Args: - event: The node succeeded event - """ - # Update domain model - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.mark_taken() - - self._accumulate_node_usage(event.node_run_result.llm_usage) - - # Store outputs in variable pool - self._store_node_outputs(event.node_id, event.node_run_result.outputs) - - # Forward to response coordinator and emit streaming events - streaming_events = self._response_coordinator.intercept_event(event) - for stream_event in streaming_events: - self._event_collector.collect(stream_event) - - # Process edges and get ready nodes - node = self._graph.nodes[event.node_id] - if node.execution_type == NodeExecutionType.BRANCH: - ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion( - event.node_id, event.node_run_result.edge_source_handle - ) - else: - ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id) - - # Collect streaming events from edge processing - for edge_event in edge_streaming_events: - self._event_collector.collect(edge_event) - - # Enqueue ready nodes - if self._graph_execution.is_paused: - for node_id in ready_nodes: - self._graph_runtime_state.register_deferred_node(node_id) - else: - for node_id in ready_nodes: - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) - - # Update execution tracking - self._state_manager.finish_execution(event.node_id) - - # Handle response node outputs - if node.execution_type == NodeExecutionType.RESPONSE: - self._update_response_outputs(event.node_run_result.outputs) - - # Collect the event - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunPauseRequestedEvent) -> None: - """Handle pause requests emitted by nodes.""" - - pause_reason = event.reason - self._graph_execution.pause(pause_reason) - self._state_manager.finish_execution(event.node_id) - if event.node_id in self._graph.nodes: - self._graph.nodes[event.node_id].state = NodeState.UNKNOWN - self._graph_runtime_state.register_paused_node(event.node_id) - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunFailedEvent) -> None: - """ - Handle node failure using error handler. - - Args: - event: The node failed event - """ - # Update domain model - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.mark_failed(event.error) - self._graph_execution.record_node_failure() - - self._accumulate_node_usage(event.node_run_result.llm_usage) - - result = self._error_handler.handle_node_failure(event) - - if result: - # Process the resulting event (retry, exception, etc.) - self.dispatch(result) - else: - # Abort execution - self._graph_execution.fail(RuntimeError(event.error)) - self._event_collector.collect(event) - self._state_manager.finish_execution(event.node_id) - - @_dispatch.register - def _(self, event: NodeRunExceptionEvent) -> None: - """ - Handle node exception event (fail-branch strategy). - - Args: - event: The node exception event - """ - # Node continues via fail-branch/default-value, treat as completion - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.mark_taken() - - self._accumulate_node_usage(event.node_run_result.llm_usage) - - # Persist outputs produced by the exception strategy (e.g. default values) - self._store_node_outputs(event.node_id, event.node_run_result.outputs) - - node = self._graph.nodes[event.node_id] - - if node.error_strategy == ErrorStrategy.DEFAULT_VALUE: - ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id) - elif node.error_strategy == ErrorStrategy.FAIL_BRANCH: - ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion( - event.node_id, event.node_run_result.edge_source_handle - ) - else: - raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}") - - for edge_event in edge_streaming_events: - self._event_collector.collect(edge_event) - - for node_id in ready_nodes: - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) - - # Update response outputs if applicable - if node.execution_type == NodeExecutionType.RESPONSE: - self._update_response_outputs(event.node_run_result.outputs) - - self._state_manager.finish_execution(event.node_id) - - # Collect the exception event for observers - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunRetryEvent) -> None: - """ - Handle node retry event. - - Args: - event: The node retry event - """ - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.increment_retry() - - # Finish the previous attempt before re-queuing the node - self._state_manager.finish_execution(event.node_id) - - # Emit retry event for observers - self._event_collector.collect(event) - - # Re-queue node for execution - self._state_manager.enqueue_node(event.node_id) - self._state_manager.start_execution(event.node_id) - - def _accumulate_node_usage(self, usage: LLMUsage) -> None: - """Accumulate token usage into the shared runtime state.""" - if usage.total_tokens <= 0: - return - - self._graph_runtime_state.add_tokens(usage.total_tokens) - - current_usage = self._graph_runtime_state.llm_usage - if current_usage.total_tokens == 0: - self._graph_runtime_state.llm_usage = usage - else: - self._graph_runtime_state.llm_usage = current_usage.plus(usage) - - def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None: - """ - Store node outputs in the variable pool. - - Args: - event: The node succeeded event containing outputs - """ - for variable_name, variable_value in outputs.items(): - self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value) - - def _update_response_outputs(self, outputs: Mapping[str, object]) -> None: - """Update response outputs for response nodes.""" - # TODO: Design a mechanism for nodes to notify the engine about how to update outputs - # in runtime state, rather than allowing nodes to directly access runtime state. - for key, value in outputs.items(): - if key == "answer": - existing = self._graph_runtime_state.get_output("answer", "") - if existing: - self._graph_runtime_state.set_output("answer", f"{existing}{value}") - else: - self._graph_runtime_state.set_output("answer", value) - else: - self._graph_runtime_state.set_output(key, value) diff --git a/api/graphon/graph_engine/event_management/event_manager.py b/api/graphon/graph_engine/event_management/event_manager.py deleted file mode 100644 index 5b2fb365e9..0000000000 --- a/api/graphon/graph_engine/event_management/event_manager.py +++ /dev/null @@ -1,186 +0,0 @@ -""" -Unified event manager for collecting and emitting events. -""" - -import logging -import threading -import time -from collections.abc import Generator -from contextlib import contextmanager -from typing import final - -from graphon.graph_events import GraphEngineEvent - -from ..layers.base import GraphEngineLayer - -_logger = logging.getLogger(__name__) - - -@final -class ReadWriteLock: - """ - A read-write lock implementation that allows multiple concurrent readers - but only one writer at a time. - """ - - def __init__(self) -> None: - self._read_ready = threading.Condition(threading.RLock()) - self._readers = 0 - - def acquire_read(self) -> None: - """Acquire a read lock.""" - _ = self._read_ready.acquire() - try: - self._readers += 1 - finally: - self._read_ready.release() - - def release_read(self) -> None: - """Release a read lock.""" - _ = self._read_ready.acquire() - try: - self._readers -= 1 - if self._readers == 0: - self._read_ready.notify_all() - finally: - self._read_ready.release() - - def acquire_write(self) -> None: - """Acquire a write lock.""" - _ = self._read_ready.acquire() - while self._readers > 0: - _ = self._read_ready.wait() - - def release_write(self) -> None: - """Release a write lock.""" - self._read_ready.release() - - @contextmanager - def read_lock(self): - """Return a context manager for read locking.""" - self.acquire_read() - try: - yield - finally: - self.release_read() - - @contextmanager - def write_lock(self): - """Return a context manager for write locking.""" - self.acquire_write() - try: - yield - finally: - self.release_write() - - -@final -class EventManager: - """ - Unified event manager that collects, buffers, and emits events. - - This class combines event collection with event emission, providing - thread-safe event management with support for notifying layers and - streaming events to external consumers. - """ - - def __init__(self) -> None: - """Initialize the event manager.""" - self._events: list[GraphEngineEvent] = [] - self._lock = ReadWriteLock() - self._layers: list[GraphEngineLayer] = [] - self._execution_complete = threading.Event() - - def set_layers(self, layers: list[GraphEngineLayer]) -> None: - """ - Set the layers to notify on event collection. - - Args: - layers: List of layers to notify - """ - self._layers = layers - - def notify_layers(self, event: GraphEngineEvent) -> None: - """Notify registered layers about an event without buffering it.""" - self._notify_layers(event) - - def collect(self, event: GraphEngineEvent) -> None: - """ - Thread-safe method to collect an event. - - Args: - event: The event to collect - """ - with self._lock.write_lock(): - self._events.append(event) - - # NOTE: `_notify_layers` is intentionally called outside the critical section - # to minimize lock contention and avoid blocking other readers or writers. - # - # The public `notify_layers` method also does not use a write lock, - # so protecting `_notify_layers` with a lock here is unnecessary. - self._notify_layers(event) - - def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]: - """ - Get new events starting from a specific index. - - Args: - start_index: The index to start from - - Returns: - List of new events - """ - with self._lock.read_lock(): - return list(self._events[start_index:]) - - def _event_count(self) -> int: - """ - Get the current count of collected events. - - Returns: - Number of collected events - """ - with self._lock.read_lock(): - return len(self._events) - - def mark_complete(self) -> None: - """Mark execution as complete to stop the event emission generator.""" - self._execution_complete.set() - - def emit_events(self) -> Generator[GraphEngineEvent, None, None]: - """ - Generator that yields events as they're collected. - - Yields: - GraphEngineEvent instances as they're processed - """ - yielded_count = 0 - - while not self._execution_complete.is_set() or yielded_count < self._event_count(): - # Get new events since last yield - new_events = self._get_new_events(yielded_count) - - # Yield any new events - for event in new_events: - yield event - yielded_count += 1 - - # Small sleep to avoid busy waiting - if not self._execution_complete.is_set() and not new_events: - time.sleep(0.001) - - def _notify_layers(self, event: GraphEngineEvent) -> None: - """ - Notify all layers of an event. - - Layer exceptions are caught and logged to prevent disrupting collection. - - Args: - event: The event to send to layers - """ - for layer in self._layers: - try: - layer.on_event(event) - except Exception: - _logger.exception("Error in layer on_event, layer_type=%s", type(layer)) diff --git a/api/graphon/graph_engine/graph_engine.py b/api/graphon/graph_engine/graph_engine.py deleted file mode 100644 index 32e0e60502..0000000000 --- a/api/graphon/graph_engine/graph_engine.py +++ /dev/null @@ -1,377 +0,0 @@ -""" -QueueBasedGraphEngine - Main orchestrator for queue-based workflow execution. - -This engine uses a modular architecture with separated packages following -Domain-Driven Design principles for improved maintainability and testability. -""" - -from __future__ import annotations - -import logging -import queue -from collections.abc import Generator -from typing import TYPE_CHECKING, cast, final - -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import NodeExecutionType -from graphon.graph import Graph -from graphon.graph_events import ( - GraphEngineEvent, - GraphNodeEventBase, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) -from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from graphon.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol - -if TYPE_CHECKING: # pragma: no cover - used only for static analysis - from graphon.runtime.graph_runtime_state import GraphProtocol - -from .command_processing import ( - AbortCommandHandler, - CommandProcessor, - PauseCommandHandler, - UpdateVariablesCommandHandler, -) -from .config import GraphEngineConfig -from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand -from .error_handler import ErrorHandler -from .event_management import EventHandler, EventManager -from .graph_state_manager import GraphStateManager -from .graph_traversal import EdgeProcessor, SkipPropagator -from .layers.base import GraphEngineLayer -from .orchestration import Dispatcher, ExecutionCoordinator -from .protocols.command_channel import CommandChannel -from .worker_management import WorkerPool - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.graph_engine.domain.graph_execution import GraphExecution - from graphon.graph_engine.response_coordinator import ResponseStreamCoordinator - -logger = logging.getLogger(__name__) - - -_DEFAULT_CONFIG = GraphEngineConfig() - - -@final -class GraphEngine: - """ - Queue-based graph execution engine. - - Uses a modular architecture that delegates responsibilities to specialized - subsystems, following Domain-Driven Design and SOLID principles. - """ - - def __init__( - self, - workflow_id: str, - graph: Graph, - graph_runtime_state: GraphRuntimeState, - command_channel: CommandChannel, - config: GraphEngineConfig = _DEFAULT_CONFIG, - child_engine_builder: ChildGraphEngineBuilderProtocol | None = None, - ) -> None: - """Initialize the graph engine with all subsystems and dependencies.""" - - # Bind runtime state to current workflow context - self._graph = graph - self._graph_runtime_state = graph_runtime_state - self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) - self._command_channel = command_channel - self._config = config - self._layers: list[GraphEngineLayer] = [] - self._child_engine_builder = child_engine_builder - if child_engine_builder is not None: - self._graph_runtime_state.bind_child_engine_builder(child_engine_builder) - - # Graph execution tracks the overall execution state - self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution) - self._graph_execution.workflow_id = workflow_id - - # === Execution Queues === - self._ready_queue = self._graph_runtime_state.ready_queue - - # Queue for events generated during execution - self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() - - # === State Management === - # Unified state manager handles all node state transitions and queue operations - self._state_manager = GraphStateManager(self._graph, self._ready_queue) - - # === Response Coordination === - # Coordinates response streaming from response nodes - self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator) - - # === Event Management === - # Event manager handles both collection and emission of events - self._event_manager = EventManager() - - # === Error Handling === - # Centralized error handler for graph execution errors - self._error_handler = ErrorHandler(self._graph, self._graph_execution) - - # === Graph Traversal Components === - # Propagates skip status through the graph when conditions aren't met - self._skip_propagator = SkipPropagator( - graph=self._graph, - state_manager=self._state_manager, - ) - - # Processes edges to determine next nodes after execution - # Also handles conditional branching and route selection - self._edge_processor = EdgeProcessor( - graph=self._graph, - state_manager=self._state_manager, - response_coordinator=self._response_coordinator, - skip_propagator=self._skip_propagator, - ) - - # === Command Processing === - # Processes external commands (e.g., abort requests) - self._command_processor = CommandProcessor( - command_channel=self._command_channel, - graph_execution=self._graph_execution, - ) - - # Register command handlers - abort_handler = AbortCommandHandler() - self._command_processor.register_handler(AbortCommand, abort_handler) - - pause_handler = PauseCommandHandler() - self._command_processor.register_handler(PauseCommand, pause_handler) - - update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool) - self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler) - - # === Worker Pool Setup === - # Create worker pool for parallel node execution - self._worker_pool = WorkerPool( - ready_queue=self._ready_queue, - event_queue=self._event_queue, - graph=self._graph, - layers=self._layers, - execution_context=self._graph_runtime_state.execution_context, - config=self._config, - ) - - # === Orchestration === - # Coordinates the overall execution lifecycle - self._execution_coordinator = ExecutionCoordinator( - graph_execution=self._graph_execution, - state_manager=self._state_manager, - command_processor=self._command_processor, - worker_pool=self._worker_pool, - ) - - # === Event Handler Registry === - # Central registry for handling all node execution events - self._event_handler_registry = EventHandler( - graph=self._graph, - graph_runtime_state=self._graph_runtime_state, - graph_execution=self._graph_execution, - response_coordinator=self._response_coordinator, - event_collector=self._event_manager, - edge_processor=self._edge_processor, - state_manager=self._state_manager, - error_handler=self._error_handler, - ) - - # Dispatches events and manages execution flow - self._dispatcher = Dispatcher( - event_queue=self._event_queue, - event_handler=self._event_handler_registry, - execution_coordinator=self._execution_coordinator, - event_emitter=self._event_manager, - ) - - # === Validation === - # Ensure all nodes share the same GraphRuntimeState instance - self._validate_graph_state_consistency() - - def _validate_graph_state_consistency(self) -> None: - """Validate that all nodes share the same GraphRuntimeState.""" - expected_state_id = id(self._graph_runtime_state) - for node in self._graph.nodes.values(): - if id(node.graph_runtime_state) != expected_state_id: - raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance") - - def _bind_layer_context( - self, - layer: GraphEngineLayer, - ) -> None: - layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel) - - def layer(self, layer: GraphEngineLayer) -> GraphEngine: - """Add a layer for extending functionality.""" - self._layers.append(layer) - self._bind_layer_context(layer) - return self - - def request_abort(self, reason: str | None = None) -> None: - """Queue an abort command for this engine.""" - self._command_channel.send_command(AbortCommand(reason=reason or "User requested abort")) - - def create_child_engine( - self, - *, - workflow_id: str, - graph_init_params: GraphInitParams, - root_node_id: str, - variable_pool: VariablePool | None = None, - ) -> GraphEngine: - return self._graph_runtime_state.create_child_engine( - workflow_id=workflow_id, - graph_init_params=graph_init_params, - root_node_id=root_node_id, - variable_pool=variable_pool, - ) - - def run(self) -> Generator[GraphEngineEvent, None, None]: - """ - Execute the graph using the modular architecture. - - Returns: - Generator yielding GraphEngineEvent instances - """ - try: - # Initialize layers - self._initialize_layers() - - is_resume = self._graph_execution.started - if not is_resume: - self._graph_execution.start() - else: - self._graph_execution.paused = False - self._graph_execution.pause_reasons = [] - - start_event = GraphRunStartedEvent( - reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL, - ) - self._event_manager.notify_layers(start_event) - yield start_event - - # Start subsystems - self._start_execution(resume=is_resume) - - # Yield events as they occur - yield from self._event_manager.emit_events() - - # Handle completion - if self._graph_execution.is_paused: - pause_reasons = self._graph_execution.pause_reasons - assert pause_reasons, "pause_reasons should not be empty when execution is paused." - # Ensure we have a valid PauseReason for the event - paused_event = GraphRunPausedEvent( - reasons=pause_reasons, - outputs=self._graph_runtime_state.outputs, - ) - self._event_manager.notify_layers(paused_event) - yield paused_event - elif self._graph_execution.aborted: - abort_reason = "Workflow execution aborted by user command" - if self._graph_execution.error: - abort_reason = str(self._graph_execution.error) - aborted_event = GraphRunAbortedEvent( - reason=abort_reason, - outputs=self._graph_runtime_state.outputs, - ) - self._event_manager.notify_layers(aborted_event) - yield aborted_event - elif self._graph_execution.has_error: - if self._graph_execution.error: - raise self._graph_execution.error - else: - outputs = self._graph_runtime_state.outputs - exceptions_count = self._graph_execution.exceptions_count - if exceptions_count > 0: - partial_event = GraphRunPartialSucceededEvent( - exceptions_count=exceptions_count, - outputs=outputs, - ) - self._event_manager.notify_layers(partial_event) - yield partial_event - else: - succeeded_event = GraphRunSucceededEvent( - outputs=outputs, - ) - self._event_manager.notify_layers(succeeded_event) - yield succeeded_event - - except Exception as e: - failed_event = GraphRunFailedEvent( - error=str(e), - exceptions_count=self._graph_execution.exceptions_count, - ) - self._event_manager.notify_layers(failed_event) - yield failed_event - raise - - finally: - self._stop_execution() - - def _initialize_layers(self) -> None: - """Initialize layers with context.""" - self._event_manager.set_layers(self._layers) - for layer in self._layers: - try: - layer.on_graph_start() - except Exception: - logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__) - - def _start_execution(self, *, resume: bool = False) -> None: - """Start execution subsystems.""" - paused_nodes: list[str] = [] - deferred_nodes: list[str] = [] - if resume: - paused_nodes = self._graph_runtime_state.consume_paused_nodes() - deferred_nodes = self._graph_runtime_state.consume_deferred_nodes() - - # Start worker pool (it calculates initial workers internally) - self._worker_pool.start() - - # Register response nodes - for node in self._graph.nodes.values(): - if node.execution_type == NodeExecutionType.RESPONSE: - self._response_coordinator.register(node.id) - - if not resume: - # Enqueue root node - root_node = self._graph.root_node - self._state_manager.enqueue_node(root_node.id) - self._state_manager.start_execution(root_node.id) - else: - seen_nodes: set[str] = set() - for node_id in paused_nodes + deferred_nodes: - if node_id in seen_nodes: - continue - seen_nodes.add(node_id) - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) - - # Start dispatcher - self._dispatcher.start() - - def _stop_execution(self) -> None: - """Stop execution subsystems.""" - self._dispatcher.stop() - self._worker_pool.stop() - # Don't mark complete here as the dispatcher already does it - - # Notify layers - for layer in self._layers: - try: - layer.on_graph_end(self._graph_execution.error) - except Exception: - logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__) - - # Public property accessors for attributes that need external access - @property - def graph_runtime_state(self) -> GraphRuntimeState: - """Get the graph runtime state.""" - return self._graph_runtime_state diff --git a/api/graphon/graph_engine/graph_state_manager.py b/api/graphon/graph_engine/graph_state_manager.py deleted file mode 100644 index ade8e403a8..0000000000 --- a/api/graphon/graph_engine/graph_state_manager.py +++ /dev/null @@ -1,290 +0,0 @@ -""" -Graph state manager that combines node, edge, and execution tracking. -""" - -import threading -from collections.abc import Sequence -from typing import TypedDict, final - -from graphon.enums import NodeState -from graphon.graph import Edge, Graph - -from .ready_queue import ReadyQueue - - -class EdgeStateAnalysis(TypedDict): - """Analysis result for edge states.""" - - has_unknown: bool - has_taken: bool - all_skipped: bool - - -@final -class GraphStateManager: - def __init__(self, graph: Graph, ready_queue: ReadyQueue) -> None: - """ - Initialize the state manager. - - Args: - graph: The workflow graph - ready_queue: Queue for nodes ready to execute - """ - self._graph = graph - self._ready_queue = ready_queue - self._lock = threading.RLock() - - # Execution tracking state - self._executing_nodes: set[str] = set() - - # ============= Node State Operations ============= - - def enqueue_node(self, node_id: str) -> None: - """ - Mark a node as TAKEN and add it to the ready queue. - - This combines the state transition and enqueueing operations - that always occur together when preparing a node for execution. - - Args: - node_id: The ID of the node to enqueue - """ - with self._lock: - self._graph.nodes[node_id].state = NodeState.TAKEN - self._ready_queue.put(node_id) - - def mark_node_skipped(self, node_id: str) -> None: - """ - Mark a node as SKIPPED. - - Args: - node_id: The ID of the node to skip - """ - with self._lock: - self._graph.nodes[node_id].state = NodeState.SKIPPED - - def is_node_ready(self, node_id: str) -> bool: - """ - Check if a node is ready to be executed. - - A node is ready when all its incoming edges from taken branches - have been satisfied. - - Args: - node_id: The ID of the node to check - - Returns: - True if the node is ready for execution - """ - with self._lock: - # Get all incoming edges to this node - incoming_edges = self._graph.get_incoming_edges(node_id) - - # If no incoming edges, node is always ready - if not incoming_edges: - return True - - # If any edge is UNKNOWN, node is not ready - if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges): - return False - - # Node is ready if at least one edge is TAKEN - return any(edge.state == NodeState.TAKEN for edge in incoming_edges) - - def get_node_state(self, node_id: str) -> NodeState: - """ - Get the current state of a node. - - Args: - node_id: The ID of the node - - Returns: - The current node state - """ - with self._lock: - return self._graph.nodes[node_id].state - - # ============= Edge State Operations ============= - - def mark_edge_taken(self, edge_id: str) -> None: - """ - Mark an edge as TAKEN. - - Args: - edge_id: The ID of the edge to mark - """ - with self._lock: - self._graph.edges[edge_id].state = NodeState.TAKEN - - def mark_edge_skipped(self, edge_id: str) -> None: - """ - Mark an edge as SKIPPED. - - Args: - edge_id: The ID of the edge to mark - """ - with self._lock: - self._graph.edges[edge_id].state = NodeState.SKIPPED - - def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis: - """ - Analyze the states of edges and return summary flags. - - Args: - edges: List of edges to analyze - - Returns: - Analysis result with state flags - """ - with self._lock: - states = {edge.state for edge in edges} - - return EdgeStateAnalysis( - has_unknown=NodeState.UNKNOWN in states, - has_taken=NodeState.TAKEN in states, - all_skipped=states == {NodeState.SKIPPED} if states else True, - ) - - def get_edge_state(self, edge_id: str) -> NodeState: - """ - Get the current state of an edge. - - Args: - edge_id: The ID of the edge - - Returns: - The current edge state - """ - with self._lock: - return self._graph.edges[edge_id].state - - def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]: - """ - Categorize branch edges into selected and unselected. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected edge - - Returns: - A tuple of (selected_edges, unselected_edges) - """ - with self._lock: - outgoing_edges = self._graph.get_outgoing_edges(node_id) - selected_edges: list[Edge] = [] - unselected_edges: list[Edge] = [] - - for edge in outgoing_edges: - if edge.source_handle == selected_handle: - selected_edges.append(edge) - else: - unselected_edges.append(edge) - - return selected_edges, unselected_edges - - # ============= Execution Tracking Operations ============= - - def start_execution(self, node_id: str) -> None: - """ - Mark a node as executing. - - Args: - node_id: The ID of the node starting execution - """ - with self._lock: - self._executing_nodes.add(node_id) - - def finish_execution(self, node_id: str) -> None: - """ - Mark a node as no longer executing. - - Args: - node_id: The ID of the node finishing execution - """ - with self._lock: - self._executing_nodes.discard(node_id) - - def is_executing(self, node_id: str) -> bool: - """ - Check if a node is currently executing. - - Args: - node_id: The ID of the node to check - - Returns: - True if the node is executing - """ - with self._lock: - return node_id in self._executing_nodes - - def get_executing_count(self) -> int: - """ - Get the count of currently executing nodes. - - Returns: - Number of executing nodes - """ - # This count is a best-effort snapshot and can change concurrently. - # Only use it for pause-drain checks where scheduling is already frozen. - with self._lock: - return len(self._executing_nodes) - - def get_executing_nodes(self) -> set[str]: - """ - Get a copy of the set of executing node IDs. - - Returns: - Set of node IDs currently executing - """ - with self._lock: - return self._executing_nodes.copy() - - def clear_executing(self) -> None: - """Clear all executing nodes.""" - with self._lock: - self._executing_nodes.clear() - - # ============= Composite Operations ============= - - def is_execution_complete(self) -> bool: - """ - Check if graph execution is complete. - - Execution is complete when: - - Ready queue is empty - - No nodes are executing - - Returns: - True if execution is complete - """ - with self._lock: - return self._ready_queue.empty() and len(self._executing_nodes) == 0 - - def get_queue_depth(self) -> int: - """ - Get the current depth of the ready queue. - - Returns: - Number of nodes in the ready queue - """ - return self._ready_queue.qsize() - - def get_execution_stats(self) -> dict[str, int]: - """ - Get execution statistics. - - Returns: - Dictionary with execution statistics - """ - with self._lock: - taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN) - skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED) - unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN) - - return { - "queue_depth": self._ready_queue.qsize(), - "executing": len(self._executing_nodes), - "taken_nodes": taken_nodes, - "skipped_nodes": skipped_nodes, - "unknown_nodes": unknown_nodes, - } diff --git a/api/graphon/graph_engine/graph_traversal/__init__.py b/api/graphon/graph_engine/graph_traversal/__init__.py deleted file mode 100644 index d629140d06..0000000000 --- a/api/graphon/graph_engine/graph_traversal/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Graph traversal subsystem for graph engine. - -This package handles graph navigation, edge processing, -and skip propagation logic. -""" - -from .edge_processor import EdgeProcessor -from .skip_propagator import SkipPropagator - -__all__ = [ - "EdgeProcessor", - "SkipPropagator", -] diff --git a/api/graphon/graph_engine/graph_traversal/edge_processor.py b/api/graphon/graph_engine/graph_traversal/edge_processor.py deleted file mode 100644 index e51eee8a69..0000000000 --- a/api/graphon/graph_engine/graph_traversal/edge_processor.py +++ /dev/null @@ -1,201 +0,0 @@ -""" -Edge processing logic for graph traversal. -""" - -from collections.abc import Sequence -from typing import TYPE_CHECKING, final - -from graphon.enums import NodeExecutionType -from graphon.graph import Edge, Graph -from graphon.graph_events import NodeRunStreamChunkEvent - -from ..graph_state_manager import GraphStateManager -from ..response_coordinator import ResponseStreamCoordinator - -if TYPE_CHECKING: - from .skip_propagator import SkipPropagator - - -@final -class EdgeProcessor: - """ - Processes edges during graph execution. - - This handles marking edges as taken or skipped, notifying - the response coordinator, triggering downstream node execution, - and managing branch node logic. - """ - - def __init__( - self, - graph: Graph, - state_manager: GraphStateManager, - response_coordinator: ResponseStreamCoordinator, - skip_propagator: "SkipPropagator", - ) -> None: - """ - Initialize the edge processor. - - Args: - graph: The workflow graph - state_manager: Unified state manager - response_coordinator: Response stream coordinator - skip_propagator: Propagator for skip states - """ - self._graph = graph - self._state_manager = state_manager - self._response_coordinator = response_coordinator - self._skip_propagator = skip_propagator - - def process_node_success( - self, node_id: str, selected_handle: str | None = None - ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Process edges after a node succeeds. - - Args: - node_id: The ID of the succeeded node - selected_handle: For branch nodes, the selected edge handle - - Returns: - Tuple of (list of downstream node IDs that are now ready, list of streaming events) - """ - node = self._graph.nodes[node_id] - - if node.execution_type == NodeExecutionType.BRANCH: - return self._process_branch_node_edges(node_id, selected_handle) - else: - return self._process_non_branch_node_edges(node_id) - - def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Process edges for non-branch nodes (mark all as TAKEN). - - Args: - node_id: The ID of the succeeded node - - Returns: - Tuple of (list of downstream nodes ready for execution, list of streaming events) - """ - ready_nodes: list[str] = [] - all_streaming_events: list[NodeRunStreamChunkEvent] = [] - outgoing_edges = self._graph.get_outgoing_edges(node_id) - - for edge in outgoing_edges: - nodes, events = self._process_taken_edge(edge) - ready_nodes.extend(nodes) - all_streaming_events.extend(events) - - return ready_nodes, all_streaming_events - - def _process_branch_node_edges( - self, node_id: str, selected_handle: str | None - ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Process edges for branch nodes. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected edge - - Returns: - Tuple of (list of downstream nodes ready for execution, list of streaming events) - - Raises: - ValueError: If no edge was selected - """ - if not selected_handle: - raise ValueError(f"Branch node {node_id} did not select any edge") - - ready_nodes: list[str] = [] - all_streaming_events: list[NodeRunStreamChunkEvent] = [] - - # Categorize edges - selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle) - - # Process unselected edges first (mark as skipped) - for edge in unselected_edges: - self._process_skipped_edge(edge) - - # Process selected edges - for edge in selected_edges: - nodes, events = self._process_taken_edge(edge) - ready_nodes.extend(nodes) - all_streaming_events.extend(events) - - return ready_nodes, all_streaming_events - - def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Mark edge as taken and check downstream node. - - Args: - edge: The edge to process - - Returns: - Tuple of (list containing downstream node ID if it's ready, list of streaming events) - """ - # Mark edge as taken - self._state_manager.mark_edge_taken(edge.id) - - # Notify response coordinator and get streaming events - streaming_events = self._response_coordinator.on_edge_taken(edge.id) - - # Check if downstream node is ready - ready_nodes: list[str] = [] - if self._state_manager.is_node_ready(edge.head): - ready_nodes.append(edge.head) - - return ready_nodes, streaming_events - - def _process_skipped_edge(self, edge: Edge) -> None: - """ - Mark edge as skipped. - - Args: - edge: The edge to skip - """ - self._state_manager.mark_edge_skipped(edge.id) - - def handle_branch_completion( - self, node_id: str, selected_handle: str | None - ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Handle completion of a branch node. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected branch - - Returns: - Tuple of (list of downstream nodes ready for execution, list of streaming events) - - Raises: - ValueError: If no branch was selected - """ - if not selected_handle: - raise ValueError(f"Branch node {node_id} completed without selecting a branch") - - # Categorize edges into selected and unselected - _, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle) - - # Skip all unselected paths - self._skip_propagator.skip_branch_paths(unselected_edges) - - # Process selected edges and get ready nodes and streaming events - return self.process_node_success(node_id, selected_handle) - - def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool: - """ - Validate that a branch selection is valid. - - Args: - node_id: The ID of the branch node - selected_handle: The handle to validate - - Returns: - True if the selection is valid - """ - outgoing_edges = self._graph.get_outgoing_edges(node_id) - valid_handles = {edge.source_handle for edge in outgoing_edges} - return selected_handle in valid_handles diff --git a/api/graphon/graph_engine/graph_traversal/skip_propagator.py b/api/graphon/graph_engine/graph_traversal/skip_propagator.py deleted file mode 100644 index bdb83b38ad..0000000000 --- a/api/graphon/graph_engine/graph_traversal/skip_propagator.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -Skip state propagation through the graph. -""" - -from collections.abc import Sequence -from typing import final - -from graphon.graph import Edge, Graph - -from ..graph_state_manager import GraphStateManager - - -@final -class SkipPropagator: - """ - Propagates skip states through the graph. - - When a node is skipped, this ensures all downstream nodes - that depend solely on it are also skipped. - """ - - def __init__( - self, - graph: Graph, - state_manager: GraphStateManager, - ) -> None: - """ - Initialize the skip propagator. - - Args: - graph: The workflow graph - state_manager: Unified state manager - """ - self._graph = graph - self._state_manager = state_manager - - def propagate_skip_from_edge(self, edge_id: str) -> None: - """ - Recursively propagate skip state from a skipped edge. - - Rules: - - If a node has any UNKNOWN incoming edges, stop processing - - If all incoming edges are SKIPPED, skip the node and its edges - - If any incoming edge is TAKEN, the node may still execute - - Args: - edge_id: The ID of the skipped edge to start from - """ - downstream_node_id = self._graph.edges[edge_id].head - incoming_edges = self._graph.get_incoming_edges(downstream_node_id) - - # Analyze edge states - edge_states = self._state_manager.analyze_edge_states(incoming_edges) - - # Stop if there are unknown edges (not yet processed) - if edge_states["has_unknown"]: - return - - # If any edge is taken, node may still execute - if edge_states["has_taken"]: - # Enqueue node - self._state_manager.enqueue_node(downstream_node_id) - self._state_manager.start_execution(downstream_node_id) - return - - # All edges are skipped, propagate skip to this node - if edge_states["all_skipped"]: - self._propagate_skip_to_node(downstream_node_id) - - def _propagate_skip_to_node(self, node_id: str) -> None: - """ - Mark a node and all its outgoing edges as skipped. - - Args: - node_id: The ID of the node to skip - """ - # Mark node as skipped - self._state_manager.mark_node_skipped(node_id) - - # Mark all outgoing edges as skipped and propagate - outgoing_edges = self._graph.get_outgoing_edges(node_id) - for edge in outgoing_edges: - self._state_manager.mark_edge_skipped(edge.id) - # Recursively propagate skip - self.propagate_skip_from_edge(edge.id) - - def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None: - """ - Skip all paths from unselected branch edges. - - Args: - unselected_edges: List of edges not taken by the branch - """ - for edge in unselected_edges: - self._state_manager.mark_edge_skipped(edge.id) - self.propagate_skip_from_edge(edge.id) diff --git a/api/graphon/graph_engine/layers/README.md b/api/graphon/graph_engine/layers/README.md deleted file mode 100644 index b0f295037c..0000000000 --- a/api/graphon/graph_engine/layers/README.md +++ /dev/null @@ -1,55 +0,0 @@ -# Layers - -Pluggable middleware for engine extensions. - -## Components - -### Layer (base) - -Abstract base class for layers. - -- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks) -- `on_graph_start()` - Execution start hook -- `on_event()` - Process all events -- `on_graph_end()` - Execution end hook - -### DebugLoggingLayer - -Comprehensive execution logging. - -- Configurable detail levels -- Tracks execution statistics -- Truncates long values - -## Usage - -```python -debug_layer = DebugLoggingLayer( - level="INFO", - include_outputs=True -) - -engine = GraphEngine(graph) -engine.layer(debug_layer) -engine.run() -``` - -`engine.layer()` binds the read-only runtime state before execution, so -`graph_runtime_state` is always available inside layer hooks. - -## Custom Layers - -```python -class MetricsLayer(Layer): - def on_event(self, event): - if isinstance(event, NodeRunSucceededEvent): - self.metrics[event.node_id] = event.elapsed_time -``` - -## Configuration - -**DebugLoggingLayer Options:** - -- `level` - Log level (INFO, DEBUG, ERROR) -- `include_inputs/outputs` - Log data values -- `max_value_length` - Truncate long values diff --git a/api/graphon/graph_engine/layers/__init__.py b/api/graphon/graph_engine/layers/__init__.py deleted file mode 100644 index 0a29a52993..0000000000 --- a/api/graphon/graph_engine/layers/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Layer system for GraphEngine extensibility. - -This module provides the layer infrastructure for extending GraphEngine functionality -with middleware-like components that can observe events and interact with execution. -""" - -from .base import GraphEngineLayer -from .debug_logging import DebugLoggingLayer -from .execution_limits import ExecutionLimitsLayer - -__all__ = [ - "DebugLoggingLayer", - "ExecutionLimitsLayer", - "GraphEngineLayer", -] diff --git a/api/graphon/graph_engine/layers/base.py b/api/graphon/graph_engine/layers/base.py deleted file mode 100644 index 605615d347..0000000000 --- a/api/graphon/graph_engine/layers/base.py +++ /dev/null @@ -1,128 +0,0 @@ -""" -Base layer class for GraphEngine extensions. - -This module provides the abstract base class for implementing layers that can -intercept and respond to GraphEngine events. -""" - -from abc import ABC, abstractmethod - -from graphon.graph_engine.protocols.command_channel import CommandChannel -from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.runtime import ReadOnlyGraphRuntimeState - - -class GraphEngineLayerNotInitializedError(Exception): - """Raised when a layer's runtime state is accessed before initialization.""" - - def __init__(self, layer_name: str | None = None) -> None: - name = layer_name or "GraphEngineLayer" - super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.") - - -class GraphEngineLayer(ABC): - """ - Abstract base class for GraphEngine layers. - - Layers are middleware-like components that can: - - Observe all events emitted by the GraphEngine - - Access the graph runtime state - - Send commands to control execution - - Subclasses should override the constructor to accept configuration parameters, - then implement the three lifecycle methods. - """ - - def __init__(self) -> None: - """Initialize the layer. Subclasses can override with custom parameters.""" - self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None - self.command_channel: CommandChannel | None = None - - @property - def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState: - if self._graph_runtime_state is None: - raise GraphEngineLayerNotInitializedError(type(self).__name__) - return self._graph_runtime_state - - def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None: - """ - Initialize the layer with engine dependencies. - - Called by GraphEngine to inject the read-only runtime state and command channel. - This is invoked when the layer is registered with a `GraphEngine` instance. - Implementations should be idempotent. - Args: - graph_runtime_state: Read-only view of the runtime state - command_channel: Channel for sending commands to the engine - """ - self._graph_runtime_state = graph_runtime_state - self.command_channel = command_channel - - @abstractmethod - def on_graph_start(self) -> None: - """ - Called when graph execution starts. - - This is called after the engine has been initialized but before any nodes - are executed. Layers can use this to set up resources or log start information. - """ - pass - - @abstractmethod - def on_event(self, event: GraphEngineEvent) -> None: - """ - Called for every event emitted by the engine. - - This method receives all events generated during graph execution, including: - - Graph lifecycle events (start, success, failure) - - Node execution events (start, success, failure, retry) - - Stream events for response nodes - - Container events (iteration, loop) - - Args: - event: The event emitted by the engine - """ - pass - - @abstractmethod - def on_graph_end(self, error: Exception | None) -> None: - """ - Called when graph execution ends. - - This is called after all nodes have been executed or when execution is - aborted. Layers can use this to clean up resources or log final state. - - Args: - error: The exception that caused execution to fail, or None if successful - """ - pass - - def on_node_run_start(self, node: Node) -> None: - """ - Called immediately before a node begins execution. - - Layers can override to inject behavior (e.g., start spans) prior to node execution. - The node's execution ID is available via `node._node_execution_id` and will be - consistent with all events emitted by this node execution. - - Args: - node: The node instance about to be executed - """ - return - - def on_node_run_end( - self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None - ) -> None: - """ - Called after a node finishes execution. - - The node's execution ID is available via `node._node_execution_id` and matches - the `id` field in all events emitted by this node execution. - - Args: - node: The node instance that just finished execution - error: Exception instance if the node failed, otherwise None - result_event: The final result event from node execution (succeeded/failed/paused), if any - """ - return diff --git a/api/graphon/graph_engine/layers/debug_logging.py b/api/graphon/graph_engine/layers/debug_logging.py deleted file mode 100644 index e6585fb3b9..0000000000 --- a/api/graphon/graph_engine/layers/debug_logging.py +++ /dev/null @@ -1,247 +0,0 @@ -""" -Debug logging layer for GraphEngine. - -This module provides a layer that logs all events and state changes during -graph execution for debugging purposes. -""" - -import logging -from collections.abc import Mapping -from typing import Any, final - -from typing_extensions import override - -from graphon.graph_events import ( - GraphEngineEvent, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .base import GraphEngineLayer - - -@final -class DebugLoggingLayer(GraphEngineLayer): - """ - A layer that provides comprehensive logging of GraphEngine execution. - - This layer logs all events with configurable detail levels, helping developers - debug workflow execution and understand the flow of events. - """ - - def __init__( - self, - level: str = "INFO", - include_inputs: bool = False, - include_outputs: bool = True, - include_process_data: bool = False, - logger_name: str = "GraphEngine.Debug", - max_value_length: int = 500, - ) -> None: - """ - Initialize the debug logging layer. - - Args: - level: Logging level (DEBUG, INFO, WARNING, ERROR) - include_inputs: Whether to log node input values - include_outputs: Whether to log node output values - include_process_data: Whether to log node process data - logger_name: Name of the logger to use - max_value_length: Maximum length of logged values (truncated if longer) - """ - super().__init__() - self.level = level - self.include_inputs = include_inputs - self.include_outputs = include_outputs - self.include_process_data = include_process_data - self.max_value_length = max_value_length - - # Set up logger - self.logger = logging.getLogger(logger_name) - log_level = getattr(logging, level.upper(), logging.INFO) - self.logger.setLevel(log_level) - - # Track execution stats - self.node_count = 0 - self.success_count = 0 - self.failure_count = 0 - self.retry_count = 0 - - def _truncate_value(self, value: Any) -> str: - """Truncate long values for logging.""" - str_value = str(value) - if len(str_value) > self.max_value_length: - return str_value[: self.max_value_length] + "... (truncated)" - return str_value - - def _format_dict(self, data: dict[str, Any] | Mapping[str, Any]) -> str: - """Format a dictionary or mapping for logging with truncation.""" - if not data: - return "{}" - - formatted_items: list[str] = [] - for key, value in data.items(): - formatted_value = self._truncate_value(value) - formatted_items.append(f" {key}: {formatted_value}") - - return "{\n" + ",\n".join(formatted_items) + "\n}" - - @override - def on_graph_start(self) -> None: - """Log graph execution start.""" - self.logger.info("=" * 80) - self.logger.info("🚀 GRAPH EXECUTION STARTED") - self.logger.info("=" * 80) - # Log initial state - self.logger.info("Initial State:") - - @override - def on_event(self, event: GraphEngineEvent) -> None: - """Log individual events based on their type.""" - event_class = event.__class__.__name__ - - # Graph-level events - if isinstance(event, GraphRunStartedEvent): - self.logger.debug("Graph run started event") - - elif isinstance(event, GraphRunSucceededEvent): - self.logger.info("✅ Graph run succeeded") - if self.include_outputs and event.outputs: - self.logger.info(" Final outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, GraphRunPartialSucceededEvent): - self.logger.warning("⚠️ Graph run partially succeeded") - if event.exceptions_count > 0: - self.logger.warning(" Total exceptions: %s", event.exceptions_count) - if self.include_outputs and event.outputs: - self.logger.info(" Final outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, GraphRunFailedEvent): - self.logger.error("❌ Graph run failed: %s", event.error) - if event.exceptions_count > 0: - self.logger.error(" Total exceptions: %s", event.exceptions_count) - - elif isinstance(event, GraphRunAbortedEvent): - self.logger.warning("⚠️ Graph run aborted: %s", event.reason) - if event.outputs: - self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs)) - - # Node-level events - # Retry before Started because Retry subclasses Started; - elif isinstance(event, NodeRunRetryEvent): - self.retry_count += 1 - self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index) - self.logger.warning(" Previous error: %s", event.error) - - elif isinstance(event, NodeRunStartedEvent): - self.node_count += 1 - self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type) - - if self.include_inputs and event.node_run_result.inputs: - self.logger.debug(" Inputs: %s", self._format_dict(event.node_run_result.inputs)) - - elif isinstance(event, NodeRunSucceededEvent): - self.success_count += 1 - self.logger.info("✅ Node succeeded: %s", event.node_id) - - if self.include_outputs and event.node_run_result.outputs: - self.logger.debug(" Outputs: %s", self._format_dict(event.node_run_result.outputs)) - - if self.include_process_data and event.node_run_result.process_data: - self.logger.debug(" Process data: %s", self._format_dict(event.node_run_result.process_data)) - - elif isinstance(event, NodeRunFailedEvent): - self.failure_count += 1 - self.logger.error("❌ Node failed: %s", event.node_id) - self.logger.error(" Error: %s", event.error) - - if event.node_run_result.error: - self.logger.error(" Details: %s", event.node_run_result.error) - - elif isinstance(event, NodeRunExceptionEvent): - self.logger.warning("⚠️ Node exception handled: %s", event.node_id) - self.logger.warning(" Error: %s", event.error) - - elif isinstance(event, NodeRunStreamChunkEvent): - # Log stream chunks at debug level to avoid spam - final_indicator = " (FINAL)" if event.is_final else "" - self.logger.debug( - "📝 Stream chunk from %s%s: %s", event.node_id, final_indicator, self._truncate_value(event.chunk) - ) - - # Iteration events - elif isinstance(event, NodeRunIterationStartedEvent): - self.logger.info("🔁 Iteration started: %s", event.node_id) - - elif isinstance(event, NodeRunIterationNextEvent): - self.logger.debug(" Iteration next: %s (index: %s)", event.node_id, event.index) - - elif isinstance(event, NodeRunIterationSucceededEvent): - self.logger.info("✅ Iteration succeeded: %s", event.node_id) - if self.include_outputs and event.outputs: - self.logger.debug(" Outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, NodeRunIterationFailedEvent): - self.logger.error("❌ Iteration failed: %s", event.node_id) - self.logger.error(" Error: %s", event.error) - - # Loop events - elif isinstance(event, NodeRunLoopStartedEvent): - self.logger.info("🔄 Loop started: %s", event.node_id) - - elif isinstance(event, NodeRunLoopNextEvent): - self.logger.debug(" Loop iteration: %s (index: %s)", event.node_id, event.index) - - elif isinstance(event, NodeRunLoopSucceededEvent): - self.logger.info("✅ Loop succeeded: %s", event.node_id) - if self.include_outputs and event.outputs: - self.logger.debug(" Outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, NodeRunLoopFailedEvent): - self.logger.error("❌ Loop failed: %s", event.node_id) - self.logger.error(" Error: %s", event.error) - - else: - # Log unknown events at debug level - self.logger.debug("Event: %s", event_class) - - @override - def on_graph_end(self, error: Exception | None) -> None: - """Log graph execution end with summary statistics.""" - self.logger.info("=" * 80) - - if error: - self.logger.error("🔴 GRAPH EXECUTION FAILED") - self.logger.error(" Error: %s", error) - else: - self.logger.info("🎉 GRAPH EXECUTION COMPLETED SUCCESSFULLY") - - # Log execution statistics - self.logger.info("Execution Statistics:") - self.logger.info(" Total nodes executed: %s", self.node_count) - self.logger.info(" Successful nodes: %s", self.success_count) - self.logger.info(" Failed nodes: %s", self.failure_count) - self.logger.info(" Node retries: %s", self.retry_count) - - # Log final state if available - if self.include_outputs and self.graph_runtime_state.outputs: - self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs)) - - self.logger.info("=" * 80) diff --git a/api/graphon/graph_engine/layers/execution_limits.py b/api/graphon/graph_engine/layers/execution_limits.py deleted file mode 100644 index 2742b3acd3..0000000000 --- a/api/graphon/graph_engine/layers/execution_limits.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -Execution limits layer for GraphEngine. - -This layer monitors workflow execution to enforce limits on: -- Maximum execution steps -- Maximum execution time - -When limits are exceeded, the layer automatically aborts execution. -""" - -import logging -import time -from enum import StrEnum -from typing import final - -from typing_extensions import override - -from graphon.graph_engine.entities.commands import AbortCommand, CommandType -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import ( - GraphEngineEvent, - NodeRunStartedEvent, -) -from graphon.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent - - -class LimitType(StrEnum): - """Types of execution limits that can be exceeded.""" - - STEP_LIMIT = "step_limit" - TIME_LIMIT = "time_limit" - - -@final -class ExecutionLimitsLayer(GraphEngineLayer): - """ - Layer that enforces execution limits for workflows. - - Monitors: - - Step count: Tracks number of node executions - - Time limit: Monitors total execution time - - Automatically aborts execution when limits are exceeded. - """ - - def __init__(self, max_steps: int, max_time: int) -> None: - """ - Initialize the execution limits layer. - - Args: - max_steps: Maximum number of execution steps allowed - max_time: Maximum execution time in seconds allowed - """ - super().__init__() - self.max_steps = max_steps - self.max_time = max_time - - # Runtime tracking - self.start_time: float | None = None - self.step_count = 0 - self.logger = logging.getLogger(__name__) - - # State tracking - self._execution_started = False - self._execution_ended = False - self._abort_sent = False # Track if abort command has been sent - - @override - def on_graph_start(self) -> None: - """Called when graph execution starts.""" - self.start_time = time.time() - self.step_count = 0 - self._execution_started = True - self._execution_ended = False - self._abort_sent = False - - self.logger.debug("Execution limits monitoring started") - - @override - def on_event(self, event: GraphEngineEvent) -> None: - """ - Called for every event emitted by the engine. - - Monitors execution progress and enforces limits. - """ - if not self._execution_started or self._execution_ended or self._abort_sent: - return - - # Track step count for node execution events - if isinstance(event, NodeRunStartedEvent): - self.step_count += 1 - self.logger.debug("Step %d started: %s", self.step_count, event.node_id) - - # Check step limit when node execution completes - if isinstance(event, NodeRunSucceededEvent | NodeRunFailedEvent): - if self._reached_step_limitation(): - self._send_abort_command(LimitType.STEP_LIMIT) - - if self._reached_time_limitation(): - self._send_abort_command(LimitType.TIME_LIMIT) - - @override - def on_graph_end(self, error: Exception | None) -> None: - """Called when graph execution ends.""" - if self._execution_started and not self._execution_ended: - self._execution_ended = True - - if self.start_time: - total_time = time.time() - self.start_time - self.logger.debug("Execution completed: %d steps in %.2f seconds", self.step_count, total_time) - - def _reached_step_limitation(self) -> bool: - """Check if step count limit has been exceeded.""" - return self.step_count > self.max_steps - - def _reached_time_limitation(self) -> bool: - """Check if time limit has been exceeded.""" - return self.start_time is not None and (time.time() - self.start_time) > self.max_time - - def _send_abort_command(self, limit_type: LimitType) -> None: - """ - Send abort command due to limit violation. - - Args: - limit_type: Type of limit exceeded - """ - if not self.command_channel or not self._execution_started or self._execution_ended or self._abort_sent: - return - - # Format detailed reason message - if limit_type == LimitType.STEP_LIMIT: - reason = f"Maximum execution steps exceeded: {self.step_count} > {self.max_steps}" - elif limit_type == LimitType.TIME_LIMIT: - elapsed_time = time.time() - self.start_time if self.start_time else 0 - reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s" - - self.logger.warning("Execution limit exceeded: %s", reason) - - try: - # Send abort command to the engine - abort_command = AbortCommand(command_type=CommandType.ABORT, reason=reason) - self.command_channel.send_command(abort_command) - - # Mark that abort has been sent to prevent duplicate commands - self._abort_sent = True - - self.logger.debug("Abort command sent to engine") - - except Exception: - self.logger.exception("Failed to send abort command") diff --git a/api/graphon/graph_engine/manager.py b/api/graphon/graph_engine/manager.py deleted file mode 100644 index c728ff6986..0000000000 --- a/api/graphon/graph_engine/manager.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -GraphEngine Manager for sending control commands via Redis channel. - -This module provides a simplified interface for controlling workflow executions -using the new Redis command channel, without requiring user permission checks. -Callers must provide a Redis client dependency from outside the workflow package. -""" - -import logging -from collections.abc import Sequence -from typing import final - -from graphon.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol -from graphon.graph_engine.entities.commands import ( - AbortCommand, - GraphEngineCommand, - PauseCommand, - UpdateVariablesCommand, - VariableUpdate, -) - -logger = logging.getLogger(__name__) - - -@final -class GraphEngineManager: - """ - Manager for sending control commands to GraphEngine instances. - - This class provides a simple interface for controlling workflow executions - by sending commands through Redis channels, without user validation. - """ - - _redis_client: RedisClientProtocol - - def __init__(self, redis_client: RedisClientProtocol) -> None: - self._redis_client = redis_client - - def send_stop_command(self, task_id: str, reason: str | None = None) -> None: - """ - Send a stop command to a running workflow. - - Args: - task_id: The task ID of the workflow to stop - reason: Optional reason for stopping (defaults to "User requested stop") - """ - abort_command = AbortCommand(reason=reason or "User requested stop") - self._send_command(task_id, abort_command) - - def send_pause_command(self, task_id: str, reason: str | None = None) -> None: - """Send a pause command to a running workflow.""" - - pause_command = PauseCommand(reason=reason or "User requested pause") - self._send_command(task_id, pause_command) - - def send_update_variables_command(self, task_id: str, updates: Sequence[VariableUpdate]) -> None: - """Send a command to update variables in a running workflow.""" - - if not updates: - return - - update_command = UpdateVariablesCommand(updates=updates) - self._send_command(task_id, update_command) - - def _send_command(self, task_id: str, command: GraphEngineCommand) -> None: - """Send a command to the workflow-specific Redis channel.""" - - if not task_id: - return - - channel_key = f"workflow:{task_id}:commands" - channel = RedisChannel(self._redis_client, channel_key) - - try: - channel.send_command(command) - except Exception: - # Silently fail if Redis is unavailable - # The legacy control mechanisms will still work - logger.exception("Failed to send graph engine command %s for task %s", command.__class__.__name__, task_id) diff --git a/api/graphon/graph_engine/orchestration/__init__.py b/api/graphon/graph_engine/orchestration/__init__.py deleted file mode 100644 index de08e942fb..0000000000 --- a/api/graphon/graph_engine/orchestration/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Orchestration subsystem for graph engine. - -This package coordinates the overall execution flow between -different subsystems. -""" - -from .dispatcher import Dispatcher -from .execution_coordinator import ExecutionCoordinator - -__all__ = [ - "Dispatcher", - "ExecutionCoordinator", -] diff --git a/api/graphon/graph_engine/orchestration/dispatcher.py b/api/graphon/graph_engine/orchestration/dispatcher.py deleted file mode 100644 index f75bbee08e..0000000000 --- a/api/graphon/graph_engine/orchestration/dispatcher.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Main dispatcher for processing events from workers. -""" - -import logging -import queue -import threading -import time -from typing import TYPE_CHECKING, final - -from graphon.graph_events import ( - GraphNodeEventBase, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunSucceededEvent, -) - -from ..event_management import EventManager -from .execution_coordinator import ExecutionCoordinator - -if TYPE_CHECKING: - from ..event_management import EventHandler - -logger = logging.getLogger(__name__) - - -@final -class Dispatcher: - """ - Main dispatcher that processes events from the event queue. - - This runs in a separate thread and coordinates event processing - with timeout and completion detection. - """ - - _COMMAND_TRIGGER_EVENTS = ( - NodeRunSucceededEvent, - NodeRunFailedEvent, - NodeRunExceptionEvent, - ) - - def __init__( - self, - event_queue: queue.Queue[GraphNodeEventBase], - event_handler: "EventHandler", - execution_coordinator: ExecutionCoordinator, - event_emitter: EventManager | None = None, - ) -> None: - """ - Initialize the dispatcher. - - Args: - event_queue: Queue of events from workers - event_handler: Event handler registry for processing events - execution_coordinator: Coordinator for execution flow - event_emitter: Optional event manager to signal completion - """ - self._event_queue = event_queue - self._event_handler = event_handler - self._execution_coordinator = execution_coordinator - self._event_emitter = event_emitter - - self._thread: threading.Thread | None = None - self._stop_event = threading.Event() - self._start_time: float | None = None - - def start(self) -> None: - """Start the dispatcher thread.""" - if self._thread and self._thread.is_alive(): - return - - self._stop_event.clear() - self._start_time = time.time() - self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True) - self._thread.start() - - def stop(self) -> None: - """Stop the dispatcher thread.""" - self._stop_event.set() - if self._thread and self._thread.is_alive(): - self._thread.join(timeout=2.0) - - def _dispatcher_loop(self) -> None: - """Main dispatcher loop.""" - try: - self._process_commands() - paused = False - while not self._stop_event.is_set(): - if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete: - break - if self._execution_coordinator.paused: - paused = True - break - - self._execution_coordinator.check_scaling() - try: - event = self._event_queue.get(timeout=0.1) - self._event_handler.dispatch(event) - self._event_queue.task_done() - self._process_commands(event) - except queue.Empty: - time.sleep(0.1) - - self._process_commands() - if paused: - self._drain_events_until_idle() - else: - self._drain_event_queue() - - except Exception as e: - logger.exception("Dispatcher error") - self._execution_coordinator.mark_failed(e) - - finally: - self._execution_coordinator.mark_complete() - # Signal the event emitter that execution is complete - if self._event_emitter: - self._event_emitter.mark_complete() - - def _process_commands(self, event: GraphNodeEventBase | None = None): - if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS): - self._execution_coordinator.process_commands() - - def _drain_event_queue(self) -> None: - while True: - try: - event = self._event_queue.get(block=False) - self._event_handler.dispatch(event) - self._event_queue.task_done() - except queue.Empty: - break - - def _drain_events_until_idle(self) -> None: - while not self._stop_event.is_set(): - try: - event = self._event_queue.get(timeout=0.1) - self._event_handler.dispatch(event) - self._event_queue.task_done() - self._process_commands(event) - except queue.Empty: - if not self._execution_coordinator.has_executing_nodes(): - break - self._drain_event_queue() diff --git a/api/graphon/graph_engine/orchestration/execution_coordinator.py b/api/graphon/graph_engine/orchestration/execution_coordinator.py deleted file mode 100644 index 0f8550eb12..0000000000 --- a/api/graphon/graph_engine/orchestration/execution_coordinator.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -Execution coordinator for managing overall workflow execution. -""" - -from typing import final - -from ..command_processing import CommandProcessor -from ..domain import GraphExecution -from ..graph_state_manager import GraphStateManager -from ..worker_management import WorkerPool - - -@final -class ExecutionCoordinator: - """ - Coordinates overall execution flow between subsystems. - - This provides high-level coordination methods used by the - dispatcher to manage execution state. - """ - - def __init__( - self, - graph_execution: GraphExecution, - state_manager: GraphStateManager, - command_processor: CommandProcessor, - worker_pool: WorkerPool, - ) -> None: - """ - Initialize the execution coordinator. - - Args: - graph_execution: Graph execution aggregate - state_manager: Unified state manager - command_processor: Processor for commands - worker_pool: Pool of workers - """ - self._graph_execution = graph_execution - self._state_manager = state_manager - self._command_processor = command_processor - self._worker_pool = worker_pool - - def process_commands(self) -> None: - """Process any pending commands.""" - self._command_processor.process_commands() - - def check_scaling(self) -> None: - """Check and perform worker scaling if needed.""" - self._worker_pool.check_and_scale() - - @property - def execution_complete(self): - return self._state_manager.is_execution_complete() - - @property - def aborted(self): - return self._graph_execution.aborted or self._graph_execution.has_error - - @property - def paused(self) -> bool: - """Expose whether the underlying graph execution is paused.""" - return self._graph_execution.is_paused - - def mark_complete(self) -> None: - """Mark execution as complete.""" - if self._graph_execution.is_paused: - return - if not self._graph_execution.completed: - self._graph_execution.complete() - - def mark_failed(self, error: Exception) -> None: - """ - Mark execution as failed. - - Args: - error: The error that caused failure - """ - self._graph_execution.fail(error) - - def handle_pause_if_needed(self) -> None: - """If the execution has been paused, stop workers immediately.""" - - if not self._graph_execution.is_paused: - return - - self._worker_pool.stop() - self._state_manager.clear_executing() - - def handle_abort_if_needed(self) -> None: - """If the execution has been aborted, stop workers immediately.""" - - if not self._graph_execution.aborted: - return - - self._worker_pool.stop() - self._state_manager.clear_executing() - - def has_executing_nodes(self) -> bool: - """Return True if any nodes are currently marked as executing.""" - # This check is only safe once execution has already paused. - # Before pause, executing state can change concurrently, which makes the result unreliable. - if not self._graph_execution.is_paused: - raise AssertionError("has_executing_nodes should only be called after execution is paused") - return self._state_manager.get_executing_count() > 0 diff --git a/api/graphon/graph_engine/protocols/command_channel.py b/api/graphon/graph_engine/protocols/command_channel.py deleted file mode 100644 index fabd8634c8..0000000000 --- a/api/graphon/graph_engine/protocols/command_channel.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -CommandChannel protocol for GraphEngine command communication. - -This protocol defines the interface for sending and receiving commands -to/from a GraphEngine instance, supporting both local and distributed scenarios. -""" - -from typing import Protocol - -from ..entities.commands import GraphEngineCommand - - -class CommandChannel(Protocol): - """ - Protocol for bidirectional command communication with GraphEngine. - - Since each GraphEngine instance processes only one workflow execution, - this channel is dedicated to that single execution. - """ - - def fetch_commands(self) -> list[GraphEngineCommand]: - """ - Fetch pending commands for this GraphEngine instance. - - Called by GraphEngine to poll for commands that need to be processed. - - Returns: - List of pending commands (may be empty) - """ - ... - - def send_command(self, command: GraphEngineCommand) -> None: - """ - Send a command to be processed by this GraphEngine instance. - - Called by external systems to send control commands to the running workflow. - - Args: - command: The command to send - """ - ... diff --git a/api/graphon/graph_engine/ready_queue/__init__.py b/api/graphon/graph_engine/ready_queue/__init__.py deleted file mode 100644 index acba0e961c..0000000000 --- a/api/graphon/graph_engine/ready_queue/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Ready queue implementations for GraphEngine. - -This package contains the protocol and implementations for managing -the queue of nodes ready for execution. -""" - -from .factory import create_ready_queue_from_state -from .in_memory import InMemoryReadyQueue -from .protocol import ReadyQueue, ReadyQueueState - -__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"] diff --git a/api/graphon/graph_engine/ready_queue/factory.py b/api/graphon/graph_engine/ready_queue/factory.py deleted file mode 100644 index a9d4f470e5..0000000000 --- a/api/graphon/graph_engine/ready_queue/factory.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -Factory for creating ReadyQueue instances from serialized state. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from .in_memory import InMemoryReadyQueue -from .protocol import ReadyQueueState - -if TYPE_CHECKING: - from .protocol import ReadyQueue - - -def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue: - """ - Create a ReadyQueue instance from a serialized state. - - Args: - state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue - - Returns: - A ReadyQueue instance initialized with the given state - - Raises: - ValueError: If the queue type is unknown or version is unsupported - """ - if state.type == "InMemoryReadyQueue": - if state.version != "1.0": - raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}") - queue = InMemoryReadyQueue() - # Always pass as JSON string to loads() - queue.loads(state.model_dump_json()) - return queue - else: - raise ValueError(f"Unknown ready queue type: {state.type}") diff --git a/api/graphon/graph_engine/ready_queue/in_memory.py b/api/graphon/graph_engine/ready_queue/in_memory.py deleted file mode 100644 index f2c265ece0..0000000000 --- a/api/graphon/graph_engine/ready_queue/in_memory.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -In-memory implementation of the ReadyQueue protocol. - -This implementation wraps Python's standard queue.Queue and adds -serialization capabilities for state storage. -""" - -import queue -from typing import final - -from .protocol import ReadyQueue, ReadyQueueState - - -@final -class InMemoryReadyQueue(ReadyQueue): - """ - In-memory ready queue implementation with serialization support. - - This implementation uses Python's queue.Queue internally and provides - methods to serialize and restore the queue state. - """ - - def __init__(self, maxsize: int = 0) -> None: - """ - Initialize the in-memory ready queue. - - Args: - maxsize: Maximum size of the queue (0 for unlimited) - """ - self._queue: queue.Queue[str] = queue.Queue(maxsize=maxsize) - - def put(self, item: str) -> None: - """ - Add a node ID to the ready queue. - - Args: - item: The node ID to add to the queue - """ - self._queue.put(item) - - def get(self, timeout: float | None = None) -> str: - """ - Retrieve and remove a node ID from the queue. - - Args: - timeout: Maximum time to wait for an item (None for blocking) - - Returns: - The node ID retrieved from the queue - - Raises: - queue.Empty: If timeout expires and no item is available - """ - if timeout is None: - return self._queue.get(block=True) - return self._queue.get(timeout=timeout) - - def task_done(self) -> None: - """ - Indicate that a previously retrieved task is complete. - - Used by worker threads to signal task completion for - join() synchronization. - """ - self._queue.task_done() - - def empty(self) -> bool: - """ - Check if the queue is empty. - - Returns: - True if the queue has no items, False otherwise - """ - return self._queue.empty() - - def qsize(self) -> int: - """ - Get the approximate size of the queue. - - Returns: - The approximate number of items in the queue - """ - return self._queue.qsize() - - def dumps(self) -> str: - """ - Serialize the queue state to a JSON string for storage. - - Returns: - A JSON string containing the serialized queue state - """ - # Extract all items from the queue without removing them - items: list[str] = [] - temp_items: list[str] = [] - - # Drain the queue temporarily to get all items - while not self._queue.empty(): - try: - item = self._queue.get_nowait() - temp_items.append(item) - items.append(item) - except queue.Empty: - break - - # Put items back in the same order - for item in temp_items: - self._queue.put(item) - - state = ReadyQueueState( - type="InMemoryReadyQueue", - version="1.0", - items=items, - ) - return state.model_dump_json() - - def loads(self, data: str) -> None: - """ - Restore the queue state from a JSON string. - - Args: - data: The JSON string containing the serialized queue state to restore - """ - state = ReadyQueueState.model_validate_json(data) - - if state.type != "InMemoryReadyQueue": - raise ValueError(f"Invalid serialized data type: {state.type}") - - if state.version != "1.0": - raise ValueError(f"Unsupported version: {state.version}") - - # Clear the current queue - while not self._queue.empty(): - try: - self._queue.get_nowait() - except queue.Empty: - break - - # Restore items - for item in state.items: - self._queue.put(item) diff --git a/api/graphon/graph_engine/ready_queue/protocol.py b/api/graphon/graph_engine/ready_queue/protocol.py deleted file mode 100644 index 97d3ea6dd2..0000000000 --- a/api/graphon/graph_engine/ready_queue/protocol.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -ReadyQueue protocol for GraphEngine node execution queue. - -This protocol defines the interface for managing the queue of nodes ready -for execution, supporting both in-memory and persistent storage scenarios. -""" - -from collections.abc import Sequence -from typing import Protocol - -from pydantic import BaseModel, Field - - -class ReadyQueueState(BaseModel): - """ - Pydantic model for serialized ready queue state. - - This defines the structure of the data returned by dumps() - and expected by loads() for ready queue serialization. - """ - - type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')") - version: str = Field(description="Serialization format version") - items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue") - - -class ReadyQueue(Protocol): - """ - Protocol for managing nodes ready for execution in GraphEngine. - - This protocol defines the interface that any ready queue implementation - must provide, enabling both in-memory queues and persistent queues - that can be serialized for state storage. - """ - - def put(self, item: str) -> None: - """ - Add a node ID to the ready queue. - - Args: - item: The node ID to add to the queue - """ - ... - - def get(self, timeout: float | None = None) -> str: - """ - Retrieve and remove a node ID from the queue. - - Args: - timeout: Maximum time to wait for an item (None for blocking) - - Returns: - The node ID retrieved from the queue - - Raises: - queue.Empty: If timeout expires and no item is available - """ - ... - - def task_done(self) -> None: - """ - Indicate that a previously retrieved task is complete. - - Used by worker threads to signal task completion for - join() synchronization. - """ - ... - - def empty(self) -> bool: - """ - Check if the queue is empty. - - Returns: - True if the queue has no items, False otherwise - """ - ... - - def qsize(self) -> int: - """ - Get the approximate size of the queue. - - Returns: - The approximate number of items in the queue - """ - ... - - def dumps(self) -> str: - """ - Serialize the queue state to a JSON string for storage. - - Returns: - A JSON string containing the serialized queue state - that can be persisted and later restored - """ - ... - - def loads(self, data: str) -> None: - """ - Restore the queue state from a JSON string. - - Args: - data: The JSON string containing the serialized queue state to restore - """ - ... diff --git a/api/graphon/graph_engine/response_coordinator/__init__.py b/api/graphon/graph_engine/response_coordinator/__init__.py deleted file mode 100644 index e11d31199c..0000000000 --- a/api/graphon/graph_engine/response_coordinator/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -ResponseStreamCoordinator - Coordinates streaming output from response nodes - -This component manages response streaming sessions and ensures ordered streaming -of responses based on upstream node outputs and constants. -""" - -from .coordinator import ResponseStreamCoordinator - -__all__ = ["ResponseStreamCoordinator"] diff --git a/api/graphon/graph_engine/response_coordinator/coordinator.py b/api/graphon/graph_engine/response_coordinator/coordinator.py deleted file mode 100644 index a6562f0223..0000000000 --- a/api/graphon/graph_engine/response_coordinator/coordinator.py +++ /dev/null @@ -1,697 +0,0 @@ -""" -Main ResponseStreamCoordinator implementation. - -This module contains the public ResponseStreamCoordinator class that manages -response streaming sessions and ensures ordered streaming of responses. -""" - -import logging -from collections import deque -from collections.abc import Sequence -from threading import RLock -from typing import Literal, TypeAlias, final -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from graphon.enums import NodeExecutionType, NodeState -from graphon.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent -from graphon.nodes.base.template import TextSegment, VariableSegment -from graphon.runtime import VariablePool -from graphon.runtime.graph_runtime_state import GraphProtocol - -from .path import Path -from .session import ResponseSession - -logger = logging.getLogger(__name__) - -# Type definitions -NodeID: TypeAlias = str -EdgeID: TypeAlias = str - - -class ResponseSessionState(BaseModel): - """Serializable representation of a response session.""" - - node_id: str - index: int = Field(default=0, ge=0) - - -class StreamBufferState(BaseModel): - """Serializable representation of buffered stream chunks.""" - - selector: tuple[str, ...] - events: list[NodeRunStreamChunkEvent] = Field(default_factory=list) - - -class StreamPositionState(BaseModel): - """Serializable representation for stream read positions.""" - - selector: tuple[str, ...] - position: int = Field(default=0, ge=0) - - -class ResponseStreamCoordinatorState(BaseModel): - """Serialized snapshot of ResponseStreamCoordinator.""" - - type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator") - version: str = Field(default="1.0") - response_nodes: Sequence[str] = Field(default_factory=list) - active_session: ResponseSessionState | None = None - waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list) - pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list) - node_execution_ids: dict[str, str] = Field(default_factory=dict) - paths_map: dict[str, list[list[str]]] = Field(default_factory=dict) - stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list) - stream_positions: Sequence[StreamPositionState] = Field(default_factory=list) - closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list) - - -@final -class ResponseStreamCoordinator: - """ - Manages response streaming sessions without relying on global state. - - Ensures ordered streaming of responses based on upstream node outputs and constants. - """ - - def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None: - """ - Initialize coordinator with variable pool. - - Args: - variable_pool: VariablePool instance for accessing node variables - graph: Graph instance for looking up node information - """ - self._variable_pool = variable_pool - self._graph = graph - self._active_session: ResponseSession | None = None - self._waiting_sessions: deque[ResponseSession] = deque() - self._lock = RLock() - - # Internal stream management (replacing OutputRegistry) - self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {} - self._stream_positions: dict[tuple[str, ...], int] = {} - self._closed_streams: set[tuple[str, ...]] = set() - - # Track response nodes - self._response_nodes: set[NodeID] = set() - - # Store paths for each response node - self._paths_maps: dict[NodeID, list[Path]] = {} - - # Track node execution IDs and types for proper event forwarding - self._node_execution_ids: dict[NodeID, str] = {} # node_id -> execution_id - - # Track response sessions to ensure only one per node - self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session - - def register(self, response_node_id: NodeID) -> None: - with self._lock: - if response_node_id in self._response_nodes: - return - self._response_nodes.add(response_node_id) - - # Build and save paths map for this response node - paths_map = self._build_paths_map(response_node_id) - self._paths_maps[response_node_id] = paths_map - - # Create and store response session for this node - response_node = self._graph.nodes[response_node_id] - session = ResponseSession.from_node(response_node) - self._response_sessions[response_node_id] = session - - def track_node_execution(self, node_id: NodeID, execution_id: str) -> None: - """Track the execution ID for a node when it starts executing. - - Args: - node_id: The ID of the node - execution_id: The execution ID from NodeRunStartedEvent - """ - with self._lock: - self._node_execution_ids[node_id] = execution_id - - def _get_or_create_execution_id(self, node_id: NodeID) -> str: - """Get the execution ID for a node, creating one if it doesn't exist. - - Args: - node_id: The ID of the node - - Returns: - The execution ID for the node - """ - with self._lock: - if node_id not in self._node_execution_ids: - self._node_execution_ids[node_id] = str(uuid4()) - return self._node_execution_ids[node_id] - - def _build_paths_map(self, response_node_id: NodeID) -> list[Path]: - """ - Build a paths map for a response node by finding all paths from root node - to the response node, recording branch edges along each path. - - Args: - response_node_id: ID of the response node to analyze - - Returns: - List of Path objects, where each path contains branch edge IDs - """ - # Get root node ID - root_node_id = self._graph.root_node.id - - # If root is the response node, return empty path - if root_node_id == response_node_id: - return [Path()] - - # Extract variable selectors from the response node's template - response_node = self._graph.nodes[response_node_id] - response_session = ResponseSession.from_node(response_node) - template = response_session.template - - # Collect all variable selectors from the template - variable_selectors: set[tuple[str, ...]] = set() - for segment in template.segments: - if isinstance(segment, VariableSegment): - variable_selectors.add(tuple(segment.selector[:2])) - - # Step 1: Find all complete paths from root to response node - all_complete_paths: list[list[EdgeID]] = [] - - def find_paths( - current_node_id: NodeID, target_node_id: NodeID, current_path: list[EdgeID], visited: set[NodeID] - ) -> None: - """Recursively find all paths from current node to target node.""" - if current_node_id == target_node_id: - # Found a complete path, store it - all_complete_paths.append(current_path.copy()) - return - - # Mark as visited to avoid cycles - visited.add(current_node_id) - - # Explore outgoing edges - outgoing_edges = self._graph.get_outgoing_edges(current_node_id) - for edge in outgoing_edges: - edge_id = edge.id - next_node_id = edge.head - - # Skip if already visited in this path - if next_node_id not in visited: - # Add edge to path and recurse - new_path = current_path + [edge_id] - find_paths(next_node_id, target_node_id, new_path, visited.copy()) - - # Start searching from root node - find_paths(root_node_id, response_node_id, [], set()) - - # Step 2: For each complete path, filter edges based on node blocking behavior - filtered_paths: list[Path] = [] - for path in all_complete_paths: - blocking_edges: list[str] = [] - for edge_id in path: - edge = self._graph.edges[edge_id] - source_node = self._graph.nodes[edge.tail] - - # Check if node is a branch, container, or response node - if source_node.execution_type in { - NodeExecutionType.BRANCH, - NodeExecutionType.CONTAINER, - NodeExecutionType.RESPONSE, - } or source_node.blocks_variable_output(variable_selectors): - blocking_edges.append(edge_id) - - # Keep the path even if it's empty - filtered_paths.append(Path(edges=blocking_edges)) - - return filtered_paths - - def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]: - """ - Handle when an edge is taken (selected by a branch node). - - This method updates the paths for all response nodes by removing - the taken edge. If any response node has an empty path after removal, - it means the node is now deterministically reachable and should start. - - Args: - edge_id: The ID of the edge that was taken - - Returns: - List of events to emit from starting new sessions - """ - events: list[NodeRunStreamChunkEvent] = [] - - with self._lock: - # Check each response node in order - for response_node_id in self._response_nodes: - if response_node_id not in self._paths_maps: - continue - - paths = self._paths_maps[response_node_id] - has_reachable_path = False - - # Update each path by removing the taken edge - for path in paths: - # Remove the taken edge from this path - path.remove_edge(edge_id) - - # Check if this path is now empty (node is reachable) - if path.is_empty(): - has_reachable_path = True - - # If node is now reachable (has empty path), start/queue session - if has_reachable_path: - # Pass the node_id to the activation method - # The method will handle checking and removing from map - events.extend(self._active_or_queue_session(response_node_id)) - return events - - def _active_or_queue_session(self, node_id: str) -> Sequence[NodeRunStreamChunkEvent]: - """ - Start a session immediately if no active session, otherwise queue it. - Only activates sessions that exist in the _response_sessions map. - - Args: - node_id: The ID of the response node to activate - - Returns: - List of events from flush attempt if session started immediately - """ - events: list[NodeRunStreamChunkEvent] = [] - - # Get the session from our map (only activate if it exists) - session = self._response_sessions.get(node_id) - if not session: - return events - - # Remove from map to ensure it won't be activated again - del self._response_sessions[node_id] - - if self._active_session is None: - self._active_session = session - - # Try to flush immediately - events.extend(self.try_flush()) - else: - # Queue the session if another is active - self._waiting_sessions.append(session) - - return events - - def intercept_event( - self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent - ) -> Sequence[NodeRunStreamChunkEvent]: - with self._lock: - if isinstance(event, NodeRunStreamChunkEvent): - self._append_stream_chunk(event.selector, event) - if event.is_final: - self._close_stream(event.selector) - return self.try_flush() - else: - # Skip cause we share the same variable pool. - # - # for variable_name, variable_value in event.node_run_result.outputs.items(): - # self._variable_pool.add((event.node_id, variable_name), variable_value) - return self.try_flush() - - def _create_stream_chunk_event( - self, - node_id: str, - execution_id: str, - selector: Sequence[str], - chunk: str, - is_final: bool = False, - ) -> NodeRunStreamChunkEvent: - """Create a stream chunk event with consistent structure. - - For selectors with special prefixes (sys, env, conversation), we use the - active response node's information since these are not actual node IDs. - """ - # Check if this is a special selector that doesn't correspond to a node - if selector and selector[0] not in self._graph.nodes and self._active_session: - # Use the active response node for special selectors - response_node = self._graph.nodes[self._active_session.node_id] - return NodeRunStreamChunkEvent( - id=execution_id, - node_id=response_node.id, - node_type=response_node.node_type, - selector=selector, - chunk=chunk, - is_final=is_final, - ) - - # Standard case: selector refers to an actual node - node = self._graph.nodes[node_id] - return NodeRunStreamChunkEvent( - id=execution_id, - node_id=node.id, - node_type=node.node_type, - selector=selector, - chunk=chunk, - is_final=is_final, - ) - - def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]: - """Process a variable segment. Returns (events, is_complete). - - Handles both regular node selectors and special system selectors (sys, env, conversation). - For special selectors, we attribute the output to the active response node. - """ - events: list[NodeRunStreamChunkEvent] = [] - source_selector_prefix = segment.selector[0] if segment.selector else "" - is_complete = False - - # Determine which node to attribute the output to - # For special selectors (sys, env, conversation), use the active response node - # For regular selectors, use the source node - if self._active_session and source_selector_prefix not in self._graph.nodes: - # Special selector - use active response node - output_node_id = self._active_session.node_id - else: - # Regular node selector - output_node_id = source_selector_prefix - execution_id = self._get_or_create_execution_id(output_node_id) - - # Stream all available chunks - while self._has_unread_stream(segment.selector): - if event := self._pop_stream_chunk(segment.selector): - # For special selectors, we need to update the event to use - # the active response node's information - if self._active_session and source_selector_prefix not in self._graph.nodes: - response_node = self._graph.nodes[self._active_session.node_id] - # Create a new event with the response node's information - # but keep the original selector - updated_event = NodeRunStreamChunkEvent( - id=execution_id, - node_id=response_node.id, - node_type=response_node.node_type, - selector=event.selector, # Keep original selector - chunk=event.chunk, - is_final=event.is_final, - ) - events.append(updated_event) - else: - # Regular node selector - use event as is - events.append(event) - - # Check if this is the last chunk by looking ahead - stream_closed = self._is_stream_closed(segment.selector) - # Check if stream is closed to determine if segment is complete - if stream_closed: - is_complete = True - - elif value := self._variable_pool.get(segment.selector): - # Process scalar value - is_last_segment = bool( - self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1 - ) - events.append( - self._create_stream_chunk_event( - node_id=output_node_id, - execution_id=execution_id, - selector=segment.selector, - chunk=value.markdown, - is_final=is_last_segment, - ) - ) - is_complete = True - - return events, is_complete - - def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]: - """Process a text segment. Returns (events, is_complete).""" - assert self._active_session is not None - current_response_node = self._graph.nodes[self._active_session.node_id] - - # Use get_or_create_execution_id to ensure we have a consistent ID - execution_id = self._get_or_create_execution_id(current_response_node.id) - - is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1 - event = self._create_stream_chunk_event( - node_id=current_response_node.id, - execution_id=execution_id, - selector=[current_response_node.id, "answer"], # FIXME(-LAN-) - chunk=segment.text, - is_final=is_last_segment, - ) - return [event] - - def try_flush(self) -> list[NodeRunStreamChunkEvent]: - with self._lock: - if not self._active_session: - return [] - - template = self._active_session.template - response_node_id = self._active_session.node_id - - events: list[NodeRunStreamChunkEvent] = [] - - # Process segments sequentially from current index - while self._active_session.index < len(template.segments): - segment = template.segments[self._active_session.index] - - if isinstance(segment, VariableSegment): - # Check if the source node for this variable is skipped - # Only check for actual nodes, not special selectors (sys, env, conversation) - source_selector_prefix = segment.selector[0] if segment.selector else "" - if source_selector_prefix in self._graph.nodes: - source_node = self._graph.nodes[source_selector_prefix] - - if source_node.state == NodeState.SKIPPED: - # Skip this variable segment if the source node is skipped - self._active_session.index += 1 - continue - - segment_events, is_complete = self._process_variable_segment(segment) - events.extend(segment_events) - - # Only advance index if this variable segment is complete - if is_complete: - self._active_session.index += 1 - else: - # Wait for more data - break - - else: - segment_events = self._process_text_segment(segment) - events.extend(segment_events) - self._active_session.index += 1 - - if self._active_session.is_complete(): - # End current session and get events from starting next session - next_session_events = self.end_session(response_node_id) - events.extend(next_session_events) - - return events - - def end_session(self, node_id: str) -> list[NodeRunStreamChunkEvent]: - """ - End the active session for a response node. - Automatically starts the next waiting session if available. - - Args: - node_id: ID of the response node ending its session - - Returns: - List of events from starting the next session - """ - with self._lock: - events: list[NodeRunStreamChunkEvent] = [] - - if self._active_session and self._active_session.node_id == node_id: - self._active_session = None - - # Try to start next waiting session - if self._waiting_sessions: - next_session = self._waiting_sessions.popleft() - self._active_session = next_session - - # Immediately try to flush any available segments - events = self.try_flush() - - return events - - # ============= Internal Stream Management Methods ============= - - def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None: - """ - Append a stream chunk to the internal buffer. - - Args: - selector: List of strings identifying the stream location - event: The NodeRunStreamChunkEvent to append - - Raises: - ValueError: If the stream is already closed - """ - key = tuple(selector) - - if key in self._closed_streams: - raise ValueError(f"Stream {'.'.join(selector)} is already closed") - - if key not in self._stream_buffers: - self._stream_buffers[key] = [] - self._stream_positions[key] = 0 - - self._stream_buffers[key].append(event) - - def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None: - """ - Pop the next unread stream chunk from the buffer. - - Args: - selector: List of strings identifying the stream location - - Returns: - The next event, or None if no unread events available - """ - key = tuple(selector) - - if key not in self._stream_buffers: - return None - - position = self._stream_positions.get(key, 0) - buffer = self._stream_buffers[key] - - if position >= len(buffer): - return None - - event = buffer[position] - self._stream_positions[key] = position + 1 - return event - - def _has_unread_stream(self, selector: Sequence[str]) -> bool: - """ - Check if the stream has unread events. - - Args: - selector: List of strings identifying the stream location - - Returns: - True if there are unread events, False otherwise - """ - key = tuple(selector) - - if key not in self._stream_buffers: - return False - - position = self._stream_positions.get(key, 0) - return position < len(self._stream_buffers[key]) - - def _close_stream(self, selector: Sequence[str]) -> None: - """ - Mark a stream as closed (no more chunks can be appended). - - Args: - selector: List of strings identifying the stream location - """ - key = tuple(selector) - self._closed_streams.add(key) - - def _is_stream_closed(self, selector: Sequence[str]) -> bool: - """ - Check if a stream is closed. - - Args: - selector: List of strings identifying the stream location - - Returns: - True if the stream is closed, False otherwise - """ - key = tuple(selector) - return key in self._closed_streams - - def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None: - """Convert an in-memory session into its serializable form.""" - - if session is None: - return None - return ResponseSessionState(node_id=session.node_id, index=session.index) - - def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession: - """Rebuild a response session from serialized data.""" - - node = self._graph.nodes.get(session_state.node_id) - if node is None: - raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state") - - session = ResponseSession.from_node(node) - session.index = session_state.index - return session - - def dumps(self) -> str: - """Serialize coordinator state to JSON.""" - - with self._lock: - state = ResponseStreamCoordinatorState( - response_nodes=sorted(self._response_nodes), - active_session=self._serialize_session(self._active_session), - waiting_sessions=[ - session_state - for session in list(self._waiting_sessions) - if (session_state := self._serialize_session(session)) is not None - ], - pending_sessions=[ - session_state - for _, session in sorted(self._response_sessions.items()) - if (session_state := self._serialize_session(session)) is not None - ], - node_execution_ids=dict(sorted(self._node_execution_ids.items())), - paths_map={ - node_id: [path.edges.copy() for path in paths] - for node_id, paths in sorted(self._paths_maps.items()) - }, - stream_buffers=[ - StreamBufferState( - selector=selector, - events=[event.model_copy(deep=True) for event in events], - ) - for selector, events in sorted(self._stream_buffers.items()) - ], - stream_positions=[ - StreamPositionState(selector=selector, position=position) - for selector, position in sorted(self._stream_positions.items()) - ], - closed_streams=sorted(self._closed_streams), - ) - return state.model_dump_json() - - def loads(self, data: str) -> None: - """Restore coordinator state from JSON.""" - - state = ResponseStreamCoordinatorState.model_validate_json(data) - - if state.type != "ResponseStreamCoordinator": - raise ValueError(f"Invalid serialized data type: {state.type}") - - if state.version != "1.0": - raise ValueError(f"Unsupported serialized version: {state.version}") - - with self._lock: - self._response_nodes = set(state.response_nodes) - self._paths_maps = { - node_id: [Path(edges=list(path_edges)) for path_edges in paths] - for node_id, paths in state.paths_map.items() - } - self._node_execution_ids = dict(state.node_execution_ids) - - self._stream_buffers = { - tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events] - for buffer in state.stream_buffers - } - self._stream_positions = { - tuple(position.selector): position.position for position in state.stream_positions - } - for selector in self._stream_buffers: - self._stream_positions.setdefault(selector, 0) - - self._closed_streams = {tuple(selector) for selector in state.closed_streams} - - self._waiting_sessions = deque( - self._session_from_state(session_state) for session_state in state.waiting_sessions - ) - self._response_sessions = { - session_state.node_id: self._session_from_state(session_state) - for session_state in state.pending_sessions - } - self._active_session = self._session_from_state(state.active_session) if state.active_session else None diff --git a/api/graphon/graph_engine/response_coordinator/path.py b/api/graphon/graph_engine/response_coordinator/path.py deleted file mode 100644 index 50f2f4eb21..0000000000 --- a/api/graphon/graph_engine/response_coordinator/path.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Internal path representation for response coordinator. - -This module contains the private Path class used internally by ResponseStreamCoordinator -to track execution paths to response nodes. -""" - -from dataclasses import dataclass, field -from typing import TypeAlias - -EdgeID: TypeAlias = str - - -@dataclass -class Path: - """ - Represents a path of branch edges that must be taken to reach a response node. - - Note: This is an internal class not exposed in the public API. - """ - - edges: list[EdgeID] = field(default_factory=list[EdgeID]) - - def contains_edge(self, edge_id: EdgeID) -> bool: - """Check if this path contains the given edge.""" - return edge_id in self.edges - - def remove_edge(self, edge_id: EdgeID) -> None: - """Remove the given edge from this path in place.""" - if self.contains_edge(edge_id): - self.edges.remove(edge_id) - - def is_empty(self) -> bool: - """Check if the path has no edges (node is reachable).""" - return len(self.edges) == 0 diff --git a/api/graphon/graph_engine/response_coordinator/session.py b/api/graphon/graph_engine/response_coordinator/session.py deleted file mode 100644 index cb877f1504..0000000000 --- a/api/graphon/graph_engine/response_coordinator/session.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Internal response session management for response coordinator. - -This module contains the private ResponseSession class used internally -by ResponseStreamCoordinator to manage streaming sessions. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Protocol, cast - -from graphon.nodes.base.template import Template -from graphon.runtime.graph_runtime_state import NodeProtocol - - -class _ResponseSessionNodeProtocol(NodeProtocol, Protocol): - """Structural contract required from nodes that can open a response session.""" - - def get_streaming_template(self) -> Template: ... - - -@dataclass -class ResponseSession: - """ - Represents an active response streaming session. - - Note: This is an internal class not exposed in the public API. - """ - - node_id: str - template: Template # Template object from the response node - index: int = 0 # Current position in the template segments - - @classmethod - def from_node(cls, node: NodeProtocol) -> ResponseSession: - """ - Create a ResponseSession from a response-capable node. - - The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer. - At runtime this must be a node that implements `get_streaming_template()`. The coordinator decides which - graph nodes should be treated as response-capable before they reach this factory. - - Args: - node: Node from the materialized workflow graph. - - Returns: - ResponseSession configured with the node's streaming template - - Raises: - TypeError: If node does not implement the response-session streaming contract. - """ - response_node = cast(_ResponseSessionNodeProtocol, node) - try: - template = response_node.get_streaming_template() - except AttributeError as exc: - raise TypeError("ResponseSession.from_node requires get_streaming_template() on response nodes") from exc - - return cls( - node_id=node.id, - template=template, - ) - - def is_complete(self) -> bool: - """Check if all segments in the template have been processed.""" - return self.index >= len(self.template.segments) diff --git a/api/graphon/graph_engine/worker.py b/api/graphon/graph_engine/worker.py deleted file mode 100644 index a0844ee48e..0000000000 --- a/api/graphon/graph_engine/worker.py +++ /dev/null @@ -1,204 +0,0 @@ -""" -Worker - Thread implementation for queue-based node execution - -Workers pull node IDs from the ready_queue, execute nodes, and push events -to the event_queue for the dispatcher to process. -""" - -import queue -import threading -import time -from collections.abc import Sequence -from contextlib import AbstractContextManager -from datetime import UTC, datetime -from typing import TYPE_CHECKING, final - -from typing_extensions import override - -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node - -from .ready_queue import ReadyQueue - -if TYPE_CHECKING: - pass - - -@final -class Worker(threading.Thread): - """ - Worker thread that executes nodes from the ready queue. - - Workers continuously pull node IDs from the ready_queue, execute the - corresponding nodes, and push the resulting events to the event_queue - for the dispatcher to process. - """ - - def __init__( - self, - ready_queue: ReadyQueue, - event_queue: queue.Queue[GraphNodeEventBase], - graph: Graph, - layers: Sequence[GraphEngineLayer], - worker_id: int = 0, - execution_context: AbstractContextManager[object] | None = None, - ) -> None: - """ - Initialize worker thread. - - Args: - ready_queue: Ready queue containing node IDs ready for execution - event_queue: Queue for pushing execution events - graph: Graph containing nodes to execute - layers: Graph engine layers for node execution hooks - worker_id: Unique identifier for this worker - execution_context: Optional execution context for context preservation - """ - super().__init__(name=f"GraphWorker-{worker_id}", daemon=True) - self._ready_queue = ready_queue - self._event_queue = event_queue - self._graph = graph - self._worker_id = worker_id - self._execution_context = execution_context - self._stop_event = threading.Event() - self._layers = layers if layers is not None else [] - self._last_task_time = time.time() - self._current_node_started_at: datetime | None = None - - def stop(self) -> None: - """Signal the worker to stop processing.""" - self._stop_event.set() - - @property - def is_idle(self) -> bool: - """Check if the worker is currently idle.""" - # Worker is idle if it hasn't processed a task recently (within 0.2 seconds) - return (time.time() - self._last_task_time) > 0.2 - - @property - def idle_duration(self) -> float: - """Get the duration in seconds since the worker last processed a task.""" - return time.time() - self._last_task_time - - @property - def worker_id(self) -> int: - """Get the worker's ID.""" - return self._worker_id - - @override - def run(self) -> None: - """ - Main worker loop. - - Continuously pulls node IDs from ready_queue, executes them, - and pushes events to event_queue until stopped. - """ - while not self._stop_event.is_set(): - # Try to get a node ID from the ready queue (with timeout) - try: - node_id = self._ready_queue.get(timeout=0.1) - except queue.Empty: - continue - - self._last_task_time = time.time() - node = self._graph.nodes[node_id] - try: - self._current_node_started_at = None - self._execute_node(node) - self._ready_queue.task_done() - except Exception as e: - self._event_queue.put( - self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at) - ) - finally: - self._current_node_started_at = None - - def _execute_node(self, node: Node) -> None: - """ - Execute a single node and handle its events. - - Args: - node: The node instance to execute - """ - node.ensure_execution_id() - - error: Exception | None = None - result_event: GraphNodeEventBase | None = None - - # Execute the node with preserved context if execution context is provided - if self._execution_context is not None: - with self._execution_context: - self._invoke_node_run_start_hooks(node) - try: - node_events = node.run() - for event in node_events: - if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id: - self._current_node_started_at = event.start_at - self._event_queue.put(event) - if is_node_result_event(event): - result_event = event - except Exception as exc: - error = exc - raise - finally: - self._invoke_node_run_end_hooks(node, error, result_event) - else: - self._invoke_node_run_start_hooks(node) - try: - node_events = node.run() - for event in node_events: - if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id: - self._current_node_started_at = event.start_at - self._event_queue.put(event) - if is_node_result_event(event): - result_event = event - except Exception as exc: - error = exc - raise - finally: - self._invoke_node_run_end_hooks(node, error, result_event) - - def _invoke_node_run_start_hooks(self, node: Node) -> None: - """Invoke on_node_run_start hooks for all layers.""" - for layer in self._layers: - try: - layer.on_node_run_start(node) - except Exception: - # Silently ignore layer errors to prevent disrupting node execution - continue - - def _invoke_node_run_end_hooks( - self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None - ) -> None: - """Invoke on_node_run_end hooks for all layers.""" - for layer in self._layers: - try: - layer.on_node_run_end(node, error, result_event) - except Exception: - # Silently ignore layer errors to prevent disrupting node execution - continue - - def _build_fallback_failure_event( - self, node: Node, error: Exception, *, started_at: datetime | None = None - ) -> NodeRunFailedEvent: - """Build a failed event when worker-level execution aborts before a node emits its own result event.""" - failure_time = datetime.now(UTC).replace(tzinfo=None) - error_message = str(error) - return NodeRunFailedEvent( - id=node.execution_id, - node_id=node.id, - node_type=node.node_type, - in_iteration_id=None, - error=error_message, - start_at=started_at or failure_time, - finished_at=failure_time, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=error_message, - error_type=type(error).__name__, - ), - ) diff --git a/api/graphon/graph_engine/worker_management/__init__.py b/api/graphon/graph_engine/worker_management/__init__.py deleted file mode 100644 index 03de1f6daa..0000000000 --- a/api/graphon/graph_engine/worker_management/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Worker management subsystem for graph engine. - -This package manages the worker pool, including creation, -scaling, and activity tracking. -""" - -from .worker_pool import WorkerPool - -__all__ = [ - "WorkerPool", -] diff --git a/api/graphon/graph_engine/worker_management/worker_pool.py b/api/graphon/graph_engine/worker_management/worker_pool.py deleted file mode 100644 index 85cdf1ca21..0000000000 --- a/api/graphon/graph_engine/worker_management/worker_pool.py +++ /dev/null @@ -1,277 +0,0 @@ -""" -Simple worker pool that consolidates functionality. - -This is a simpler implementation that merges WorkerPool, ActivityTracker, -DynamicScaler, and WorkerFactory into a single class. -""" - -import logging -import queue -import threading -from contextlib import AbstractContextManager -from typing import final - -from graphon.graph import Graph -from graphon.graph_events import GraphNodeEventBase - -from ..config import GraphEngineConfig -from ..layers.base import GraphEngineLayer -from ..ready_queue import ReadyQueue -from ..worker import Worker - -logger = logging.getLogger(__name__) - - -@final -class WorkerPool: - """ - Simple worker pool with integrated management. - - This class consolidates all worker management functionality into - a single, simpler implementation without excessive abstraction. - """ - - def __init__( - self, - ready_queue: ReadyQueue, - event_queue: queue.Queue[GraphNodeEventBase], - graph: Graph, - layers: list[GraphEngineLayer], - config: GraphEngineConfig, - execution_context: AbstractContextManager[object] | None = None, - ) -> None: - """ - Initialize the simple worker pool. - - Args: - ready_queue: Ready queue for nodes ready for execution - event_queue: Queue for worker events - graph: The workflow graph - layers: Graph engine layers for node execution hooks - config: GraphEngine worker pool configuration - execution_context: Optional execution context for context preservation - """ - self._ready_queue = ready_queue - self._event_queue = event_queue - self._graph = graph - self._execution_context = execution_context - self._layers = layers - self._config = config - - # Worker management - self._workers: list[Worker] = [] - self._worker_counter = 0 - self._lock = threading.RLock() - self._running = False - - # No longer tracking worker states with callbacks to avoid lock contention - - def start(self, initial_count: int | None = None) -> None: - """ - Start the worker pool. - - Args: - initial_count: Number of workers to start with (auto-calculated if None) - """ - with self._lock: - if self._running: - return - - self._running = True - - # Calculate initial worker count - if initial_count is None: - node_count = len(self._graph.nodes) - if node_count < 10: - initial_count = self._config.min_workers - elif node_count < 50: - initial_count = min(self._config.min_workers + 1, self._config.max_workers) - else: - initial_count = min(self._config.min_workers + 2, self._config.max_workers) - - logger.debug( - "Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)", - initial_count, - node_count, - self._config.min_workers, - self._config.max_workers, - ) - - # Create initial workers - for _ in range(initial_count): - self._create_worker() - - def stop(self) -> None: - """Stop all workers in the pool.""" - with self._lock: - self._running = False - worker_count = len(self._workers) - - if worker_count > 0: - logger.debug("Stopping worker pool: %d workers", worker_count) - - # Stop all workers - for worker in self._workers: - worker.stop() - - # Wait for workers to finish - for worker in self._workers: - if worker.is_alive(): - worker.join(timeout=2.0) - - self._workers.clear() - - def _create_worker(self) -> None: - """Create and start a new worker.""" - worker_id = self._worker_counter - self._worker_counter += 1 - - worker = Worker( - ready_queue=self._ready_queue, - event_queue=self._event_queue, - graph=self._graph, - layers=self._layers, - worker_id=worker_id, - execution_context=self._execution_context, - ) - - worker.start() - self._workers.append(worker) - - def _remove_worker(self, worker: Worker, worker_id: int) -> None: - """Remove a specific worker from the pool.""" - # Stop the worker - worker.stop() - - # Wait for it to finish - if worker.is_alive(): - worker.join(timeout=2.0) - - # Remove from list - if worker in self._workers: - self._workers.remove(worker) - - def _try_scale_up(self, queue_depth: int, current_count: int) -> bool: - """ - Try to scale up workers if needed. - - Args: - queue_depth: Current queue depth - current_count: Current number of workers - - Returns: - True if scaled up, False otherwise - """ - if queue_depth > self._config.scale_up_threshold and current_count < self._config.max_workers: - old_count = current_count - self._create_worker() - - logger.debug( - "Scaled up workers: %d -> %d (queue_depth=%d exceeded threshold=%d)", - old_count, - len(self._workers), - queue_depth, - self._config.scale_up_threshold, - ) - return True - return False - - def _try_scale_down(self, queue_depth: int, current_count: int, active_count: int, idle_count: int) -> bool: - """ - Try to scale down workers if we have excess capacity. - - Args: - queue_depth: Current queue depth - current_count: Current number of workers - active_count: Number of active workers - idle_count: Number of idle workers - - Returns: - True if scaled down, False otherwise - """ - # Skip if we're at minimum or have no idle workers - if current_count <= self._config.min_workers or idle_count == 0: - return False - - # Check if we have excess capacity - has_excess_capacity = ( - queue_depth <= active_count # Active workers can handle current queue - or idle_count > active_count # More idle than active workers - or (queue_depth == 0 and idle_count > 0) # No work and have idle workers - ) - - if not has_excess_capacity: - return False - - # Find and remove idle workers that have been idle long enough - workers_to_remove: list[tuple[Worker, int]] = [] - - for worker in self._workers: - # Check if worker is idle and has exceeded idle time threshold - if worker.is_idle and worker.idle_duration >= self._config.scale_down_idle_time: - # Don't remove if it would leave us unable to handle the queue - remaining_workers = current_count - len(workers_to_remove) - 1 - if remaining_workers >= self._config.min_workers and remaining_workers >= max(1, queue_depth // 2): - workers_to_remove.append((worker, worker.worker_id)) - # Only remove one worker per check to avoid aggressive scaling - break - - # Remove idle workers if any found - if workers_to_remove: - old_count = current_count - for worker, worker_id in workers_to_remove: - self._remove_worker(worker, worker_id) - - logger.debug( - "Scaled down workers: %d -> %d (removed %d idle workers after %.1fs, " - "queue_depth=%d, active=%d, idle=%d)", - old_count, - len(self._workers), - len(workers_to_remove), - self._config.scale_down_idle_time, - queue_depth, - active_count, - idle_count - len(workers_to_remove), - ) - return True - - return False - - def check_and_scale(self) -> None: - """Check and perform scaling if needed.""" - with self._lock: - if not self._running: - return - - current_count = len(self._workers) - queue_depth = self._ready_queue.qsize() - - # Count active vs idle workers by querying their state directly - idle_count = sum(1 for worker in self._workers if worker.is_idle) - active_count = current_count - idle_count - - # Try to scale up if queue is backing up - self._try_scale_up(queue_depth, current_count) - - # Try to scale down if we have excess capacity - self._try_scale_down(queue_depth, current_count, active_count, idle_count) - - def get_worker_count(self) -> int: - """Get current number of workers.""" - with self._lock: - return len(self._workers) - - def get_status(self) -> dict[str, int]: - """ - Get pool status information. - - Returns: - Dictionary with status information - """ - with self._lock: - return { - "total_workers": len(self._workers), - "queue_depth": self._ready_queue.qsize(), - "min_workers": self._config.min_workers, - "max_workers": self._config.max_workers, - } diff --git a/api/graphon/graph_events/__init__.py b/api/graphon/graph_events/__init__.py deleted file mode 100644 index 7cec587a05..0000000000 --- a/api/graphon/graph_events/__init__.py +++ /dev/null @@ -1,84 +0,0 @@ -# Agent events -from .agent import NodeRunAgentLogEvent - -# Base events -from .base import ( - BaseGraphEvent, - GraphEngineEvent, - GraphNodeEventBase, -) - -# Graph events -from .graph import ( - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) - -# Iteration events -from .iteration import ( - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, -) - -# Loop events -from .loop import ( - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, -) - -# Node events -from .node import ( - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunPauseRequestedEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, - is_node_result_event, -) - -__all__ = [ - "BaseGraphEvent", - "GraphEngineEvent", - "GraphNodeEventBase", - "GraphRunAbortedEvent", - "GraphRunFailedEvent", - "GraphRunPartialSucceededEvent", - "GraphRunPausedEvent", - "GraphRunStartedEvent", - "GraphRunSucceededEvent", - "NodeRunAgentLogEvent", - "NodeRunExceptionEvent", - "NodeRunFailedEvent", - "NodeRunHumanInputFormFilledEvent", - "NodeRunHumanInputFormTimeoutEvent", - "NodeRunIterationFailedEvent", - "NodeRunIterationNextEvent", - "NodeRunIterationStartedEvent", - "NodeRunIterationSucceededEvent", - "NodeRunLoopFailedEvent", - "NodeRunLoopNextEvent", - "NodeRunLoopStartedEvent", - "NodeRunLoopSucceededEvent", - "NodeRunPauseRequestedEvent", - "NodeRunRetrieverResourceEvent", - "NodeRunRetryEvent", - "NodeRunStartedEvent", - "NodeRunStreamChunkEvent", - "NodeRunSucceededEvent", - "NodeRunVariableUpdatedEvent", - "is_node_result_event", -] diff --git a/api/graphon/graph_events/agent.py b/api/graphon/graph_events/agent.py deleted file mode 100644 index 759fe3a71c..0000000000 --- a/api/graphon/graph_events/agent.py +++ /dev/null @@ -1,17 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import Field - -from .base import GraphAgentNodeEventBase - - -class NodeRunAgentLogEvent(GraphAgentNodeEventBase): - message_id: str = Field(..., description="message id") - label: str = Field(..., description="label") - node_execution_id: str = Field(..., description="node execution id") - parent_id: str | None = Field(..., description="parent id") - error: str | None = Field(..., description="error") - status: str = Field(..., description="status") - data: Mapping[str, Any] = Field(..., description="data") - metadata: Mapping[str, object] = Field(default_factory=dict) diff --git a/api/graphon/graph_events/base.py b/api/graphon/graph_events/base.py deleted file mode 100644 index 4ea9787b9a..0000000000 --- a/api/graphon/graph_events/base.py +++ /dev/null @@ -1,31 +0,0 @@ -from pydantic import BaseModel, Field - -from graphon.enums import NodeType -from graphon.node_events import NodeRunResult - - -class GraphEngineEvent(BaseModel): - pass - - -class BaseGraphEvent(GraphEngineEvent): - pass - - -class GraphNodeEventBase(GraphEngineEvent): - id: str = Field(..., description="node execution id") - node_id: str - node_type: NodeType - - in_iteration_id: str | None = None - """iteration id if node is in iteration""" - in_loop_id: str | None = None - """loop id if node is in loop""" - - # The version of the node, or "1" if not specified. - node_version: str = "1" - node_run_result: NodeRunResult = Field(default_factory=NodeRunResult) - - -class GraphAgentNodeEventBase(GraphNodeEventBase): - pass diff --git a/api/graphon/graph_events/graph.py b/api/graphon/graph_events/graph.py deleted file mode 100644 index 3782cb49bc..0000000000 --- a/api/graphon/graph_events/graph.py +++ /dev/null @@ -1,57 +0,0 @@ -from pydantic import Field - -from graphon.entities.pause_reason import PauseReason -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.graph_events import BaseGraphEvent - - -class GraphRunStartedEvent(BaseGraphEvent): - # Reason is emitted for workflow start events and is always set. - reason: WorkflowStartReason = Field( - default=WorkflowStartReason.INITIAL, - description="reason for workflow start", - ) - - -class GraphRunSucceededEvent(BaseGraphEvent): - """Event emitted when a run completes successfully with final outputs.""" - - outputs: dict[str, object] = Field( - default_factory=dict, - description="Final workflow outputs keyed by output selector.", - ) - - -class GraphRunFailedEvent(BaseGraphEvent): - error: str = Field(..., description="failed reason") - exceptions_count: int = Field(description="exception count", default=0) - - -class GraphRunPartialSucceededEvent(BaseGraphEvent): - """Event emitted when a run finishes with partial success and failures.""" - - exceptions_count: int = Field(..., description="exception count") - outputs: dict[str, object] = Field( - default_factory=dict, - description="Outputs that were materialised before failures occurred.", - ) - - -class GraphRunAbortedEvent(BaseGraphEvent): - """Event emitted when a graph run is aborted by user command.""" - - reason: str | None = Field(default=None, description="reason for abort") - outputs: dict[str, object] = Field( - default_factory=dict, - description="Outputs produced before the abort was requested.", - ) - - -class GraphRunPausedEvent(BaseGraphEvent): - """Event emitted when a graph run is paused by user command.""" - - reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list) - outputs: dict[str, object] = Field( - default_factory=dict, - description="Outputs available to the client while the run is paused.", - ) diff --git a/api/graphon/graph_events/human_input.py b/api/graphon/graph_events/human_input.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/graph_events/iteration.py b/api/graphon/graph_events/iteration.py deleted file mode 100644 index 28627395fd..0000000000 --- a/api/graphon/graph_events/iteration.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import GraphNodeEventBase - - -class NodeRunIterationStartedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class NodeRunIterationNextEvent(GraphNodeEventBase): - node_title: str - index: int = Field(..., description="index") - pre_iteration_output: Any = None - - -class NodeRunIterationSucceededEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class NodeRunIterationFailedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/graphon/graph_events/loop.py b/api/graphon/graph_events/loop.py deleted file mode 100644 index 7cdc5427e2..0000000000 --- a/api/graphon/graph_events/loop.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import GraphNodeEventBase - - -class NodeRunLoopStartedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class NodeRunLoopNextEvent(GraphNodeEventBase): - node_title: str - index: int = Field(..., description="index") - pre_loop_output: Any = None - - -class NodeRunLoopSucceededEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class NodeRunLoopFailedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/graphon/graph_events/node.py b/api/graphon/graph_events/node.py deleted file mode 100644 index 471ae08ee7..0000000000 --- a/api/graphon/graph_events/node.py +++ /dev/null @@ -1,106 +0,0 @@ -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any - -from pydantic import Field - -from graphon.entities.pause_reason import PauseReason -from graphon.variables.variables import Variable - -from .base import GraphNodeEventBase - - -class NodeRunStartedEvent(GraphNodeEventBase): - node_title: str - predecessor_node_id: str | None = None - start_at: datetime = Field(..., description="node start time") - extras: dict[str, object] = Field(default_factory=dict) - - # FIXME(-LAN-): only for ToolNode - provider_type: str = "" - provider_id: str = "" - - -class NodeRunStreamChunkEvent(GraphNodeEventBase): - # Spec-compliant fields - selector: Sequence[str] = Field( - ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" - ) - chunk: str = Field(..., description="the actual chunk content") - is_final: bool = Field(default=False, description="indicates if this is the last chunk") - - -class NodeRunRetrieverResourceEvent(GraphNodeEventBase): - retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources") - context: str = Field(..., description="context") - - -class NodeRunSucceededEvent(GraphNodeEventBase): - start_at: datetime = Field(..., description="node start time") - finished_at: datetime | None = Field(default=None, description="node finish time") - - -class NodeRunVariableUpdatedEvent(GraphNodeEventBase): - """Request that the engine apply a variable update before downstream observers continue.""" - - variable: Variable = Field(..., description="Updated variable payload to apply.") - - -class NodeRunFailedEvent(GraphNodeEventBase): - error: str = Field(..., description="error") - start_at: datetime = Field(..., description="node start time") - finished_at: datetime | None = Field(default=None, description="node finish time") - - -class NodeRunExceptionEvent(GraphNodeEventBase): - error: str = Field(..., description="error") - start_at: datetime = Field(..., description="node start time") - finished_at: datetime | None = Field(default=None, description="node finish time") - - -class NodeRunRetryEvent(NodeRunStartedEvent): - error: str = Field(..., description="error") - retry_index: int = Field(..., description="which retry attempt is about to be performed") - - -class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase): - """Emitted when a HumanInput form is submitted and before the node finishes.""" - - node_title: str = Field(..., description="HumanInput node title") - rendered_content: str = Field(..., description="Markdown content rendered with user inputs.") - action_id: str = Field(..., description="User action identifier chosen in the form.") - action_text: str = Field(..., description="Display text of the chosen action button.") - - -class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase): - """Emitted when a HumanInput form times out.""" - - node_title: str = Field(..., description="HumanInput node title") - expiration_time: datetime = Field(..., description="Form expiration time") - - -class NodeRunPauseRequestedEvent(GraphNodeEventBase): - reason: PauseReason = Field(..., description="pause reason") - - -def is_node_result_event(event: GraphNodeEventBase) -> bool: - """ - Check if an event is a final result event from node execution. - - A result event indicates the completion of a node execution and contains - runtime information such as inputs, outputs, or error details. - - Args: - event: The event to check - - Returns: - True if the event is a node result event (succeeded/failed/paused), False otherwise - """ - return isinstance( - event, - ( - NodeRunSucceededEvent, - NodeRunFailedEvent, - NodeRunPauseRequestedEvent, - ), - ) diff --git a/api/graphon/model_runtime/README.md b/api/graphon/model_runtime/README.md deleted file mode 100644 index b9d2c55210..0000000000 --- a/api/graphon/model_runtime/README.md +++ /dev/null @@ -1,51 +0,0 @@ -# Model Runtime - -This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers. - -- On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers, -- On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic. - -## Features - -- Supports capability invocation for 6 types of models - - - `LLM` - LLM text completion, dialogue, pre-computed tokens capability - - `Text Embedding Model` - Text Embedding, pre-computed tokens capability - - `Rerank Model` - Segment Rerank capability - - `Speech-to-text Model` - Speech to text capability - - `Text-to-speech Model` - Text to speech capability - - `Moderation` - Moderation capability - -- Model provider display - - Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. - -- Selectable model list display - - After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models. - - In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models. - -- Provider/model credential authentication - - The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. - -## Structure - -Model Runtime is divided into three layers: - -- The outermost layer is the factory method - - It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials. - -- The second layer is the provider layer - - It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers. - -- The bottom layer is the model layer - - It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types). - -## Documentation - -For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/). diff --git a/api/graphon/model_runtime/README_CN.md b/api/graphon/model_runtime/README_CN.md deleted file mode 100644 index 0a8b56b3fe..0000000000 --- a/api/graphon/model_runtime/README_CN.md +++ /dev/null @@ -1,64 +0,0 @@ -# Model Runtime - -该模块提供了各模型的调用、鉴权接口,并为 Dify 提供了统一的模型供应商的信息和凭据表单规则。 - -- 一方面将模型和上下游解耦,方便开发者对模型横向扩展, -- 另一方面提供了只需在后端定义供应商和模型,即可在前端页面直接展示,无需修改前端逻辑。 - -## 功能介绍 - -- 支持 6 种模型类型的能力调用 - - - `LLM` - LLM 文本补全、对话,预计算 tokens 能力 - - `Text Embedding Model` - 文本 Embedding,预计算 tokens 能力 - - `Rerank Model` - 分段 Rerank 能力 - - `Speech-to-text Model` - 语音转文本能力 - - `Text-to-speech Model` - 文本转语音能力 - - `Moderation` - Moderation 能力 - -- 模型供应商展示 - - 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。 - -- 可选择的模型列表展示 - - 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 - - 除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。 - -- 供应商/模型凭据鉴权 - - 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。 - -## 结构 - -Model Runtime 分三层: - -- 最外层为工厂方法 - - 提供获取所有供应商、所有模型列表、获取供应商实例、供应商/模型凭据鉴权方法。 - -- 第二层为供应商层 - - 提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。 - - 对于供应商/模型凭据,有两种情况 - - - 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据 - - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 - - 当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。 - -- 最底层为模型层 - - 提供各种模型类型的直接调用、预定义模型配置信息、获取预定义/远程模型列表、模型凭据鉴权方法,不同模型额外提供了特殊方法,如 LLM 提供预计算 tokens 方法、获取费用信息方法等,**可横向扩展**同供应商下不同的模型(支持的模型类型下)。 - - 在这里我们需要先区分模型参数与模型凭据。 - - - 模型参数 (**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在 DifyRuntime 中,他们的参数名一般为**model_parameters: dict[str, any]**。 - - - 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。 - -## 文档 - -有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。 diff --git a/api/graphon/model_runtime/__init__.py b/api/graphon/model_runtime/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/model_runtime/callbacks/__init__.py b/api/graphon/model_runtime/callbacks/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/model_runtime/callbacks/base_callback.py b/api/graphon/model_runtime/callbacks/base_callback.py deleted file mode 100644 index cd85cf6301..0000000000 --- a/api/graphon/model_runtime/callbacks/base_callback.py +++ /dev/null @@ -1,159 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence - -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - -_TEXT_COLOR_MAPPING = { - "blue": "36;1", - "yellow": "33;1", - "pink": "38;5;200", - "green": "32;1", - "red": "31;1", -} - - -class Callback(ABC): - """ - Base class for callbacks. - Only for LLM. - """ - - raise_error: bool = False - - @abstractmethod - def on_before_invoke( - self, - llm_instance: AIModel, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - Before invoke callback - - :param llm_instance: LLM instance - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: optional end-user identifier for the invocation - :param invocation_context: opaque request metadata for the current invocation - """ - raise NotImplementedError() - - @abstractmethod - def on_new_chunk( - self, - llm_instance: AIModel, - chunk: LLMResultChunk, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - On new chunk callback - - :param llm_instance: LLM instance - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: optional end-user identifier for the invocation - :param invocation_context: opaque request metadata for the current invocation - """ - raise NotImplementedError() - - @abstractmethod - def on_after_invoke( - self, - llm_instance: AIModel, - result: LLMResult, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - After invoke callback - - :param llm_instance: LLM instance - :param result: result - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: optional end-user identifier for the invocation - :param invocation_context: opaque request metadata for the current invocation - """ - raise NotImplementedError() - - @abstractmethod - def on_invoke_error( - self, - llm_instance: AIModel, - ex: Exception, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - Invoke error callback - - :param llm_instance: LLM instance - :param ex: exception - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: optional end-user identifier for the invocation - :param invocation_context: opaque request metadata for the current invocation - """ - raise NotImplementedError() - - def print_text(self, text: str, color: str | None = None, end: str = ""): - """Print text with highlighting and no end characters.""" - text_to_print = self._get_colored_text(text, color) if color else text - print(text_to_print, end=end) - - def _get_colored_text(self, text: str, color: str) -> str: - """Get colored text.""" - color_str = _TEXT_COLOR_MAPPING[color] - return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" diff --git a/api/graphon/model_runtime/callbacks/logging_callback.py b/api/graphon/model_runtime/callbacks/logging_callback.py deleted file mode 100644 index f96eb446fc..0000000000 --- a/api/graphon/model_runtime/callbacks/logging_callback.py +++ /dev/null @@ -1,180 +0,0 @@ -import json -import logging -import sys -from collections.abc import Mapping, Sequence -from typing import cast - -from graphon.model_runtime.callbacks.base_callback import Callback -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -class LoggingCallback(Callback): - def on_before_invoke( - self, - llm_instance: AIModel, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - Before invoke callback - - :param llm_instance: LLM instance - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: optional end-user identifier for the invocation - :param invocation_context: opaque request metadata for the current invocation - """ - self.print_text("\n[on_llm_before_invoke]\n", color="blue") - self.print_text(f"Model: {model}\n", color="blue") - self.print_text("Parameters:\n", color="blue") - for key, value in model_parameters.items(): - self.print_text(f"\t{key}: {value}\n", color="blue") - - if stop: - self.print_text(f"\tstop: {stop}\n", color="blue") - - if tools: - self.print_text("\tTools:\n", color="blue") - for tool in tools: - self.print_text(f"\t\t{tool.name}\n", color="blue") - - self.print_text(f"Stream: {stream}\n", color="blue") - if user: - self.print_text(f"User: {user}\n", color="blue") - - if invocation_context: - self.print_text(f"Invocation context: {dict(invocation_context)}\n", color="blue") - - self.print_text("Prompt messages:\n", color="blue") - for prompt_message in prompt_messages: - if prompt_message.name: - self.print_text(f"\tname: {prompt_message.name}\n", color="blue") - - self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue") - self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue") - - if stream: - self.print_text("\n[on_llm_new_chunk]") - - def on_new_chunk( - self, - llm_instance: AIModel, - chunk: LLMResultChunk, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - On new chunk callback - - :param llm_instance: LLM instance - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - """ - _ = user, invocation_context - sys.stdout.write(cast(str, chunk.delta.message.content)) - sys.stdout.flush() - - def on_after_invoke( - self, - llm_instance: AIModel, - result: LLMResult, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - After invoke callback - - :param llm_instance: LLM instance - :param result: result - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - """ - _ = user, invocation_context - self.print_text("\n[on_llm_after_invoke]\n", color="yellow") - self.print_text(f"Content: {result.message.content}\n", color="yellow") - - if result.message.tool_calls: - self.print_text("Tool calls:\n", color="yellow") - for tool_call in result.message.tool_calls: - self.print_text(f"\t{tool_call.id}\n", color="yellow") - self.print_text(f"\t{tool_call.function.name}\n", color="yellow") - self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow") - - self.print_text(f"Model: {result.model}\n", color="yellow") - self.print_text(f"Usage: {result.usage}\n", color="yellow") - self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow") - - def on_invoke_error( - self, - llm_instance: AIModel, - ex: Exception, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - Invoke error callback - - :param llm_instance: LLM instance - :param ex: exception - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - """ - _ = user, invocation_context - self.print_text("\n[on_llm_invoke_error]\n", color="red") - logger.exception(ex) diff --git a/api/graphon/model_runtime/entities/__init__.py b/api/graphon/model_runtime/entities/__init__.py deleted file mode 100644 index a24e437d48..0000000000 --- a/api/graphon/model_runtime/entities/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -from .llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from .message_entities import ( - AssistantPromptMessage, - AudioPromptMessageContent, - DocumentPromptMessageContent, - ImagePromptMessageContent, - MultiModalPromptMessageContent, - PromptMessage, - PromptMessageContent, - PromptMessageContentType, - PromptMessageRole, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, - VideoPromptMessageContent, -) -from .model_entities import ModelPropertyKey - -__all__ = [ - "AssistantPromptMessage", - "AudioPromptMessageContent", - "DocumentPromptMessageContent", - "ImagePromptMessageContent", - "LLMMode", - "LLMResult", - "LLMResultChunk", - "LLMResultChunkDelta", - "LLMUsage", - "ModelPropertyKey", - "MultiModalPromptMessageContent", - "PromptMessage", - "PromptMessageContent", - "PromptMessageContentType", - "PromptMessageRole", - "PromptMessageTool", - "SystemPromptMessage", - "TextPromptMessageContent", - "ToolPromptMessage", - "UserPromptMessage", - "VideoPromptMessageContent", -] diff --git a/api/graphon/model_runtime/entities/common_entities.py b/api/graphon/model_runtime/entities/common_entities.py deleted file mode 100644 index b673efae22..0000000000 --- a/api/graphon/model_runtime/entities/common_entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from pydantic import BaseModel, model_validator - - -class I18nObject(BaseModel): - """ - Model class for i18n object. - """ - - zh_Hans: str | None = None - en_US: str - - @model_validator(mode="after") - def _(self): - if not self.zh_Hans: - self.zh_Hans = self.en_US - return self diff --git a/api/graphon/model_runtime/entities/defaults.py b/api/graphon/model_runtime/entities/defaults.py deleted file mode 100644 index bcce17c5d5..0000000000 --- a/api/graphon/model_runtime/entities/defaults.py +++ /dev/null @@ -1,130 +0,0 @@ -from graphon.model_runtime.entities.model_entities import DefaultParameterName - -PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { - DefaultParameterName.TEMPERATURE: { - "label": { - "en_US": "Temperature", - "zh_Hans": "温度", - }, - "type": "float", - "help": { - "en_US": "Controls randomness. Lower temperature results in less random completions." - " As the temperature approaches zero, the model will become deterministic and repetitive." - " Higher temperature results in more random completions.", - "zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。" - "较高的温度会导致更多的随机完成。", - }, - "required": False, - "default": 0.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.TOP_P: { - "label": { - "en_US": "Top P", - "zh_Hans": "Top P", - }, - "type": "float", - "help": { - "en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options" - " are considered.", - "zh_Hans": "通过核心采样控制多样性:0.5 表示考虑了一半的所有可能性加权选项。", - }, - "required": False, - "default": 1.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.TOP_K: { - "label": { - "en_US": "Top K", - "zh_Hans": "Top K", - }, - "type": "int", - "help": { - "en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.", - "zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。", - }, - "required": False, - "default": 50, - "min": 1, - "max": 100, - "precision": 0, - }, - DefaultParameterName.PRESENCE_PENALTY: { - "label": { - "en_US": "Presence Penalty", - "zh_Hans": "存在惩罚", - }, - "type": "float", - "help": { - "en_US": "Applies a penalty to the log-probability of tokens already in the text.", - "zh_Hans": "对文本中已有的标记的对数概率施加惩罚。", - }, - "required": False, - "default": 0.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.FREQUENCY_PENALTY: { - "label": { - "en_US": "Frequency Penalty", - "zh_Hans": "频率惩罚", - }, - "type": "float", - "help": { - "en_US": "Applies a penalty to the log-probability of tokens that appear in the text.", - "zh_Hans": "对文本中出现的标记的对数概率施加惩罚。", - }, - "required": False, - "default": 0.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.MAX_TOKENS: { - "label": { - "en_US": "Max Tokens", - "zh_Hans": "最大 Token 数", - }, - "type": "int", - "help": { - "en_US": "Specifies the upper limit on the length of generated results." - " If the generated results are truncated, you can increase this parameter.", - "zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。", - }, - "required": False, - "default": 64, - "min": 1, - "max": 2048, - "precision": 0, - }, - DefaultParameterName.RESPONSE_FORMAT: { - "label": { - "en_US": "Response Format", - "zh_Hans": "回复格式", - }, - "type": "string", - "help": { - "en_US": "Set a response format, ensure the output from llm is a valid code block as possible," - " such as JSON, XML, etc.", - "zh_Hans": "设置一个返回格式,确保 llm 的输出尽可能是有效的代码块,如 JSON、XML 等", - }, - "required": False, - "options": ["JSON", "XML"], - }, - DefaultParameterName.JSON_SCHEMA: { - "label": { - "en_US": "JSON Schema", - }, - "type": "text", - "help": { - "en_US": "Set a response json schema will ensure LLM to adhere it.", - "zh_Hans": "设置返回的 json schema,llm 将按照它返回", - }, - "required": False, - }, -} diff --git a/api/graphon/model_runtime/entities/llm_entities.py b/api/graphon/model_runtime/entities/llm_entities.py deleted file mode 100644 index bfc80f21c5..0000000000 --- a/api/graphon/model_runtime/entities/llm_entities.py +++ /dev/null @@ -1,219 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from decimal import Decimal -from enum import StrEnum -from typing import Any, TypedDict, Union - -from pydantic import BaseModel, Field - -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage -from graphon.model_runtime.entities.model_entities import ModelUsage, PriceInfo - - -class LLMMode(StrEnum): - """ - Enum class for large language model mode. - """ - - COMPLETION = "completion" - CHAT = "chat" - - -class LLMUsageMetadata(TypedDict, total=False): - """ - TypedDict for LLM usage metadata. - All fields are optional. - """ - - prompt_tokens: int - completion_tokens: int - total_tokens: int - prompt_unit_price: Union[float, str] - completion_unit_price: Union[float, str] - total_price: Union[float, str] - currency: str - prompt_price_unit: Union[float, str] - completion_price_unit: Union[float, str] - prompt_price: Union[float, str] - completion_price: Union[float, str] - latency: float - time_to_first_token: float - time_to_generate: float - - -class LLMUsage(ModelUsage): - """ - Model class for llm usage. - """ - - prompt_tokens: int - prompt_unit_price: Decimal - prompt_price_unit: Decimal - prompt_price: Decimal - completion_tokens: int - completion_unit_price: Decimal - completion_price_unit: Decimal - completion_price: Decimal - total_tokens: int - total_price: Decimal - currency: str - latency: float - time_to_first_token: float | None = None - time_to_generate: float | None = None - - @classmethod - def empty_usage(cls): - return cls( - prompt_tokens=0, - prompt_unit_price=Decimal("0.0"), - prompt_price_unit=Decimal("0.0"), - prompt_price=Decimal("0.0"), - completion_tokens=0, - completion_unit_price=Decimal("0.0"), - completion_price_unit=Decimal("0.0"), - completion_price=Decimal("0.0"), - total_tokens=0, - total_price=Decimal("0.0"), - currency="USD", - latency=0.0, - time_to_first_token=None, - time_to_generate=None, - ) - - @classmethod - def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage: - """ - Create LLMUsage instance from metadata dictionary with default values. - - Args: - metadata: TypedDict containing usage metadata - - Returns: - LLMUsage instance with values from metadata or defaults - """ - prompt_tokens = metadata.get("prompt_tokens", 0) - completion_tokens = metadata.get("completion_tokens", 0) - total_tokens = metadata.get("total_tokens", 0) - - # If total_tokens is not provided but prompt and completion tokens are, - # calculate total_tokens - if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0): - total_tokens = prompt_tokens + completion_tokens - - return cls( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))), - completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))), - total_price=Decimal(str(metadata.get("total_price", 0))), - currency=metadata.get("currency", "USD"), - prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))), - completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))), - prompt_price=Decimal(str(metadata.get("prompt_price", 0))), - completion_price=Decimal(str(metadata.get("completion_price", 0))), - latency=metadata.get("latency", 0.0), - time_to_first_token=metadata.get("time_to_first_token"), - time_to_generate=metadata.get("time_to_generate"), - ) - - def plus(self, other: LLMUsage) -> LLMUsage: - """ - Add two LLMUsage instances together. - - :param other: Another LLMUsage instance to add - :return: A new LLMUsage instance with summed values - """ - if self.total_tokens == 0: - return other - else: - return LLMUsage( - prompt_tokens=self.prompt_tokens + other.prompt_tokens, - prompt_unit_price=other.prompt_unit_price, - prompt_price_unit=other.prompt_price_unit, - prompt_price=self.prompt_price + other.prompt_price, - completion_tokens=self.completion_tokens + other.completion_tokens, - completion_unit_price=other.completion_unit_price, - completion_price_unit=other.completion_price_unit, - completion_price=self.completion_price + other.completion_price, - total_tokens=self.total_tokens + other.total_tokens, - total_price=self.total_price + other.total_price, - currency=other.currency, - latency=self.latency + other.latency, - time_to_first_token=other.time_to_first_token, - time_to_generate=other.time_to_generate, - ) - - def __add__(self, other: LLMUsage) -> LLMUsage: - """ - Overload the + operator to add two LLMUsage instances. - - :param other: Another LLMUsage instance to add - :return: A new LLMUsage instance with summed values - """ - return self.plus(other) - - -class LLMResult(BaseModel): - """ - Model class for llm result. - """ - - id: str | None = None - model: str - prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) - message: AssistantPromptMessage - usage: LLMUsage - system_fingerprint: str | None = None - reasoning_content: str | None = None - - -class LLMStructuredOutput(BaseModel): - """ - Model class for llm structured output. - """ - - structured_output: Mapping[str, Any] | None = None - - -class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput): - """ - Model class for llm result with structured output. - """ - - -class LLMResultChunkDelta(BaseModel): - """ - Model class for llm result chunk delta. - """ - - index: int - message: AssistantPromptMessage - usage: LLMUsage | None = None - finish_reason: str | None = None - - -class LLMResultChunk(BaseModel): - """ - Model class for llm result chunk. - """ - - model: str - prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) - system_fingerprint: str | None = None - delta: LLMResultChunkDelta - - -class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput): - """ - Model class for llm result chunk with structured output. - """ - - -class NumTokensResult(PriceInfo): - """ - Model class for number of tokens result. - """ - - tokens: int diff --git a/api/graphon/model_runtime/entities/message_entities.py b/api/graphon/model_runtime/entities/message_entities.py deleted file mode 100644 index 402bfdc606..0000000000 --- a/api/graphon/model_runtime/entities/message_entities.py +++ /dev/null @@ -1,279 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from collections.abc import Mapping, Sequence -from enum import StrEnum, auto -from typing import Annotated, Any, Literal, Union - -from pydantic import BaseModel, Field, field_serializer, field_validator - - -class PromptMessageRole(StrEnum): - """ - Enum class for prompt message. - """ - - SYSTEM = auto() - USER = auto() - ASSISTANT = auto() - TOOL = auto() - - @classmethod - def value_of(cls, value: str) -> PromptMessageRole: - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid prompt message type value {value}") - - -class PromptMessageTool(BaseModel): - """ - Model class for prompt message tool. - """ - - name: str - description: str - parameters: dict - - -class PromptMessageFunction(BaseModel): - """ - Model class for prompt message function. - """ - - type: str = "function" - function: PromptMessageTool - - -class PromptMessageContentType(StrEnum): - """ - Enum class for prompt message content type. - """ - - TEXT = auto() - IMAGE = auto() - AUDIO = auto() - VIDEO = auto() - DOCUMENT = auto() - - -class PromptMessageContent(ABC, BaseModel): - """ - Model class for prompt message content. - """ - - type: PromptMessageContentType - - -class TextPromptMessageContent(PromptMessageContent): - """ - Model class for text prompt message content. - """ - - type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore - data: str - - -class MultiModalPromptMessageContent(PromptMessageContent): - """ - Model class for multi-modal prompt message content. - """ - - format: str = Field(default=..., description="the format of multi-modal file") - base64_data: str = Field(default="", description="the base64 data of multi-modal file") - url: str = Field(default="", description="the url of multi-modal file") - mime_type: str = Field(default=..., description="the mime type of multi-modal file") - filename: str = Field(default="", description="the filename of multi-modal file") - - @property - def data(self): - return self.url or f"data:{self.mime_type};base64,{self.base64_data}" - - -class VideoPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore - - -class AudioPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore - - -class ImagePromptMessageContent(MultiModalPromptMessageContent): - """ - Model class for image prompt message content. - """ - - class DETAIL(StrEnum): - LOW = auto() - HIGH = auto() - - type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore - detail: DETAIL = DETAIL.LOW - - -class DocumentPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore - - -PromptMessageContentUnionTypes = Annotated[ - Union[ - TextPromptMessageContent, - ImagePromptMessageContent, - DocumentPromptMessageContent, - AudioPromptMessageContent, - VideoPromptMessageContent, - ], - Field(discriminator="type"), -] - - -CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = { - PromptMessageContentType.TEXT: TextPromptMessageContent, - PromptMessageContentType.IMAGE: ImagePromptMessageContent, - PromptMessageContentType.AUDIO: AudioPromptMessageContent, - PromptMessageContentType.VIDEO: VideoPromptMessageContent, - PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent, -} - - -class PromptMessage(ABC, BaseModel): - """ - Model class for prompt message. - """ - - role: PromptMessageRole - content: str | list[PromptMessageContentUnionTypes] | None = None - name: str | None = None - - def is_empty(self) -> bool: - """ - Check if prompt message is empty. - - :return: True if prompt message is empty, False otherwise - """ - return not self.content - - def get_text_content(self) -> str: - """ - Get text content from prompt message. - - :return: Text content as string, empty string if no text content - """ - if isinstance(self.content, str): - return self.content - elif isinstance(self.content, list): - text_parts = [] - for item in self.content: - if isinstance(item, TextPromptMessageContent): - text_parts.append(item.data) - return "".join(text_parts) - else: - return "" - - @field_validator("content", mode="before") - @classmethod - def validate_content(cls, v): - if isinstance(v, list): - prompts = [] - for prompt in v: - if isinstance(prompt, PromptMessageContent): - if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent): - prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump()) - elif isinstance(prompt, dict): - prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt) - else: - raise ValueError(f"invalid prompt message {prompt}") - prompts.append(prompt) - return prompts - return v - - @field_serializer("content") - def serialize_content( - self, content: Union[str, Sequence[PromptMessageContent]] | None - ) -> str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent] | None: - if content is None or isinstance(content, str): - return content - if isinstance(content, list): - return [item.model_dump() if hasattr(item, "model_dump") else item for item in content] - return content - - -class UserPromptMessage(PromptMessage): - """ - Model class for user prompt message. - """ - - role: PromptMessageRole = PromptMessageRole.USER - - -class AssistantPromptMessage(PromptMessage): - """ - Model class for assistant prompt message. - """ - - class ToolCall(BaseModel): - """ - Model class for assistant prompt message tool call. - """ - - class ToolCallFunction(BaseModel): - """ - Model class for assistant prompt message tool call function. - """ - - name: str - arguments: str - - id: str - type: str - function: ToolCallFunction - - @field_validator("id", mode="before") - @classmethod - def transform_id_to_str(cls, value) -> str: - if not isinstance(value, str): - return str(value) - else: - return value - - role: PromptMessageRole = PromptMessageRole.ASSISTANT - tool_calls: list[ToolCall] = [] - - def is_empty(self) -> bool: - """ - Check if prompt message is empty. - - :return: True if prompt message is empty, False otherwise - """ - return super().is_empty() and not self.tool_calls - - -class SystemPromptMessage(PromptMessage): - """ - Model class for system prompt message. - """ - - role: PromptMessageRole = PromptMessageRole.SYSTEM - - -class ToolPromptMessage(PromptMessage): - """ - Model class for tool prompt message. - """ - - role: PromptMessageRole = PromptMessageRole.TOOL - tool_call_id: str - - def is_empty(self) -> bool: - """ - Check if prompt message is empty. - - :return: True if prompt message is empty, False otherwise - """ - return super().is_empty() and not self.tool_call_id diff --git a/api/graphon/model_runtime/entities/model_entities.py b/api/graphon/model_runtime/entities/model_entities.py deleted file mode 100644 index 5ec4970faf..0000000000 --- a/api/graphon/model_runtime/entities/model_entities.py +++ /dev/null @@ -1,242 +0,0 @@ -from __future__ import annotations - -from decimal import Decimal -from enum import StrEnum, auto -from typing import Any - -from pydantic import BaseModel, ConfigDict, model_validator - -from graphon.model_runtime.entities.common_entities import I18nObject - - -class ModelType(StrEnum): - """ - Enum class for model type. - """ - - LLM = auto() - TEXT_EMBEDDING = "text-embedding" - RERANK = auto() - SPEECH2TEXT = auto() - MODERATION = auto() - TTS = auto() - - @classmethod - def value_of(cls, origin_model_type: str) -> ModelType: - """ - Get model type from origin model type. - - :return: model type - """ - if origin_model_type in {"text-generation", cls.LLM}: - return cls.LLM - elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}: - return cls.TEXT_EMBEDDING - elif origin_model_type in {"reranking", cls.RERANK}: - return cls.RERANK - elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}: - return cls.SPEECH2TEXT - elif origin_model_type in {"tts", cls.TTS}: - return cls.TTS - elif origin_model_type == cls.MODERATION: - return cls.MODERATION - else: - raise ValueError(f"invalid origin model type {origin_model_type}") - - def to_origin_model_type(self) -> str: - """ - Get origin model type from model type. - - :return: origin model type - """ - if self == self.LLM: - return "text-generation" - elif self == self.TEXT_EMBEDDING: - return "embeddings" - elif self == self.RERANK: - return "reranking" - elif self == self.SPEECH2TEXT: - return "speech2text" - elif self == self.TTS: - return "tts" - elif self == self.MODERATION: - return "moderation" - else: - raise ValueError(f"invalid model type {self}") - - -class FetchFrom(StrEnum): - """ - Enum class for fetch from. - """ - - PREDEFINED_MODEL = "predefined-model" - CUSTOMIZABLE_MODEL = "customizable-model" - - -class ModelFeature(StrEnum): - """ - Enum class for llm feature. - """ - - TOOL_CALL = "tool-call" - MULTI_TOOL_CALL = "multi-tool-call" - AGENT_THOUGHT = "agent-thought" - VISION = auto() - STREAM_TOOL_CALL = "stream-tool-call" - DOCUMENT = auto() - VIDEO = auto() - AUDIO = auto() - STRUCTURED_OUTPUT = "structured-output" - - -class DefaultParameterName(StrEnum): - """ - Enum class for parameter template variable. - """ - - TEMPERATURE = auto() - TOP_P = auto() - TOP_K = auto() - PRESENCE_PENALTY = auto() - FREQUENCY_PENALTY = auto() - MAX_TOKENS = auto() - RESPONSE_FORMAT = auto() - JSON_SCHEMA = auto() - - @classmethod - def value_of(cls, value: Any) -> DefaultParameterName: - """ - Get parameter name from value. - - :param value: parameter value - :return: parameter name - """ - for name in cls: - if name.value == value: - return name - raise ValueError(f"invalid parameter name {value}") - - -class ParameterType(StrEnum): - """ - Enum class for parameter type. - """ - - FLOAT = auto() - INT = auto() - STRING = auto() - BOOLEAN = auto() - TEXT = auto() - - -class ModelPropertyKey(StrEnum): - """ - Enum class for model property key. - """ - - MODE = auto() - CONTEXT_SIZE = auto() - MAX_CHUNKS = auto() - FILE_UPLOAD_LIMIT = auto() - SUPPORTED_FILE_EXTENSIONS = auto() - MAX_CHARACTERS_PER_CHUNK = auto() - DEFAULT_VOICE = auto() - VOICES = auto() - WORD_LIMIT = auto() - AUDIO_TYPE = auto() - MAX_WORKERS = auto() - - -class ProviderModel(BaseModel): - """ - Model class for provider model. - """ - - model: str - label: I18nObject - model_type: ModelType - features: list[ModelFeature] | None = None - fetch_from: FetchFrom - model_properties: dict[ModelPropertyKey, Any] - deprecated: bool = False - model_config = ConfigDict(protected_namespaces=()) - - @property - def support_structure_output(self) -> bool: - return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features - - -class ParameterRule(BaseModel): - """ - Model class for parameter rule. - """ - - name: str - use_template: str | None = None - label: I18nObject - type: ParameterType - help: I18nObject | None = None - required: bool = False - default: Any | None = None - min: float | None = None - max: float | None = None - precision: int | None = None - options: list[str] = [] - - -class PriceConfig(BaseModel): - """ - Model class for pricing info. - """ - - input: Decimal - output: Decimal | None = None - unit: Decimal - currency: str - - -class AIModelEntity(ProviderModel): - """ - Model class for AI model. - """ - - parameter_rules: list[ParameterRule] = [] - pricing: PriceConfig | None = None - - @model_validator(mode="after") - def validate_model(self): - supported_schema_keys = ["json_schema"] - schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None) - if not schema_key: - return self - if self.features is None: - self.features = [ModelFeature.STRUCTURED_OUTPUT] - else: - if ModelFeature.STRUCTURED_OUTPUT not in self.features: - self.features.append(ModelFeature.STRUCTURED_OUTPUT) - return self - - -class ModelUsage(BaseModel): - pass - - -class PriceType(StrEnum): - """ - Enum class for price type. - """ - - INPUT = auto() - OUTPUT = auto() - - -class PriceInfo(BaseModel): - """ - Model class for price info. - """ - - unit_price: Decimal - unit: Decimal - total_amount: Decimal - currency: str diff --git a/api/graphon/model_runtime/entities/provider_entities.py b/api/graphon/model_runtime/entities/provider_entities.py deleted file mode 100644 index 8e6c516fb9..0000000000 --- a/api/graphon/model_runtime/entities/provider_entities.py +++ /dev/null @@ -1,179 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum, auto - -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType - - -class ConfigurateMethod(StrEnum): - """ - Enum class for configurate method of provider model. - """ - - PREDEFINED_MODEL = "predefined-model" - CUSTOMIZABLE_MODEL = "customizable-model" - - -class FormType(StrEnum): - """ - Enum class for form type. - """ - - TEXT_INPUT = "text-input" - SECRET_INPUT = "secret-input" - SELECT = auto() - RADIO = auto() - SWITCH = auto() - - -class FormShowOnObject(BaseModel): - """ - Model class for form show on. - """ - - variable: str - value: str - - -class FormOption(BaseModel): - """ - Model class for form option. - """ - - label: I18nObject - value: str - show_on: list[FormShowOnObject] = [] - - @model_validator(mode="after") - def _(self): - if not self.label: - self.label = I18nObject(en_US=self.value) - return self - - -class CredentialFormSchema(BaseModel): - """ - Model class for credential form schema. - """ - - variable: str - label: I18nObject - type: FormType - required: bool = True - default: str | None = None - options: list[FormOption] | None = None - placeholder: I18nObject | None = None - max_length: int = 0 - show_on: list[FormShowOnObject] = [] - - -class ProviderCredentialSchema(BaseModel): - """ - Model class for provider credential schema. - """ - - credential_form_schemas: list[CredentialFormSchema] - - -class FieldModelSchema(BaseModel): - label: I18nObject - placeholder: I18nObject | None = None - - -class ModelCredentialSchema(BaseModel): - """ - Model class for model credential schema. - """ - - model: FieldModelSchema - credential_form_schemas: list[CredentialFormSchema] - - -class SimpleProviderEntity(BaseModel): - """ - Simplified provider schema exposed to callers. - - `provider` is the canonical runtime identifier. `provider_name` is an optional - compatibility alias for short-name lookups and is empty when no alias exists. - """ - - provider: str - provider_name: str = "" - label: I18nObject - icon_small: I18nObject | None = None - icon_small_dark: I18nObject | None = None - supported_model_types: Sequence[ModelType] - models: list[AIModelEntity] = [] - - -class ProviderHelpEntity(BaseModel): - """ - Model class for provider help. - """ - - title: I18nObject - url: I18nObject - - -class ProviderEntity(BaseModel): - """ - Runtime-native provider schema. - - `provider` is the canonical runtime identifier. `provider_name` is a - compatibility alias for callers that still resolve providers by short name and - is empty when no alias exists. - """ - - provider: str - provider_name: str = "" - label: I18nObject - description: I18nObject | None = None - icon_small: I18nObject | None = None - icon_small_dark: I18nObject | None = None - background: str | None = None - help: ProviderHelpEntity | None = None - supported_model_types: Sequence[ModelType] - configurate_methods: list[ConfigurateMethod] - models: list[AIModelEntity] = Field(default_factory=list) - provider_credential_schema: ProviderCredentialSchema | None = None - model_credential_schema: ModelCredentialSchema | None = None - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - # position from plugin _position.yaml - position: dict[str, list[str]] | None = {} - - @field_validator("models", mode="before") - @classmethod - def validate_models(cls, v): - # returns EmptyList if v is empty - if not v: - return [] - return v - - def to_simple_provider(self) -> SimpleProviderEntity: - """ - Convert to simple provider. - - :return: simple provider - """ - return SimpleProviderEntity( - provider=self.provider, - provider_name=self.provider_name, - label=self.label, - icon_small=self.icon_small, - supported_model_types=self.supported_model_types, - models=self.models, - ) - - -class ProviderConfig(BaseModel): - """ - Model class for provider config. - """ - - provider: str - credentials: dict diff --git a/api/graphon/model_runtime/entities/rerank_entities.py b/api/graphon/model_runtime/entities/rerank_entities.py deleted file mode 100644 index 8a0bb5fac2..0000000000 --- a/api/graphon/model_runtime/entities/rerank_entities.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import TypedDict - -from pydantic import BaseModel - - -class MultimodalRerankInput(TypedDict): - content: str - content_type: str - - -class RerankDocument(BaseModel): - """ - Model class for rerank document. - """ - - index: int - text: str - score: float - - -class RerankResult(BaseModel): - """ - Model class for rerank result. - """ - - model: str - docs: list[RerankDocument] diff --git a/api/graphon/model_runtime/entities/text_embedding_entities.py b/api/graphon/model_runtime/entities/text_embedding_entities.py deleted file mode 100644 index 08ffd83b5b..0000000000 --- a/api/graphon/model_runtime/entities/text_embedding_entities.py +++ /dev/null @@ -1,47 +0,0 @@ -from decimal import Decimal -from enum import StrEnum, auto - -from pydantic import BaseModel - -from graphon.model_runtime.entities.model_entities import ModelUsage - - -class EmbeddingInputType(StrEnum): - """Embedding request input variants understood by the model runtime.""" - - DOCUMENT = auto() - QUERY = auto() - - -class EmbeddingUsage(ModelUsage): - """ - Model class for embedding usage. - """ - - tokens: int - total_tokens: int - unit_price: Decimal - price_unit: Decimal - total_price: Decimal - currency: str - latency: float - - -class EmbeddingResult(BaseModel): - """ - Model class for text embedding result. - """ - - model: str - embeddings: list[list[float]] - usage: EmbeddingUsage - - -class FileEmbeddingResult(BaseModel): - """ - Model class for file embedding result. - """ - - model: str - embeddings: list[list[float]] - usage: EmbeddingUsage diff --git a/api/graphon/model_runtime/errors/__init__.py b/api/graphon/model_runtime/errors/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/model_runtime/errors/invoke.py b/api/graphon/model_runtime/errors/invoke.py deleted file mode 100644 index 1a57078b98..0000000000 --- a/api/graphon/model_runtime/errors/invoke.py +++ /dev/null @@ -1,41 +0,0 @@ -class InvokeError(ValueError): - """Base class for all LLM exceptions.""" - - description: str | None = None - - def __init__(self, description: str | None = None): - if description is not None: - self.description = description - - def __str__(self): - return self.description or self.__class__.__name__ - - -class InvokeConnectionError(InvokeError): - """Raised when the Invoke returns connection error.""" - - description = "Connection Error" - - -class InvokeServerUnavailableError(InvokeError): - """Raised when the Invoke returns server unavailable error.""" - - description = "Server Unavailable Error" - - -class InvokeRateLimitError(InvokeError): - """Raised when the Invoke returns rate limit error.""" - - description = "Rate Limit Error" - - -class InvokeAuthorizationError(InvokeError): - """Raised when the Invoke returns authorization error.""" - - description = "Incorrect model credentials provided, please check and try again. " - - -class InvokeBadRequestError(InvokeError): - """Raised when the Invoke returns bad request.""" - - description = "Bad Request Error" diff --git a/api/graphon/model_runtime/errors/validate.py b/api/graphon/model_runtime/errors/validate.py deleted file mode 100644 index 16bebcc67d..0000000000 --- a/api/graphon/model_runtime/errors/validate.py +++ /dev/null @@ -1,6 +0,0 @@ -class CredentialsValidateFailedError(ValueError): - """ - Credentials validate failed error - """ - - pass diff --git a/api/graphon/model_runtime/memory/__init__.py b/api/graphon/model_runtime/memory/__init__.py deleted file mode 100644 index 2d954486c3..0000000000 --- a/api/graphon/model_runtime/memory/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory - -__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"] diff --git a/api/graphon/model_runtime/memory/prompt_message_memory.py b/api/graphon/model_runtime/memory/prompt_message_memory.py deleted file mode 100644 index 03e26e9ff5..0000000000 --- a/api/graphon/model_runtime/memory/prompt_message_memory.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from typing import Protocol - -from graphon.model_runtime.entities import PromptMessage - -DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000 - - -class PromptMessageMemory(Protocol): - """Port for loading memory as prompt messages.""" - - def get_history_prompt_messages( - self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None - ) -> Sequence[PromptMessage]: - """Return historical prompt messages constrained by token/message limits.""" - ... diff --git a/api/graphon/model_runtime/model_providers/__base/__init__.py b/api/graphon/model_runtime/model_providers/__base/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/model_runtime/model_providers/__base/ai_model.py b/api/graphon/model_runtime/model_providers/__base/ai_model.py deleted file mode 100644 index 1700ec9740..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/ai_model.py +++ /dev/null @@ -1,247 +0,0 @@ -import decimal - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE -from graphon.model_runtime.entities.model_entities import ( - AIModelEntity, - DefaultParameterName, - ModelType, - PriceConfig, - PriceInfo, - PriceType, -) -from graphon.model_runtime.entities.provider_entities import ProviderEntity -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.runtime import ModelRuntime - - -class AIModel: - """ - Runtime-facing base class for all model providers. - - This stays a regular Python class because instances hold live collaborators - such as the provider schema and runtime adapter rather than user input that - benefits from Pydantic validation. Subclasses must pin ``model_type`` via a - class attribute; the base class is not meant to be instantiated directly. - """ - - model_type: ModelType - provider_schema: ProviderEntity - model_runtime: ModelRuntime - started_at: float - - def __init__( - self, - provider_schema: ProviderEntity, - model_runtime: ModelRuntime, - *, - started_at: float = 0, - ) -> None: - if getattr(type(self), "model_type", None) is None: - raise TypeError("AIModel subclasses must define model_type as a class attribute") - - self.model_type = type(self).model_type - self.provider_schema = provider_schema - self.model_runtime = model_runtime - self.started_at = started_at - - @property - def provider(self) -> str: - return self.provider_schema.provider - - @property - def provider_display_name(self) -> str: - return self.provider_schema.label.en_US - - @property - def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]: - """ - Map model invoke error to unified error. - - The key is the error type thrown to the caller, and the value contains - runtime-facing exception types that should be normalized to it. - """ - return { - InvokeConnectionError: [InvokeConnectionError], - InvokeServerUnavailableError: [InvokeServerUnavailableError], - InvokeRateLimitError: [InvokeRateLimitError], - InvokeAuthorizationError: [InvokeAuthorizationError], - InvokeBadRequestError: [InvokeBadRequestError], - ValueError: [ValueError], - } - - def _transform_invoke_error(self, error: Exception) -> Exception: - """ - Transform invoke error to unified error - - :param error: model invoke error - :return: unified error - """ - for invoke_error, model_errors in self._invoke_error_mapping.items(): - if isinstance(error, tuple(model_errors)): - if invoke_error == InvokeAuthorizationError: - return InvokeAuthorizationError( - description=( - f"[{self.provider_display_name}] Incorrect model credentials provided, " - "please check and try again." - ) - ) - elif isinstance(invoke_error, InvokeError): - return InvokeError( - description=f"[{self.provider_display_name}] {invoke_error.description}, {str(error)}" - ) - else: - return error - - return InvokeError(description=f"[{self.provider_display_name}] Error: {str(error)}") - - def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo: - """ - Get price for given model and tokens - - :param model: model name - :param credentials: model credentials - :param price_type: price type - :param tokens: number of tokens - :return: price info - """ - # get model schema - model_schema = self.get_model_schema(model, credentials) - - # get price info from predefined model schema - price_config: PriceConfig | None = None - if model_schema and model_schema.pricing: - price_config = model_schema.pricing - - # get unit price - unit_price = None - if price_config: - if price_type == PriceType.INPUT: - unit_price = price_config.input - elif price_type == PriceType.OUTPUT and price_config.output is not None: - unit_price = price_config.output - - if unit_price is None: - return PriceInfo( - unit_price=decimal.Decimal("0.0"), - unit=decimal.Decimal("0.0"), - total_amount=decimal.Decimal("0.0"), - currency="USD", - ) - - # calculate total amount - if not price_config: - raise ValueError(f"Price config not found for model {model}") - total_amount = tokens * unit_price * price_config.unit - total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP) - - return PriceInfo( - unit_price=unit_price, - unit=price_config.unit, - total_amount=total_amount, - currency=price_config.currency, - ) - - def get_model_schema(self, model: str, credentials: dict | None = None) -> AIModelEntity | None: - """ - Get model schema by model name and credentials - - :param model: model name - :param credentials: model credentials - :return: model schema - """ - return self.model_runtime.get_model_schema( - provider=self.provider, - model_type=self.model_type, - model=model, - credentials=credentials or {}, - ) - - def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: - """ - Get customizable model schema from credentials - - :param model: model name - :param credentials: model credentials - :return: model schema - """ - - # get customizable model schema - schema = self.get_customizable_model_schema(model, credentials) - if not schema: - return None - - # fill in the template - new_parameter_rules = [] - for parameter_rule in schema.parameter_rules: - if parameter_rule.use_template: - try: - default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) - default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) - if not parameter_rule.max and "max" in default_parameter_rule: - parameter_rule.max = default_parameter_rule["max"] - if not parameter_rule.min and "min" in default_parameter_rule: - parameter_rule.min = default_parameter_rule["min"] - if not parameter_rule.default and "default" in default_parameter_rule: - parameter_rule.default = default_parameter_rule["default"] - if not parameter_rule.precision and "precision" in default_parameter_rule: - parameter_rule.precision = default_parameter_rule["precision"] - if not parameter_rule.required and "required" in default_parameter_rule: - parameter_rule.required = default_parameter_rule["required"] - if not parameter_rule.help and "help" in default_parameter_rule: - parameter_rule.help = I18nObject( - en_US=default_parameter_rule["help"]["en_US"], - ) - if ( - parameter_rule.help - and not parameter_rule.help.en_US - and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"]) - ): - parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"] - if ( - parameter_rule.help - and not parameter_rule.help.zh_Hans - and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"]) - ): - parameter_rule.help.zh_Hans = default_parameter_rule["help"].get( - "zh_Hans", default_parameter_rule["help"]["en_US"] - ) - except ValueError: - pass - - new_parameter_rules.append(parameter_rule) - - schema.parameter_rules = new_parameter_rules - - return schema - - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: - """ - Get customizable model schema - - :param model: model name - :param credentials: model credentials - :return: model schema - """ - return None - - def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName): - """ - Get default parameter rule for given name - - :param name: parameter name - :return: parameter rule - """ - default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name) - - if not default_parameter_rule: - raise Exception(f"Invalid model parameter rule name {name}") - - return default_parameter_rule diff --git a/api/graphon/model_runtime/model_providers/__base/large_language_model.py b/api/graphon/model_runtime/model_providers/__base/large_language_model.py deleted file mode 100644 index 0f909646a1..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/large_language_model.py +++ /dev/null @@ -1,638 +0,0 @@ -import logging -import time -import uuid -from collections.abc import Callable, Generator, Iterator, Mapping, Sequence -from typing import Union - -from graphon.model_runtime.callbacks.base_callback import Callback -from graphon.model_runtime.callbacks.logging_callback import LoggingCallback -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageContentUnionTypes, - PromptMessageTool, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ( - ModelType, - PriceType, -) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -def _gen_tool_call_id() -> str: - return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" - - -def _run_callbacks(callbacks: Sequence[Callback] | None, *, event: str, invoke: Callable[[Callback], None]) -> None: - if not callbacks: - return - - for callback in callbacks: - try: - invoke(callback) - except Exception as e: - if callback.raise_error: - raise - logger.warning("Callback %s %s failed with error %s", callback.__class__.__name__, event, e) - - -def _get_or_create_tool_call( - existing_tools_calls: list[AssistantPromptMessage.ToolCall], - tool_call_id: str, -) -> AssistantPromptMessage.ToolCall: - """ - Get or create a tool call by ID. - - If `tool_call_id` is empty, returns the most recently created tool call. - """ - if not tool_call_id: - if not existing_tools_calls: - raise ValueError("tool_call_id is empty but no existing tool call is available to apply the delta") - return existing_tools_calls[-1] - - tool_call = next((tool_call for tool_call in existing_tools_calls if tool_call.id == tool_call_id), None) - if tool_call is None: - tool_call = AssistantPromptMessage.ToolCall( - id=tool_call_id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), - ) - existing_tools_calls.append(tool_call) - - return tool_call - - -def _merge_tool_call_delta( - tool_call: AssistantPromptMessage.ToolCall, - delta: AssistantPromptMessage.ToolCall, -) -> None: - if delta.id: - tool_call.id = delta.id - if delta.type: - tool_call.type = delta.type - if delta.function.name: - tool_call.function.name = delta.function.name - if delta.function.arguments: - tool_call.function.arguments += delta.function.arguments - - -def _build_llm_result_from_chunks( - model: str, - prompt_messages: Sequence[PromptMessage], - chunks: Iterator[LLMResultChunk], -) -> LLMResult: - """ - Build a single `LLMResult` by accumulating all returned chunks. - - Some models only support streaming output (e.g. Qwen3 open-source edition) - and the plugin side may still implement the response via a chunked stream, - so all chunks must be consumed and concatenated into a single ``LLMResult``. - - The ``usage`` is taken from the last chunk that carries it, which is the - typical convention for streaming responses (the final chunk contains the - aggregated token counts). - """ - content = "" - content_list: list[PromptMessageContentUnionTypes] = [] - usage = LLMUsage.empty_usage() - system_fingerprint: str | None = None - tools_calls: list[AssistantPromptMessage.ToolCall] = [] - - try: - for chunk in chunks: - if isinstance(chunk.delta.message.content, str): - content += chunk.delta.message.content - elif isinstance(chunk.delta.message.content, list): - content_list.extend(chunk.delta.message.content) - - if chunk.delta.message.tool_calls: - _increase_tool_call(chunk.delta.message.tool_calls, tools_calls) - - if chunk.delta.usage: - usage = chunk.delta.usage - if chunk.system_fingerprint: - system_fingerprint = chunk.system_fingerprint - except Exception: - logger.exception("Error while consuming non-stream plugin chunk iterator.") - raise - finally: - # Drain any remaining chunks to release underlying streaming resources (e.g. HTTP connections). - close = getattr(chunks, "close", None) - if callable(close): - close() - - return LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=content or content_list, - tool_calls=tools_calls, - ), - usage=usage, - system_fingerprint=system_fingerprint, - ) - - -def _invoke_llm_via_runtime( - *, - llm_model: "LargeLanguageModel", - provider: str, - model: str, - credentials: dict, - model_parameters: dict, - prompt_messages: Sequence[PromptMessage], - tools: list[PromptMessageTool] | None, - stop: Sequence[str] | None, - stream: bool, -) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: - return llm_model.model_runtime.invoke_llm( - provider=provider, - model=model, - credentials=credentials, - model_parameters=model_parameters, - prompt_messages=list(prompt_messages), - tools=tools, - stop=stop, - stream=stream, - ) - - -def _normalize_non_stream_runtime_result( - model: str, - prompt_messages: Sequence[PromptMessage], - result: Union[LLMResult, Iterator[LLMResultChunk]], -) -> LLMResult: - if isinstance(result, LLMResult): - return result - return _build_llm_result_from_chunks(model=model, prompt_messages=prompt_messages, chunks=result) - - -def _increase_tool_call( - new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall] -): - """ - Merge incremental tool call updates into existing tool calls. - - :param new_tool_calls: List of new tool call deltas to be merged. - :param existing_tools_calls: List of existing tool calls to be modified IN-PLACE. - """ - - for new_tool_call in new_tool_calls: - # generate ID for tool calls with function name but no ID to track them - if new_tool_call.function.name and not new_tool_call.id: - new_tool_call.id = _gen_tool_call_id() - - tool_call = _get_or_create_tool_call(existing_tools_calls, new_tool_call.id) - _merge_tool_call_delta(tool_call, new_tool_call) - - -class LargeLanguageModel(AIModel): - """ - Model class for large language model. - """ - - model_type: ModelType = ModelType.LLM - - def invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict | None = None, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, - stream: bool = True, - callbacks: list[Callback] | None = None, - ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: - """ - Invoke large language model - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param callbacks: callbacks - :return: full response or stream response chunk generator result - """ - # validate and filter model parameters - if model_parameters is None: - model_parameters = {} - - self.started_at = time.perf_counter() - - callbacks = callbacks or [] - - if logger.isEnabledFor(logging.DEBUG): - callbacks.append(LoggingCallback()) - - # trigger before invoke callbacks - self._trigger_before_invoke_callbacks( - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - callbacks=callbacks, - ) - - result: Union[LLMResult, Generator[LLMResultChunk, None, None]] - - try: - result = _invoke_llm_via_runtime( - llm_model=self, - provider=self.provider, - model=model, - credentials=credentials, - model_parameters=model_parameters, - prompt_messages=prompt_messages, - tools=tools, - stop=stop, - stream=stream, - ) - - if not stream: - result = _normalize_non_stream_runtime_result( - model=model, prompt_messages=prompt_messages, result=result - ) - except Exception as e: - self._trigger_invoke_error_callbacks( - model=model, - ex=e, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - callbacks=callbacks, - ) - - # TODO - raise self._transform_invoke_error(e) - - if stream and not isinstance(result, LLMResult): - return self._invoke_result_generator( - model=model, - result=result, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - callbacks=callbacks, - ) - elif isinstance(result, LLMResult): - self._trigger_after_invoke_callbacks( - model=model, - result=result, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - callbacks=callbacks, - ) - # Following https://github.com/langgenius/dify/issues/17799, - # we removed the prompt_messages from the chunk on the plugin daemon side. - # To ensure compatibility, we add the prompt_messages back here. - result.prompt_messages = prompt_messages - return result - raise NotImplementedError("unsupported invoke result type", type(result)) - - def _invoke_result_generator( - self, - model: str, - result: Generator[LLMResultChunk, None, None], - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - invocation_context: Mapping[str, object] | None = None, - callbacks: list[Callback] | None = None, - ) -> Generator[LLMResultChunk, None, None]: - """ - Invoke result generator - - :param result: result generator - :return: result generator - """ - callbacks = callbacks or [] - message_content: list[PromptMessageContentUnionTypes] = [] - usage = None - system_fingerprint = None - real_model = model - - def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None): - if not content: - return - if isinstance(content, list): - message_content.extend(content) - return - if isinstance(content, str): - message_content.append(TextPromptMessageContent(data=content)) - return - - try: - for chunk in result: - # Following https://github.com/langgenius/dify/issues/17799, - # we removed the prompt_messages from the chunk on the plugin daemon side. - # To ensure compatibility, we add the prompt_messages back here. - chunk.prompt_messages = prompt_messages - yield chunk - - self._trigger_new_chunk_callbacks( - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - callbacks=callbacks, - ) - - _update_message_content(chunk.delta.message.content) - - real_model = chunk.model - if chunk.delta.usage: - usage = chunk.delta.usage - - if chunk.system_fingerprint: - system_fingerprint = chunk.system_fingerprint - except Exception as e: - raise self._transform_invoke_error(e) - - assistant_message = AssistantPromptMessage(content=message_content) - self._trigger_after_invoke_callbacks( - model=model, - result=LLMResult( - model=real_model, - prompt_messages=prompt_messages, - message=assistant_message, - usage=usage or LLMUsage.empty_usage(), - system_fingerprint=system_fingerprint, - ), - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - callbacks=callbacks, - ) - - def get_num_tokens( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None, - ) -> int: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param tools: tools for tool calling - :return: - """ - return self.model_runtime.get_llm_num_tokens( - provider=self.provider, - model_type=self.model_type, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, - ) - - def calc_response_usage( - self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int - ) -> LLMUsage: - """ - Calculate response usage - - :param model: model name - :param credentials: model credentials - :param prompt_tokens: prompt tokens - :param completion_tokens: completion tokens - :return: usage - """ - # get prompt price info - prompt_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=prompt_tokens, - ) - - # get completion price info - completion_price_info = self.get_price( - model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens - ) - - # transform usage - usage = LLMUsage( - prompt_tokens=prompt_tokens, - prompt_unit_price=prompt_price_info.unit_price, - prompt_price_unit=prompt_price_info.unit, - prompt_price=prompt_price_info.total_amount, - completion_tokens=completion_tokens, - completion_unit_price=completion_price_info.unit_price, - completion_price_unit=completion_price_info.unit, - completion_price=completion_price_info.total_amount, - total_tokens=prompt_tokens + completion_tokens, - total_price=prompt_price_info.total_amount + completion_price_info.total_amount, - currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at, - ) - - return usage - - def _trigger_before_invoke_callbacks( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - invocation_context: Mapping[str, object] | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger before invoke callbacks - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - :param callbacks: callbacks - """ - _run_callbacks( - callbacks, - event="on_before_invoke", - invoke=lambda callback: callback.on_before_invoke( - llm_instance=self, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - ), - ) - - def _trigger_new_chunk_callbacks( - self, - chunk: LLMResultChunk, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - invocation_context: Mapping[str, object] | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger new chunk callbacks - - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - """ - _run_callbacks( - callbacks, - event="on_new_chunk", - invoke=lambda callback: callback.on_new_chunk( - llm_instance=self, - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - ), - ) - - def _trigger_after_invoke_callbacks( - self, - model: str, - result: LLMResult, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - invocation_context: Mapping[str, object] | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger after invoke callbacks - - :param model: model name - :param result: result - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - :param callbacks: callbacks - """ - _run_callbacks( - callbacks, - event="on_after_invoke", - invoke=lambda callback: callback.on_after_invoke( - llm_instance=self, - result=result, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - ), - ) - - def _trigger_invoke_error_callbacks( - self, - model: str, - ex: Exception, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - invocation_context: Mapping[str, object] | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger invoke error callbacks - - :param model: model name - :param ex: exception - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - :param callbacks: callbacks - """ - _run_callbacks( - callbacks, - event="on_invoke_error", - invoke=lambda callback: callback.on_invoke_error( - llm_instance=self, - ex=ex, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - ), - ) diff --git a/api/graphon/model_runtime/model_providers/__base/moderation_model.py b/api/graphon/model_runtime/model_providers/__base/moderation_model.py deleted file mode 100644 index 01f6842998..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/moderation_model.py +++ /dev/null @@ -1,33 +0,0 @@ -import time - -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - - -class ModerationModel(AIModel): - """ - Model class for moderation model. - """ - - model_type: ModelType = ModelType.MODERATION - - def invoke(self, model: str, credentials: dict, text: str) -> bool: - """ - Invoke moderation model - - :param model: model name - :param credentials: model credentials - :param text: text to moderate - :return: false if text is safe, true otherwise - """ - self.started_at = time.perf_counter() - - try: - return self.model_runtime.invoke_moderation( - provider=self.provider, - model=model, - credentials=credentials, - text=text, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/graphon/model_runtime/model_providers/__base/rerank_model.py b/api/graphon/model_runtime/model_providers/__base/rerank_model.py deleted file mode 100644 index 94b2b5a4fb..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/rerank_model.py +++ /dev/null @@ -1,76 +0,0 @@ -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - - -class RerankModel(AIModel): - """ - Base Model class for rerank model. - """ - - model_type: ModelType = ModelType.RERANK - - def invoke( - self, - model: str, - credentials: dict, - query: str, - docs: list[str], - score_threshold: float | None = None, - top_n: int | None = None, - ) -> RerankResult: - """ - Invoke rerank model - - :param model: model name - :param credentials: model credentials - :param query: search query - :param docs: docs for reranking - :param score_threshold: score threshold - :param top_n: top n - :return: rerank result - """ - try: - return self.model_runtime.invoke_rerank( - provider=self.provider, - model=model, - credentials=credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ) - except Exception as e: - raise self._transform_invoke_error(e) - - def invoke_multimodal_rerank( - self, - model: str, - credentials: dict, - query: MultimodalRerankInput, - docs: list[MultimodalRerankInput], - score_threshold: float | None = None, - top_n: int | None = None, - ) -> RerankResult: - """ - Invoke multimodal rerank model - :param model: model name - :param credentials: model credentials - :param query: search query - :param docs: docs for reranking - :param score_threshold: score threshold - :param top_n: top n - :return: rerank result - """ - try: - return self.model_runtime.invoke_multimodal_rerank( - provider=self.provider, - model=model, - credentials=credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/graphon/model_runtime/model_providers/__base/speech2text_model.py b/api/graphon/model_runtime/model_providers/__base/speech2text_model.py deleted file mode 100644 index 4f5d648639..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/speech2text_model.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import IO - -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - - -class Speech2TextModel(AIModel): - """ - Model class for speech2text model. - """ - - model_type: ModelType = ModelType.SPEECH2TEXT - - def invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str: - """ - Invoke speech to text model - - :param model: model name - :param credentials: model credentials - :param file: audio file - :return: text for given audio file - """ - try: - return self.model_runtime.invoke_speech_to_text( - provider=self.provider, - model=model, - credentials=credentials, - file=file, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/graphon/model_runtime/model_providers/__base/text_embedding_model.py b/api/graphon/model_runtime/model_providers/__base/text_embedding_model.py deleted file mode 100644 index c8b4a0a6af..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/text_embedding_model.py +++ /dev/null @@ -1,98 +0,0 @@ -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - - -class TextEmbeddingModel(AIModel): - """ - Model class for text embedding model. - """ - - model_type: ModelType = ModelType.TEXT_EMBEDDING - - def invoke( - self, - model: str, - credentials: dict, - texts: list[str] | None = None, - multimodel_documents: list[dict] | None = None, - input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, - ) -> EmbeddingResult: - """ - Invoke text embedding model - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :param files: files to embed - :param input_type: input type - :return: embeddings result - """ - try: - if texts: - return self.model_runtime.invoke_text_embedding( - provider=self.provider, - model=model, - credentials=credentials, - texts=texts, - input_type=input_type, - ) - if multimodel_documents: - return self.model_runtime.invoke_multimodal_embedding( - provider=self.provider, - model=model, - credentials=credentials, - documents=multimodel_documents, - input_type=input_type, - ) - raise ValueError("No texts or files provided") - except Exception as e: - raise self._transform_invoke_error(e) - - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> list[int]: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :return: - """ - return self.model_runtime.get_text_embedding_num_tokens( - provider=self.provider, - model=model, - credentials=credentials, - texts=texts, - ) - - def _get_context_size(self, model: str, credentials: dict) -> int: - """ - Get context size for given embedding model - - :param model: model name - :param credentials: model credentials - :return: context size - """ - model_schema = self.get_model_schema(model, credentials) - - if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties: - content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] - return content_size - - return 1000 - - def _get_max_chunks(self, model: str, credentials: dict) -> int: - """ - Get max chunks for given embedding model - - :param model: model name - :param credentials: model credentials - :return: max chunks - """ - model_schema = self.get_model_schema(model, credentials) - - if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: - max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] - return max_chunks - - return 1 diff --git a/api/graphon/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/graphon/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py deleted file mode 100644 index 3967acf07b..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py +++ /dev/null @@ -1,53 +0,0 @@ -import logging -from threading import Lock -from typing import Any - -logger = logging.getLogger(__name__) - -_tokenizer: Any | None = None -_lock = Lock() - - -class GPT2Tokenizer: - @staticmethod - def _get_num_tokens_by_gpt2(text: str) -> int: - """ - use gpt2 tokenizer to get num tokens - """ - _tokenizer = GPT2Tokenizer.get_encoder() - tokens = _tokenizer.encode(text) # type: ignore - return len(tokens) - - @staticmethod - def get_num_tokens(text: str) -> int: - # Because this process needs more cpu resource, we turn this back before we find a better way to handle it. - # - # future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text) - # result = future.result() - # return cast(int, result) - return GPT2Tokenizer._get_num_tokens_by_gpt2(text) - - @staticmethod - def get_encoder(): - global _tokenizer, _lock - if _tokenizer is not None: - return _tokenizer - with _lock: - if _tokenizer is None: - # Try to use tiktoken to get the tokenizer because it is faster - # - try: - import tiktoken - - _tokenizer = tiktoken.get_encoding("gpt2") - except Exception: - from os.path import abspath, dirname, join - - from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer - - base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), "gpt2") - _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) - logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken") - - return _tokenizer diff --git a/api/graphon/model_runtime/model_providers/__base/tts_model.py b/api/graphon/model_runtime/model_providers/__base/tts_model.py deleted file mode 100644 index 6846f3c403..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/tts_model.py +++ /dev/null @@ -1,58 +0,0 @@ -import logging -from collections.abc import Iterable - -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -class TTSModel(AIModel): - """ - Model class for TTS model. - """ - - model_type: ModelType = ModelType.TTS - - def invoke( - self, - model: str, - credentials: dict, - content_text: str, - voice: str, - ) -> Iterable[bytes]: - """ - Invoke large language model - - :param model: model name - :param credentials: model credentials - :param voice: model timbre - :param content_text: text content to be translated - :return: translated audio file - """ - try: - return self.model_runtime.invoke_tts( - provider=self.provider, - model=model, - credentials=credentials, - content_text=content_text, - voice=voice, - ) - except Exception as e: - raise self._transform_invoke_error(e) - - def get_tts_model_voices(self, model: str, credentials: dict, language: str | None = None): - """ - Retrieves the list of voices supported by a given text-to-speech (TTS) model. - - :param language: The language for which the voices are requested. - :param model: The name of the TTS model. - :param credentials: The credentials required to access the TTS model. - :return: A list of voices supported by the TTS model. - """ - return self.model_runtime.get_tts_model_voices( - provider=self.provider, - model=model, - credentials=credentials, - language=language, - ) diff --git a/api/graphon/model_runtime/model_providers/__init__.py b/api/graphon/model_runtime/model_providers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/model_runtime/model_providers/_position.yaml b/api/graphon/model_runtime/model_providers/_position.yaml deleted file mode 100644 index fb02de3a67..0000000000 --- a/api/graphon/model_runtime/model_providers/_position.yaml +++ /dev/null @@ -1,43 +0,0 @@ -- openai -- deepseek -- anthropic -- azure_openai -- google -- vertex_ai -- nvidia -- nvidia_nim -- cohere -- upstage -- bedrock -- togetherai -- openrouter -- ollama -- mistralai -- groq -- replicate -- huggingface_hub -- xinference -- triton_inference_server -- zhipuai -- baichuan -- spark -- minimax -- tongyi -- wenxin -- moonshot -- tencent -- jina -- chatglm -- yi -- openllm -- localai -- volcengine_maas -- openai_api_compatible -- hunyuan -- siliconflow -- perfxcloud -- zhinao -- fireworks -- mixedbread -- nomic -- voyage diff --git a/api/graphon/model_runtime/model_providers/model_provider_factory.py b/api/graphon/model_runtime/model_providers/model_provider_factory.py deleted file mode 100644 index 1ea30c7120..0000000000 --- a/api/graphon/model_runtime/model_providers/model_provider_factory.py +++ /dev/null @@ -1,173 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence - -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity -from graphon.model_runtime.model_providers.__base.ai_model import AIModel -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel -from graphon.model_runtime.runtime import ModelRuntime -from graphon.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator -from graphon.model_runtime.schema_validators.provider_credential_schema_validator import ( - ProviderCredentialSchemaValidator, -) - - -class ModelProviderFactory: - """Factory for provider schemas and model-type instances backed by a runtime adapter.""" - - def __init__(self, model_runtime: ModelRuntime): - if model_runtime is None: - raise ValueError("model_runtime is required.") - self.model_runtime = model_runtime - - def get_providers(self) -> Sequence[ProviderEntity]: - """ - Get all providers. - """ - return list(self.get_model_providers()) - - def get_model_providers(self) -> Sequence[ProviderEntity]: - """ - Get all model providers exposed by the runtime adapter. - """ - return self.model_runtime.fetch_model_providers() - - def get_provider_schema(self, provider: str) -> ProviderEntity: - """ - Get provider schema. - """ - return self.get_model_provider(provider=provider) - - def get_model_provider(self, provider: str) -> ProviderEntity: - """ - Get provider schema. - """ - provider_entity = self._resolve_provider(provider) - if provider_entity is None: - raise ValueError(f"Invalid provider: {provider}") - - return provider_entity - - def provider_credentials_validate(self, *, provider: str, credentials: dict): - """ - Validate provider credentials. - """ - provider_entity = self.get_model_provider(provider=provider) - - provider_credential_schema = provider_entity.provider_credential_schema - if not provider_credential_schema: - raise ValueError(f"Provider {provider} does not have provider_credential_schema") - - validator = ProviderCredentialSchemaValidator(provider_credential_schema) - filtered_credentials = validator.validate_and_filter(credentials) - - self.model_runtime.validate_provider_credentials( - provider=provider_entity.provider, - credentials=filtered_credentials, - ) - - return filtered_credentials - - def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict): - """ - Validate model credentials. - """ - provider_entity = self.get_model_provider(provider=provider) - - model_credential_schema = provider_entity.model_credential_schema - if not model_credential_schema: - raise ValueError(f"Provider {provider} does not have model_credential_schema") - - validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) - filtered_credentials = validator.validate_and_filter(credentials) - - self.model_runtime.validate_model_credentials( - provider=provider_entity.provider, - model_type=model_type, - model=model, - credentials=filtered_credentials, - ) - - return filtered_credentials - - def get_model_schema( - self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None - ) -> AIModelEntity | None: - """ - Get model schema. - """ - provider_entity = self.get_model_provider(provider) - return self.model_runtime.get_model_schema( - provider=provider_entity.provider, - model_type=model_type, - model=model, - credentials=credentials or {}, - ) - - def get_models( - self, - *, - provider: str | None = None, - model_type: ModelType | None = None, - provider_configs: list[ProviderConfig] | None = None, - ) -> list[SimpleProviderEntity]: - """ - Get all models for given model type. - """ - providers = [] - for provider_entity in self.get_model_providers(): - if provider and not self._matches_provider(provider_entity, provider): - continue - - if model_type and model_type not in provider_entity.supported_model_types: - continue - - simple_provider_schema = provider_entity.to_simple_provider() - if model_type is not None: - simple_provider_schema.models = [ - model_schema for model_schema in provider_entity.models if model_schema.model_type == model_type - ] - providers.append(simple_provider_schema) - - return providers - - def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel: - """ - Get model type instance by provider name and model type. - """ - provider_schema = self.get_model_provider(provider) - - if model_type == ModelType.LLM: - return LargeLanguageModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - if model_type == ModelType.TEXT_EMBEDDING: - return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - if model_type == ModelType.RERANK: - return RerankModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - if model_type == ModelType.SPEECH2TEXT: - return Speech2TextModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - if model_type == ModelType.MODERATION: - return ModerationModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - if model_type == ModelType.TTS: - return TTSModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - - raise ValueError(f"Unsupported model type: {model_type}") - - def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: - """ - Get provider icon. - """ - provider_entity = self.get_model_provider(provider) - return self.model_runtime.get_provider_icon(provider=provider_entity.provider, icon_type=icon_type, lang=lang) - - def _resolve_provider(self, provider: str) -> ProviderEntity | None: - return next((item for item in self.get_model_providers() if self._matches_provider(item, provider)), None) - - @staticmethod - def _matches_provider(provider_entity: ProviderEntity, provider: str) -> bool: - return provider in (provider_entity.provider, provider_entity.provider_name) diff --git a/api/graphon/model_runtime/runtime.py b/api/graphon/model_runtime/runtime.py deleted file mode 100644 index 79862bab8b..0000000000 --- a/api/graphon/model_runtime/runtime.py +++ /dev/null @@ -1,159 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator, Iterable, Sequence -from typing import IO, Any, Protocol, Union, runtime_checkable - -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.model_runtime.entities.provider_entities import ProviderEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult - - -@runtime_checkable -class ModelRuntime(Protocol): - """Port for provider discovery, schema lookup, and model execution. - - `provider` is the model runtime's canonical provider identifier. Adapters may - derive transport-specific details from it, but those details stay outside - this boundary. - """ - - def fetch_model_providers(self) -> Sequence[ProviderEntity]: ... - - def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: ... - - def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: ... - - def validate_model_credentials( - self, - *, - provider: str, - model_type: ModelType, - model: str, - credentials: dict[str, Any], - ) -> None: ... - - def get_model_schema( - self, - *, - provider: str, - model_type: ModelType, - model: str, - credentials: dict[str, Any], - ) -> AIModelEntity | None: ... - - def invoke_llm( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - model_parameters: dict[str, Any], - prompt_messages: Sequence[PromptMessage], - tools: list[PromptMessageTool] | None, - stop: Sequence[str] | None, - stream: bool, - ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: ... - - def get_llm_num_tokens( - self, - *, - provider: str, - model_type: ModelType, - model: str, - credentials: dict[str, Any], - prompt_messages: Sequence[PromptMessage], - tools: Sequence[PromptMessageTool] | None, - ) -> int: ... - - def invoke_text_embedding( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - texts: list[str], - input_type: EmbeddingInputType, - ) -> EmbeddingResult: ... - - def invoke_multimodal_embedding( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - documents: list[dict[str, Any]], - input_type: EmbeddingInputType, - ) -> EmbeddingResult: ... - - def get_text_embedding_num_tokens( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - texts: list[str], - ) -> list[int]: ... - - def invoke_rerank( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - query: str, - docs: list[str], - score_threshold: float | None, - top_n: int | None, - ) -> RerankResult: ... - - def invoke_multimodal_rerank( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - query: MultimodalRerankInput, - docs: list[MultimodalRerankInput], - score_threshold: float | None, - top_n: int | None, - ) -> RerankResult: ... - - def invoke_tts( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - content_text: str, - voice: str, - ) -> Iterable[bytes]: ... - - def get_tts_model_voices( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - language: str | None, - ) -> Any: ... - - def invoke_speech_to_text( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - file: IO[bytes], - ) -> str: ... - - def invoke_moderation( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - text: str, - ) -> bool: ... diff --git a/api/graphon/model_runtime/schema_validators/__init__.py b/api/graphon/model_runtime/schema_validators/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/model_runtime/schema_validators/common_validator.py b/api/graphon/model_runtime/schema_validators/common_validator.py deleted file mode 100644 index 984507081b..0000000000 --- a/api/graphon/model_runtime/schema_validators/common_validator.py +++ /dev/null @@ -1,92 +0,0 @@ -from typing import Union, cast - -from graphon.model_runtime.entities.provider_entities import CredentialFormSchema, FormType - - -class CommonValidator: - def _validate_and_filter_credential_form_schemas( - self, credential_form_schemas: list[CredentialFormSchema], credentials: dict - ): - need_validate_credential_form_schema_map = {} - for credential_form_schema in credential_form_schemas: - if not credential_form_schema.show_on: - need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema - continue - - all_show_on_match = True - for show_on_object in credential_form_schema.show_on: - if show_on_object.variable not in credentials: - all_show_on_match = False - break - - if credentials[show_on_object.variable] != show_on_object.value: - all_show_on_match = False - break - - if all_show_on_match: - need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema - - # Iterate over the remaining credential_form_schemas, verify each credential_form_schema - validated_credentials = {} - for credential_form_schema in need_validate_credential_form_schema_map.values(): - # add the value of the credential_form_schema corresponding to it to validated_credentials - result = self._validate_credential_form_schema(credential_form_schema, credentials) - if result: - validated_credentials[credential_form_schema.variable] = result - - return validated_credentials - - def _validate_credential_form_schema( - self, credential_form_schema: CredentialFormSchema, credentials: dict - ) -> Union[str, bool, None]: - """ - Validate credential form schema - - :param credential_form_schema: credential form schema - :param credentials: credentials - :return: validated credential form schema value - """ - # If the variable does not exist in credentials - value: Union[str, bool, None] = None - if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: - # If required is True, an exception is thrown - if credential_form_schema.required: - raise ValueError(f"Variable {credential_form_schema.variable} is required") - else: - # Get the value of default - if credential_form_schema.default: - # If it exists, add it to validated_credentials - return credential_form_schema.default - else: - # If default does not exist, skip - return None - - # Get the value corresponding to the variable from credentials - value = cast(str, credentials[credential_form_schema.variable]) - - # If max_length=0, no validation is performed - if credential_form_schema.max_length: - if len(value) > credential_form_schema.max_length: - raise ValueError( - f"Variable {credential_form_schema.variable} length should not be" - f" greater than {credential_form_schema.max_length}" - ) - - # check the type of value - if not isinstance(value, str): - raise ValueError(f"Variable {credential_form_schema.variable} should be string") - - if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}: - # If the value is in options, no validation is performed - if credential_form_schema.options: - if value not in [option.value for option in credential_form_schema.options]: - raise ValueError(f"Variable {credential_form_schema.variable} is not in options") - - if credential_form_schema.type == FormType.SWITCH: - # If the value is not in ['true', 'false'], an exception is thrown - if value.lower() not in {"true", "false"}: - raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") - - value = value.lower() == "true" - - return value diff --git a/api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py b/api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py deleted file mode 100644 index 9e4830c1b7..0000000000 --- a/api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py +++ /dev/null @@ -1,27 +0,0 @@ -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ModelCredentialSchema -from graphon.model_runtime.schema_validators.common_validator import CommonValidator - - -class ModelCredentialSchemaValidator(CommonValidator): - def __init__(self, model_type: ModelType, model_credential_schema: ModelCredentialSchema): - self.model_type = model_type - self.model_credential_schema = model_credential_schema - - def validate_and_filter(self, credentials: dict): - """ - Validate model credentials - - :param credentials: model credentials - :return: filtered credentials - """ - - if self.model_credential_schema is None: - raise ValueError("Model credential schema is None") - - # get the credential_form_schemas in provider_credential_schema - credential_form_schemas = self.model_credential_schema.credential_form_schemas - - credentials["__model_type"] = self.model_type.value - - return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py deleted file mode 100644 index 05fd3ce142..0000000000 --- a/api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py +++ /dev/null @@ -1,19 +0,0 @@ -from graphon.model_runtime.entities.provider_entities import ProviderCredentialSchema -from graphon.model_runtime.schema_validators.common_validator import CommonValidator - - -class ProviderCredentialSchemaValidator(CommonValidator): - def __init__(self, provider_credential_schema: ProviderCredentialSchema): - self.provider_credential_schema = provider_credential_schema - - def validate_and_filter(self, credentials: dict): - """ - Validate provider credentials - - :param credentials: provider credentials - :return: validated provider credentials - """ - # get the credential_form_schemas in provider_credential_schema - credential_form_schemas = self.provider_credential_schema.credential_form_schemas - - return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/api/graphon/model_runtime/utils/__init__.py b/api/graphon/model_runtime/utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/model_runtime/utils/encoders.py b/api/graphon/model_runtime/utils/encoders.py deleted file mode 100644 index 13abf74767..0000000000 --- a/api/graphon/model_runtime/utils/encoders.py +++ /dev/null @@ -1,218 +0,0 @@ -import dataclasses -import datetime -from collections import defaultdict, deque -from collections.abc import Callable, Sequence -from decimal import Decimal -from enum import Enum -from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network -from pathlib import Path, PurePath -from re import Pattern -from types import GeneratorType -from typing import Any, Literal, Union -from uuid import UUID - -from pydantic import BaseModel -from pydantic.networks import AnyUrl, NameEmail -from pydantic.types import SecretBytes, SecretStr -from pydantic_core import Url -from pydantic_extra_types.color import Color - - -def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: - return model.model_dump(mode=mode, **kwargs) - - -# Taken from Pydantic v1 as is -def isoformat(o: Union[datetime.date, datetime.time]) -> str: - return o.isoformat() - - -# Taken from Pydantic v1 as is -# TODO: pv2 should this return strings instead? -def decimal_encoder(dec_value: Decimal) -> Union[int, float]: - """ - Encodes a Decimal as int of there's no exponent, otherwise float - - This is useful when we use ConstrainedDecimal to represent Numeric(x,0) - where a integer (but not int typed) is used. Encoding this as a float - results in failed round-tripping between encode and parse. - Our Id type is a prime example of this. - - >>> decimal_encoder(Decimal("1.0")) - 1.0 - - >>> decimal_encoder(Decimal("1")) - 1 - """ - if dec_value.as_tuple().exponent >= 0: # type: ignore[operator] - return int(dec_value) - else: - return float(dec_value) - - -ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { - bytes: lambda o: o.decode(), - Color: str, - datetime.date: isoformat, - datetime.datetime: isoformat, - datetime.time: isoformat, - datetime.timedelta: lambda td: td.total_seconds(), - Decimal: decimal_encoder, - Enum: lambda o: o.value, - frozenset: list, - deque: list, - GeneratorType: list, - IPv4Address: str, - IPv4Interface: str, - IPv4Network: str, - IPv6Address: str, - IPv6Interface: str, - IPv6Network: str, - NameEmail: str, - Path: str, - Pattern: lambda o: o.pattern, - SecretBytes: str, - SecretStr: str, - set: list, - UUID: str, - Url: str, - AnyUrl: str, -} - - -def generate_encoders_by_class_tuples( - type_encoder_map: dict[Any, Callable[[Any], Any]], -) -> dict[Callable[[Any], Any], tuple[Any, ...]]: - encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple) - for type_, encoder in type_encoder_map.items(): - encoders_by_class_tuples[encoder] += (type_,) - return encoders_by_class_tuples - - -encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) - - -def jsonable_encoder( - obj: Any, - by_alias: bool = True, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - custom_encoder: dict[Any, Callable[[Any], Any]] | None = None, - excluded_key_prefixes: Sequence[str] = (), -) -> Any: - custom_encoder = custom_encoder or {} - if custom_encoder: - if type(obj) in custom_encoder: - return custom_encoder[type(obj)](obj) - else: - for encoder_type, encoder_instance in custom_encoder.items(): - if isinstance(obj, encoder_type): - return encoder_instance(obj) - if isinstance(obj, BaseModel): - obj_dict = _model_dump( - obj, - mode="json", - include=None, - exclude=None, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - exclude_defaults=exclude_defaults, - ) - if "__root__" in obj_dict: - obj_dict = obj_dict["__root__"] - return jsonable_encoder( - obj_dict, - exclude_none=exclude_none, - exclude_defaults=exclude_defaults, - excluded_key_prefixes=excluded_key_prefixes, - ) - if dataclasses.is_dataclass(obj): - # Ensure obj is a dataclass instance, not a dataclass type - if not isinstance(obj, type): - obj_dict = dataclasses.asdict(obj) - return jsonable_encoder( - obj_dict, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - excluded_key_prefixes=excluded_key_prefixes, - ) - if isinstance(obj, Enum): - return obj.value - if isinstance(obj, PurePath): - return str(obj) - if isinstance(obj, str | int | float | type(None)): - return obj - if isinstance(obj, Decimal): - return format(obj, "f") - if isinstance(obj, dict): - encoded_dict = {} - for key, value in obj.items(): - if isinstance(key, str) and any(key.startswith(prefix) for prefix in excluded_key_prefixes): - continue - if value is None and exclude_none: - continue - - encoded_key = jsonable_encoder( - key, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - excluded_key_prefixes=excluded_key_prefixes, - ) - encoded_value = jsonable_encoder( - value, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - excluded_key_prefixes=excluded_key_prefixes, - ) - encoded_dict[encoded_key] = encoded_value - return encoded_dict - if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque): - encoded_list = [] - for item in obj: - encoded_list.append( - jsonable_encoder( - item, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - excluded_key_prefixes=excluded_key_prefixes, - ) - ) - return encoded_list - - if type(obj) in ENCODERS_BY_TYPE: - return ENCODERS_BY_TYPE[type(obj)](obj) - for encoder, classes_tuple in encoders_by_class_tuples.items(): - if isinstance(obj, classes_tuple): - return encoder(obj) - - try: - data = dict(obj) # type: ignore - except Exception as e: - errors: list[Exception] = [] - errors.append(e) - try: - data = vars(obj) # type: ignore - except Exception as e: - errors.append(e) - raise ValueError(str(errors)) from e - return jsonable_encoder( - data, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - excluded_key_prefixes=excluded_key_prefixes, - ) diff --git a/api/graphon/node_events/__init__.py b/api/graphon/node_events/__init__.py deleted file mode 100644 index a2bbf9f176..0000000000 --- a/api/graphon/node_events/__init__.py +++ /dev/null @@ -1,48 +0,0 @@ -from .agent import AgentLogEvent -from .base import NodeEventBase, NodeRunResult -from .iteration import ( - IterationFailedEvent, - IterationNextEvent, - IterationStartedEvent, - IterationSucceededEvent, -) -from .loop import ( - LoopFailedEvent, - LoopNextEvent, - LoopStartedEvent, - LoopSucceededEvent, -) -from .node import ( - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - ModelInvokeCompletedEvent, - PauseRequestedEvent, - RunRetrieverResourceEvent, - RunRetryEvent, - StreamChunkEvent, - StreamCompletedEvent, - VariableUpdatedEvent, -) - -__all__ = [ - "AgentLogEvent", - "HumanInputFormFilledEvent", - "HumanInputFormTimeoutEvent", - "IterationFailedEvent", - "IterationNextEvent", - "IterationStartedEvent", - "IterationSucceededEvent", - "LoopFailedEvent", - "LoopNextEvent", - "LoopStartedEvent", - "LoopSucceededEvent", - "ModelInvokeCompletedEvent", - "NodeEventBase", - "NodeRunResult", - "PauseRequestedEvent", - "RunRetrieverResourceEvent", - "RunRetryEvent", - "StreamChunkEvent", - "StreamCompletedEvent", - "VariableUpdatedEvent", -] diff --git a/api/graphon/node_events/agent.py b/api/graphon/node_events/agent.py deleted file mode 100644 index bf295ec774..0000000000 --- a/api/graphon/node_events/agent.py +++ /dev/null @@ -1,18 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import Field - -from .base import NodeEventBase - - -class AgentLogEvent(NodeEventBase): - message_id: str = Field(..., description="id") - label: str = Field(..., description="label") - node_execution_id: str = Field(..., description="node execution id") - parent_id: str | None = Field(..., description="parent id") - error: str | None = Field(..., description="error") - status: str = Field(..., description="status") - data: Mapping[str, Any] = Field(..., description="data") - metadata: Mapping[str, Any] = Field(default_factory=dict, description="metadata") - node_id: str = Field(..., description="node id") diff --git a/api/graphon/node_events/base.py b/api/graphon/node_events/base.py deleted file mode 100644 index dcd1672428..0000000000 --- a/api/graphon/node_events/base.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import BaseModel, Field - -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage - - -class NodeEventBase(BaseModel): - """Base class for all node events""" - - pass - - -def _default_metadata(): - v: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} - return v - - -class NodeRunResult(BaseModel): - """ - Node Run Result. - """ - - status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.PENDING - - inputs: Mapping[str, Any] = Field(default_factory=dict) - process_data: Mapping[str, Any] = Field(default_factory=dict) - outputs: Mapping[str, Any] = Field(default_factory=dict) - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=_default_metadata) - llm_usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage) - - edge_source_handle: str = "source" # source handle id of node with multiple branches - - error: str = "" - error_type: str = "" - - # single step node run retry - retry_index: int = 0 diff --git a/api/graphon/node_events/iteration.py b/api/graphon/node_events/iteration.py deleted file mode 100644 index 744ddea628..0000000000 --- a/api/graphon/node_events/iteration.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import NodeEventBase - - -class IterationStartedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class IterationNextEvent(NodeEventBase): - index: int = Field(..., description="index") - pre_iteration_output: Any = None - - -class IterationSucceededEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class IterationFailedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/graphon/node_events/loop.py b/api/graphon/node_events/loop.py deleted file mode 100644 index 3ae230f9f6..0000000000 --- a/api/graphon/node_events/loop.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import NodeEventBase - - -class LoopStartedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class LoopNextEvent(NodeEventBase): - index: int = Field(..., description="index") - pre_loop_output: Any = None - - -class LoopSucceededEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class LoopFailedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/graphon/node_events/node.py b/api/graphon/node_events/node.py deleted file mode 100644 index 17f1494cf2..0000000000 --- a/api/graphon/node_events/node.py +++ /dev/null @@ -1,72 +0,0 @@ -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any - -from pydantic import Field - -from graphon.entities.pause_reason import PauseReason -from graphon.file import File -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import NodeRunResult -from graphon.variables.variables import Variable - -from .base import NodeEventBase - - -class RunRetrieverResourceEvent(NodeEventBase): - retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources") - context: str = Field(..., description="context") - context_files: list[File] | None = Field(default=None, description="context files") - - -class ModelInvokeCompletedEvent(NodeEventBase): - text: str - usage: LLMUsage - finish_reason: str | None = None - reasoning_content: str | None = None - structured_output: dict | None = None - - -class RunRetryEvent(NodeEventBase): - error: str = Field(..., description="error") - retry_index: int = Field(..., description="Retry attempt number") - start_at: datetime = Field(..., description="Retry start time") - - -class StreamChunkEvent(NodeEventBase): - # Spec-compliant fields - selector: Sequence[str] = Field( - ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" - ) - chunk: str = Field(..., description="the actual chunk content") - is_final: bool = Field(default=False, description="indicates if this is the last chunk") - - -class StreamCompletedEvent(NodeEventBase): - node_run_result: NodeRunResult = Field(..., description="run result") - - -class VariableUpdatedEvent(NodeEventBase): - """Notify the engine that a single variable should be applied to the shared pool.""" - - variable: Variable = Field(..., description="Updated variable payload to apply.") - - -class PauseRequestedEvent(NodeEventBase): - reason: PauseReason = Field(..., description="pause reason") - - -class HumanInputFormFilledEvent(NodeEventBase): - """Event emitted when a human input form is submitted.""" - - node_title: str - rendered_content: str - action_id: str - action_text: str - - -class HumanInputFormTimeoutEvent(NodeEventBase): - """Event emitted when a human input form times out.""" - - node_title: str - expiration_time: datetime diff --git a/api/graphon/nodes/__init__.py b/api/graphon/nodes/__init__.py deleted file mode 100644 index 2d376d104d..0000000000 --- a/api/graphon/nodes/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from graphon.enums import BuiltinNodeTypes - -__all__ = ["BuiltinNodeTypes"] diff --git a/api/graphon/nodes/answer/__init__.py b/api/graphon/nodes/answer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/nodes/answer/answer_node.py b/api/graphon/nodes/answer/answer_node.py deleted file mode 100644 index c5261a7939..0000000000 --- a/api/graphon/nodes/answer/answer_node.py +++ /dev/null @@ -1,70 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.answer.entities import AnswerNodeData -from graphon.nodes.base.node import Node -from graphon.nodes.base.template import Template -from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from graphon.variables import ArrayFileSegment, FileSegment, Segment - - -class AnswerNode(Node[AnswerNodeData]): - node_type = BuiltinNodeTypes.ANSWER - execution_type = NodeExecutionType.RESPONSE - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - segments = self.graph_runtime_state.variable_pool.convert_template(self.node_data.answer) - files = self._extract_files_from_segments(segments.value) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"answer": segments.markdown, "files": ArrayFileSegment(value=files)}, - ) - - def _extract_files_from_segments(self, segments: Sequence[Segment]): - """Extract all files from segments containing FileSegment or ArrayFileSegment instances. - - FileSegment contains a single file, while ArrayFileSegment contains multiple files. - This method flattens all files into a single list. - """ - files = [] - for segment in segments: - if isinstance(segment, FileSegment): - # Single file - wrap in list for consistency - files.append(segment.value) - elif isinstance(segment, ArrayFileSegment): - # Multiple files - extend the list - files.extend(segment.value) - return files - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: AnswerNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - variable_template_parser = VariableTemplateParser(template=node_data.answer) - variable_selectors = variable_template_parser.extract_variable_selectors() - - variable_mapping = {} - for variable_selector in variable_selectors: - variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector - - return variable_mapping - - def get_streaming_template(self) -> Template: - """ - Get the template for streaming. - - Returns: - Template instance for this Answer node - """ - return Template.from_answer_template(self.node_data.answer) diff --git a/api/graphon/nodes/answer/entities.py b/api/graphon/nodes/answer/entities.py deleted file mode 100644 index c49f1f3895..0000000000 --- a/api/graphon/nodes/answer/entities.py +++ /dev/null @@ -1,67 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum, auto - -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - - -class AnswerNodeData(BaseNodeData): - """ - Answer Node Data. - """ - - type: NodeType = BuiltinNodeTypes.ANSWER - answer: str = Field(..., description="answer template string") - - -class GenerateRouteChunk(BaseModel): - """ - Generate Route Chunk. - """ - - class ChunkType(StrEnum): - VAR = auto() - TEXT = auto() - - type: ChunkType = Field(..., description="generate route chunk type") - - -class VarGenerateRouteChunk(GenerateRouteChunk): - """ - Var Generate Route Chunk. - """ - - type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR - """generate route chunk type""" - value_selector: Sequence[str] = Field(..., description="value selector") - - -class TextGenerateRouteChunk(GenerateRouteChunk): - """ - Text Generate Route Chunk. - """ - - type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT - """generate route chunk type""" - text: str = Field(..., description="text") - - -class AnswerNodeDoubleLink(BaseModel): - node_id: str = Field(..., description="node id") - source_node_ids: list[str] = Field(..., description="source node ids") - target_node_ids: list[str] = Field(..., description="target node ids") - - -class AnswerStreamGenerateRoute(BaseModel): - """ - AnswerStreamGenerateRoute entity - """ - - answer_dependencies: dict[str, list[str]] = Field( - ..., description="answer dependencies (answer node id -> dependent answer node ids)" - ) - answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( - ..., description="answer generate route (answer node id -> generate route chunks)" - ) diff --git a/api/graphon/nodes/base/__init__.py b/api/graphon/nodes/base/__init__.py deleted file mode 100644 index 036e25895d..0000000000 --- a/api/graphon/nodes/base/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState -from .usage_tracking_mixin import LLMUsageTrackingMixin - -__all__ = [ - "BaseIterationNodeData", - "BaseIterationState", - "BaseLoopNodeData", - "BaseLoopState", - "LLMUsageTrackingMixin", -] diff --git a/api/graphon/nodes/base/entities.py b/api/graphon/nodes/base/entities.py deleted file mode 100644 index 94b88c097d..0000000000 --- a/api/graphon/nodes/base/entities.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from enum import StrEnum -from typing import Any - -from pydantic import BaseModel, field_validator - -from graphon.entities.base_node_data import BaseNodeData - - -class VariableSelector(BaseModel): - """ - Variable Selector. - """ - - variable: str - value_selector: Sequence[str] - - -class OutputVariableType(StrEnum): - STRING = "string" - NUMBER = "number" - INTEGER = "integer" - SECRET = "secret" - BOOLEAN = "boolean" - OBJECT = "object" - FILE = "file" - ARRAY = "array" - ARRAY_STRING = "array[string]" - ARRAY_NUMBER = "array[number]" - ARRAY_OBJECT = "array[object]" - ARRAY_BOOLEAN = "array[boolean]" - ARRAY_FILE = "array[file]" - ANY = "any" - ARRAY_ANY = "array[any]" - - -class OutputVariableEntity(BaseModel): - """ - Output Variable Entity. - """ - - variable: str - value_type: OutputVariableType = OutputVariableType.ANY - value_selector: Sequence[str] - - @field_validator("value_type", mode="before") - @classmethod - def normalize_value_type(cls, v: Any) -> Any: - """ - Normalize value_type to handle case-insensitive array types. - Converts 'Array[...]' to 'array[...]' for backward compatibility. - """ - if isinstance(v, str) and v.startswith("Array["): - return v.lower() - return v - - -class BaseIterationNodeData(BaseNodeData): - start_node_id: str | None = None - - -class BaseIterationState(BaseModel): - iteration_node_id: str - index: int - inputs: dict - - class MetaData(BaseModel): - pass - - metadata: MetaData - - -class BaseLoopNodeData(BaseNodeData): - start_node_id: str | None = None - - -class BaseLoopState(BaseModel): - loop_node_id: str - index: int - inputs: dict - - class MetaData(BaseModel): - pass - - metadata: MetaData diff --git a/api/graphon/nodes/base/node.py b/api/graphon/nodes/base/node.py deleted file mode 100644 index 613ff4f037..0000000000 --- a/api/graphon/nodes/base/node.py +++ /dev/null @@ -1,787 +0,0 @@ -from __future__ import annotations - -import logging -import operator -from abc import abstractmethod -from collections.abc import Generator, Mapping, Sequence -from datetime import UTC, datetime -from functools import singledispatchmethod -from types import MappingProxyType -from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin -from uuid import uuid4 - -from graphon.entities import GraphInitParams -from graphon.entities.base_node_data import BaseNodeData, RetryConfig -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ( - ErrorStrategy, - NodeExecutionType, - NodeState, - NodeType, - WorkflowNodeExecutionStatus, -) -from graphon.graph_events import ( - GraphNodeEventBase, - NodeRunAgentLogEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunRetrieverResourceEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, -) -from graphon.node_events import ( - AgentLogEvent, - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - IterationFailedEvent, - IterationNextEvent, - IterationStartedEvent, - IterationSucceededEvent, - LoopFailedEvent, - LoopNextEvent, - LoopStartedEvent, - LoopSucceededEvent, - NodeEventBase, - NodeRunResult, - PauseRequestedEvent, - RunRetrieverResourceEvent, - StreamChunkEvent, - StreamCompletedEvent, - VariableUpdatedEvent, -) -from graphon.runtime import GraphRuntimeState - -NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) -_MISSING_RUN_CONTEXT_VALUE = object() - -logger = logging.getLogger(__name__) - - -class Node(Generic[NodeDataT]): - """BaseNode serves as the foundational class for all node implementations. - - Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output` - attribute to track files generated by the LLM). However, these states are not persisted - when the workflow is suspended or resumed. If a node needs its state to be preserved - across workflow suspension and resumption, it should include the relevant state data - in its output. - """ - - node_type: ClassVar[NodeType] - execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE - _node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData - - def __init_subclass__(cls, **kwargs: Any) -> None: - """ - Automatically extract and validate the node data type from the generic parameter. - - When a subclass is defined as `class MyNode(Node[MyNodeData])`, this method: - 1. Inspects `__orig_bases__` to find the `Node[T]` parameterization - 2. Extracts `T` (e.g., `MyNodeData`) from the generic argument - 3. Validates that `T` is a proper `BaseNodeData` subclass - 4. Stores it in `_node_data_type` for automatic hydration in `__init__` - - This eliminates the need for subclasses to manually implement boilerplate - accessor methods like `_get_title()`, `_get_error_strategy()`, etc. - - How it works: - :: - - class CodeNode(Node[CodeNodeData]): - │ │ - │ └─────────────────────────────────┐ - │ │ - ▼ ▼ - ┌─────────────────────────────┐ ┌─────────────────────────────────┐ - │ __orig_bases__ = ( │ │ CodeNodeData(BaseNodeData) │ - │ Node[CodeNodeData], │ │ title: str │ - │ ) │ │ desc: str | None │ - └──────────────┬──────────────┘ │ ... │ - │ └─────────────────────────────────┘ - ▼ ▲ - ┌─────────────────────────────┐ │ - │ get_origin(base) -> Node │ │ - │ get_args(base) -> ( │ │ - │ CodeNodeData, │ ──────────────────────┘ - │ ) │ - └──────────────┬──────────────┘ - │ - ▼ - ┌─────────────────────────────┐ - │ Validate: │ - │ - Is it a type? │ - │ - Is it a BaseNodeData │ - │ subclass? │ - └──────────────┬──────────────┘ - │ - ▼ - ┌─────────────────────────────┐ - │ cls._node_data_type = │ - │ CodeNodeData │ - └─────────────────────────────┘ - - Later, in __init__: - :: - - config["data"] ──► _node_data_type.model_validate(..., from_attributes=True) - │ - ▼ - CodeNodeData instance - (stored in self._node_data) - - Example: - class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted - node_type = BuiltinNodeTypes.CODE - # No need to implement _get_title, _get_error_strategy, etc. - """ - super().__init_subclass__(**kwargs) - - if cls is Node: - return - - node_data_type = cls._extract_node_data_type_from_generic() - - if node_data_type is None: - raise TypeError(f"{cls.__name__} must inherit from Node[T] with a BaseNodeData subtype") - - cls._node_data_type = node_data_type - - # Skip base class itself - if cls is Node: - return - # Only treat nodes from the base graphon package as production - # registrations. Higher-layer packages may still register subclasses, - # but graphon itself should not know their module identities. - # This prevents test helper subclasses from polluting the global registry and - # accidentally overriding real node types (e.g., a test Answer node). - module_name = getattr(cls, "__module__", "") - # Only register concrete subclasses that define node_type and version() - node_type = cls.node_type - version = cls.version() - bucket = Node._registry.setdefault(node_type, {}) - if module_name.startswith("graphon.nodes."): - # Production node definitions take precedence and may override - bucket[version] = cls # type: ignore[index] - else: - # External/test subclasses may register but must not override production - bucket.setdefault(version, cls) # type: ignore[index] - # Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic - version_keys = [v for v in bucket if v != "latest"] - numeric_pairs: list[tuple[str, int]] = [] - for v in version_keys: - numeric_pairs.append((v, int(v))) - if numeric_pairs: - latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0] - else: - latest_key = max(version_keys) if version_keys else version - bucket["latest"] = bucket[latest_key] - Node._registry_version += 1 - - @classmethod - def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None: - """ - Extract the node data type from the generic parameter `Node[T]`. - - Inspects `__orig_bases__` to find the `Node[T]` parameterization and extracts `T`. - - Returns: - The extracted BaseNodeData subtype, or None if not found. - - Raises: - TypeError: If the generic argument is invalid (not exactly one argument, - or not a BaseNodeData subtype). - """ - # __orig_bases__ contains the original generic bases before type erasure. - # For `class CodeNode(Node[CodeNodeData])`, this would be `(Node[CodeNodeData],)`. - for base in getattr(cls, "__orig_bases__", ()): # type: ignore[attr-defined] - origin = get_origin(base) # Returns `Node` for `Node[CodeNodeData]` - if origin is Node: - args = get_args(base) # Returns `(CodeNodeData,)` for `Node[CodeNodeData]` - if len(args) != 1: - raise TypeError(f"{cls.__name__} must specify exactly one node data generic argument") - - candidate = args[0] - if not isinstance(candidate, type) or not issubclass(candidate, BaseNodeData): - raise TypeError(f"{cls.__name__} must parameterize Node with a BaseNodeData subtype") - - return candidate - - return None - - # Global registry populated via __init_subclass__ - _registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {} - _registry_version: ClassVar[int] = 0 - - @classmethod - def get_registry_version(cls) -> int: - return cls._registry_version - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - ) -> None: - self._graph_init_params = graph_init_params - self._run_context = MappingProxyType(dict(graph_init_params.run_context)) - self.id = id - self.workflow_id = graph_init_params.workflow_id - self.graph_config = graph_init_params.graph_config - self.workflow_call_depth = graph_init_params.call_depth - self.graph_runtime_state = graph_runtime_state - self.state: NodeState = NodeState.UNKNOWN # node execution state - - node_id = config["id"] - - self._node_id = node_id - self._node_execution_id: str = "" - self._start_at = datetime.now(UTC).replace(tzinfo=None) - - self._node_data = self.validate_node_data(config["data"]) - - self.post_init() - - @classmethod - def validate_node_data(cls, node_data: BaseNodeData | Mapping[str, Any]) -> NodeDataT: - """Validate shared graph node payloads against the subclass-declared NodeData model. - - Re-validate from a dumped payload instead of `from_attributes=True` so compatibility - extras stored on `BaseNodeData` survive the handoff to the concrete node data model. - Human Input delivery methods are one such extra field until graphon owns that schema. - """ - if isinstance(node_data, BaseNodeData): - payload = node_data.model_dump(mode="python") - else: - payload = dict(node_data) - return cast(NodeDataT, cls._node_data_type.model_validate(payload)) - - def init_node_data(self, data: BaseNodeData | Mapping[str, Any]) -> None: - """Hydrate `_node_data` for legacy callers that bypass `__init__`.""" - self._node_data = self.validate_node_data(cast(BaseNodeData, data)) - - def post_init(self) -> None: - """Optional hook for subclasses requiring extra initialization.""" - return - - @property - def graph_init_params(self) -> GraphInitParams: - return self._graph_init_params - - @property - def run_context(self) -> Mapping[str, Any]: - return self._run_context - - def get_run_context_value(self, key: str, default: Any = None) -> Any: - return self._run_context.get(key, default) - - def require_run_context_value(self, key: str) -> Any: - value = self.get_run_context_value(key, _MISSING_RUN_CONTEXT_VALUE) - if value is _MISSING_RUN_CONTEXT_VALUE: - raise ValueError(f"run_context missing required key: {key}") - return value - - @property - def execution_id(self) -> str: - return self._node_execution_id - - def ensure_execution_id(self) -> str: - if self._node_execution_id: - return self._node_execution_id - - resumed_execution_id = self._restore_execution_id_from_runtime_state() - if resumed_execution_id: - self._node_execution_id = resumed_execution_id - return self._node_execution_id - - self._node_execution_id = str(uuid4()) - return self._node_execution_id - - def _restore_execution_id_from_runtime_state(self) -> str | None: - graph_execution = self.graph_runtime_state.graph_execution - try: - node_executions = graph_execution.node_executions - except AttributeError: - return None - if not isinstance(node_executions, dict): - return None - node_execution = node_executions.get(self._node_id) - if node_execution is None: - return None - execution_id = node_execution.execution_id - if not execution_id: - return None - return str(execution_id) - - @abstractmethod - def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: - """ - Run node - :return: - """ - raise NotImplementedError - - def populate_start_event(self, event: NodeRunStartedEvent) -> None: - """Allow subclasses to enrich the started event without cross-node imports in the base class.""" - _ = event - - def run(self) -> Generator[GraphNodeEventBase, None, None]: - execution_id = self.ensure_execution_id() - self._start_at = datetime.now(UTC).replace(tzinfo=None) - - # Create and push start event with required fields - start_event = NodeRunStartedEvent( - id=execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.title, - in_iteration_id=None, - start_at=self._start_at, - ) - try: - self.populate_start_event(start_event) - except Exception: - logger.warning("Failed to populate start event for node %s", self._node_id, exc_info=True) - yield start_event - - try: - result = self._run() - - # Handle NodeRunResult - if isinstance(result, NodeRunResult): - yield self._convert_node_run_result_to_graph_node_event(result) - return - - # Handle event stream - for event in result: - # NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase - if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance] - yield self._dispatch(event) - elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance] - event.id = self.execution_id - yield event - else: - yield event - except Exception as e: - logger.exception("Node %s failed to run", self._node_id) - result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - error_type="WorkflowNodeError", - ) - finished_at = datetime.now(UTC).replace(tzinfo=None) - yield NodeRunFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=result, - error=str(e), - ) - - @classmethod - def extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - config: NodeConfigDict, - ) -> Mapping[str, Sequence[str]]: - """Extracts references variable selectors from node configuration. - - The `config` parameter represents the configuration for a specific node type and corresponds - to the `data` field in the node definition object. - - The returned mapping has the following structure: - - {'1747829548239.#1747829667553.result#': ['1747829667553', 'result']} - - For loop and iteration nodes, the mapping may look like this: - - { - "1748332301644.input_selector": ["1748332363630", "result"], - "1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"], - } - - where `1748332301644` is the ID of the loop / iteration node, - and `1748332325079` is the ID of the node inside the loop or iteration node. - - Here, the key consists of two parts: the current node ID (provided as the `node_id` - parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector, - enclosed in `#` symbols. These two parts are separated by a dot (`.`). - - The value is a list of string representing the variable selector, where the first element is the node ID - of the referenced variable, and the second element is the variable name within that node. - - The meaning of the above response is: - - The node with ID `1747829548239` references the variable `result` from the node with - ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a - reference to the `result` output variable of node `1747829667553`. - - :param graph_config: graph config - :param config: node config - :return: - """ - node_id = config["id"] - node_data = cls.validate_node_data(config["data"]) - data = cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, - node_id=node_id, - node_data=node_data, - ) - return data - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: NodeDataT, - ) -> Mapping[str, Sequence[str]]: - return {} - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: - """ - Check if this node blocks the output of specific variables. - - This method is used to determine if a node must complete execution before - the specified variables can be used in streaming output. - - :param variable_selectors: Set of variable selectors, each as a tuple (e.g., ('conversation', 'str')) - :return: True if this node blocks output of any of the specified variables, False otherwise - """ - return False - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return {} - - @classmethod - @abstractmethod - def version(cls) -> str: - """`node_version` returns the version of current node type.""" - # NOTE(QuantumGhost): Node versions must remain unique per `NodeType` so - # registry lookups can resolve numeric versions and `latest`. - raise NotImplementedError("subclasses of BaseNode must implement `version` method.") - - @classmethod - def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]: - """Return a read-only view of the currently registered node classes. - - This accessor intentionally performs no imports. The embedding layer that - owns bootstrap (for example `core.workflow.node_factory`) must import any - extension node packages before calling it so their subclasses register via - `__init_subclass__`. - """ - return {node_type: MappingProxyType(version_map) for node_type, version_map in cls._registry.items()} - - @property - def retry(self) -> bool: - return False - - def _get_error_strategy(self) -> ErrorStrategy | None: - """Get the error strategy for this node.""" - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - """Get the retry configuration for this node.""" - return self._node_data.retry_config - - def _get_title(self) -> str: - """Get the node title.""" - return self._node_data.title - - def _get_description(self) -> str | None: - """Get the node description.""" - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - """Get the default values dictionary for this node.""" - return self._node_data.default_value_dict - - # Public interface properties that delegate to abstract methods - @property - def error_strategy(self) -> ErrorStrategy | None: - """Get the error strategy for this node.""" - return self._get_error_strategy() - - @property - def retry_config(self) -> RetryConfig: - """Get the retry configuration for this node.""" - return self._get_retry_config() - - @property - def title(self) -> str: - """Get the node title.""" - return self._get_title() - - @property - def description(self) -> str | None: - """Get the node description.""" - return self._get_description() - - @property - def default_value_dict(self) -> dict[str, Any]: - """Get the default values dictionary for this node.""" - return self._get_default_value_dict() - - @property - def node_data(self) -> NodeDataT: - """Typed access to this node's configuration data.""" - return self._node_data - - def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase: - finished_at = datetime.now(UTC).replace(tzinfo=None) - match result.status: - case WorkflowNodeExecutionStatus.FAILED: - return NodeRunFailedEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=result, - error=result.error, - ) - case WorkflowNodeExecutionStatus.SUCCEEDED: - return NodeRunSucceededEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=result, - ) - case _: - raise Exception(f"result status {result.status} not supported") - - @singledispatchmethod - def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase: - raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}") - - @_dispatch.register - def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: - return NodeRunStreamChunkEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - selector=event.selector, - chunk=event.chunk, - is_final=event.is_final, - ) - - @_dispatch.register - def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent: - finished_at = datetime.now(UTC).replace(tzinfo=None) - match event.node_run_result.status: - case WorkflowNodeExecutionStatus.SUCCEEDED: - return NodeRunSucceededEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=event.node_run_result, - ) - case WorkflowNodeExecutionStatus.FAILED: - return NodeRunFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=event.node_run_result, - error=event.node_run_result.error, - ) - case _: - raise NotImplementedError( - f"Node {self._node_id} does not support status {event.node_run_result.status}" - ) - - @_dispatch.register - def _(self, event: VariableUpdatedEvent) -> NodeRunVariableUpdatedEvent: - return NodeRunVariableUpdatedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - variable=event.variable, - ) - - @_dispatch.register - def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent: - return NodeRunPauseRequestedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED), - reason=event.reason, - ) - - @_dispatch.register - def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent: - return NodeRunAgentLogEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - message_id=event.message_id, - label=event.label, - node_execution_id=event.node_execution_id, - parent_id=event.parent_id, - error=event.error, - status=event.status, - data=event.data, - metadata=event.metadata, - ) - - @_dispatch.register - def _(self, event: HumanInputFormFilledEvent): - return NodeRunHumanInputFormFilledEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=event.node_title, - rendered_content=event.rendered_content, - action_id=event.action_id, - action_text=event.action_text, - ) - - @_dispatch.register - def _(self, event: HumanInputFormTimeoutEvent): - return NodeRunHumanInputFormTimeoutEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=event.node_title, - expiration_time=event.expiration_time, - ) - - @_dispatch.register - def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: - return NodeRunLoopStartedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - metadata=event.metadata, - predecessor_node_id=event.predecessor_node_id, - ) - - @_dispatch.register - def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent: - return NodeRunLoopNextEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - index=event.index, - pre_loop_output=event.pre_loop_output, - ) - - @_dispatch.register - def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent: - return NodeRunLoopSucceededEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - ) - - @_dispatch.register - def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent: - return NodeRunLoopFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - error=event.error, - ) - - @_dispatch.register - def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent: - return NodeRunIterationStartedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - metadata=event.metadata, - predecessor_node_id=event.predecessor_node_id, - ) - - @_dispatch.register - def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent: - return NodeRunIterationNextEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - index=event.index, - pre_iteration_output=event.pre_iteration_output, - ) - - @_dispatch.register - def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent: - return NodeRunIterationSucceededEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - ) - - @_dispatch.register - def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent: - return NodeRunIterationFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - error=event.error, - ) - - @_dispatch.register - def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: - return NodeRunRetrieverResourceEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - retriever_resources=event.retriever_resources, - context=event.context, - node_version=self.version(), - ) diff --git a/api/graphon/nodes/base/template.py b/api/graphon/nodes/base/template.py deleted file mode 100644 index 311de4a6ea..0000000000 --- a/api/graphon/nodes/base/template.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Template structures for Response nodes (Answer and End). - -This module provides a unified template structure for both Answer and End nodes, -similar to SegmentGroup but focused on template representation without values. -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Any, Union - -from graphon.nodes.base.variable_template_parser import VariableTemplateParser - - -@dataclass(frozen=True) -class TemplateSegment(ABC): - """Base class for template segments.""" - - @abstractmethod - def __str__(self) -> str: - """String representation of the segment.""" - pass - - -@dataclass(frozen=True) -class TextSegment(TemplateSegment): - """A text segment in a template.""" - - text: str - - def __str__(self) -> str: - return self.text - - -@dataclass(frozen=True) -class VariableSegment(TemplateSegment): - """A variable reference segment in a template.""" - - selector: Sequence[str] - variable_name: str | None = None # Optional variable name for End nodes - - def __str__(self) -> str: - return "{{#" + ".".join(self.selector) + "#}}" - - -# Type alias for segments -TemplateSegmentUnion = Union[TextSegment, VariableSegment] - - -@dataclass(frozen=True) -class Template: - """Unified template structure for Response nodes. - - Similar to SegmentGroup, but represents the template structure - without variable values - only marking variable selectors. - """ - - segments: list[TemplateSegmentUnion] - - @classmethod - def from_answer_template(cls, template_str: str) -> Template: - """Create a Template from an Answer node template string. - - Example: - "Hello, {{#node1.name#}}" -> [TextSegment("Hello, "), VariableSegment(["node1", "name"])] - - Args: - template_str: The answer template string - - Returns: - Template instance - """ - parser = VariableTemplateParser(template_str) - segments: list[TemplateSegmentUnion] = [] - - # Extract variable selectors to find all variables - variable_selectors = parser.extract_variable_selectors() - var_map = {var.variable: var.value_selector for var in variable_selectors} - - # Parse template to get ordered segments - # We need to split the template by variable placeholders while preserving order - import re - - # Create a regex pattern that matches variable placeholders - pattern = r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}" - - # Split template while keeping the delimiters (variable placeholders) - parts = re.split(pattern, template_str) - - for i, part in enumerate(parts): - if not part: - continue - - # Check if this part is a variable reference (odd indices after split) - if i % 2 == 1: # Odd indices are variable keys - # Remove the # symbols from the variable key - var_key = part - if var_key in var_map: - segments.append(VariableSegment(selector=list(var_map[var_key]))) - else: - # This shouldn't happen with valid templates - segments.append(TextSegment(text="{{" + part + "}}")) - else: - # Even indices are text segments - segments.append(TextSegment(text=part)) - - return cls(segments=segments) - - @classmethod - def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template: - """Create a Template from an End node outputs configuration. - - End nodes are treated as templates of concatenated variables with newlines. - - Example: - [{"variable": "text", "value_selector": ["node1", "text"]}, - {"variable": "result", "value_selector": ["node2", "result"]}] - -> - [VariableSegment(["node1", "text"]), - TextSegment("\n"), - VariableSegment(["node2", "result"])] - - Args: - outputs_config: List of output configurations with variable and value_selector - - Returns: - Template instance - """ - segments: list[TemplateSegmentUnion] = [] - - for i, output in enumerate(outputs_config): - if i > 0: - # Add newline separator between variables - segments.append(TextSegment(text="\n")) - - value_selector = output.get("value_selector", []) - variable_name = output.get("variable", "") - if value_selector: - segments.append(VariableSegment(selector=list(value_selector), variable_name=variable_name)) - - if len(segments) > 0 and isinstance(segments[-1], TextSegment): - segments = segments[:-1] - - return cls(segments=segments) - - def __str__(self) -> str: - """String representation of the template.""" - return "".join(str(segment) for segment in self.segments) diff --git a/api/graphon/nodes/base/usage_tracking_mixin.py b/api/graphon/nodes/base/usage_tracking_mixin.py deleted file mode 100644 index 955bfe6726..0000000000 --- a/api/graphon/nodes/base/usage_tracking_mixin.py +++ /dev/null @@ -1,28 +0,0 @@ -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState - - -class LLMUsageTrackingMixin: - """Provides shared helpers for merging and recording LLM usage within workflow nodes.""" - - graph_runtime_state: GraphRuntimeState - - @staticmethod - def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage: - """Return a combined usage snapshot, preserving zero-value inputs.""" - if new_usage is None or new_usage.total_tokens <= 0: - return current - if current.total_tokens == 0: - return new_usage - return current.plus(new_usage) - - def _accumulate_usage(self, usage: LLMUsage) -> None: - """Push usage into the graph runtime accumulator for downstream reporting.""" - if usage.total_tokens <= 0: - return - - current_usage = self.graph_runtime_state.llm_usage - if current_usage.total_tokens == 0: - self.graph_runtime_state.llm_usage = usage.model_copy() - else: - self.graph_runtime_state.llm_usage = current_usage.plus(usage) diff --git a/api/graphon/nodes/base/variable_template_parser.py b/api/graphon/nodes/base/variable_template_parser.py deleted file mode 100644 index de5e619e8c..0000000000 --- a/api/graphon/nodes/base/variable_template_parser.py +++ /dev/null @@ -1,130 +0,0 @@ -import re -from collections.abc import Mapping, Sequence -from typing import Any - -from .entities import VariableSelector - -REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") - -SELECTOR_PATTERN = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") - - -def extract_selectors_from_template(template: str, /) -> Sequence[VariableSelector]: - parts = SELECTOR_PATTERN.split(template) - selectors = [] - for part in filter(lambda x: x, parts): - if "." in part and part[0] == "#" and part[-1] == "#": - selectors.append(VariableSelector(variable=f"{part}", value_selector=part[1:-1].split("."))) - return selectors - - -class VariableTemplateParser: - """ - !NOTE: Consider to use the new `segments` module instead of this class. - - A class for parsing and manipulating template variables in a string. - - Rules: - - 1. Template variables must be enclosed in `{{}}`. - 2. The template variable Key can only be: #node_id.var1.var2#. - 3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2. - - Example usage: - - template = "Hello, {{#node_id.query.name#}}! Your age is {{#node_id.query.age#}}." - parser = VariableTemplateParser(template) - - # Extract template variable keys - variable_keys = parser.extract() - print(variable_keys) - # Output: ['#node_id.query.name#', '#node_id.query.age#'] - - # Extract variable selectors - variable_selectors = parser.extract_variable_selectors() - print(variable_selectors) - # Output: [VariableSelector(variable='#node_id.query.name#', value_selector=['node_id', 'query', 'name']), - # VariableSelector(variable='#node_id.query.age#', value_selector=['node_id', 'query', 'age'])] - - # Format the template string - inputs = {'#node_id.query.name#': 'John', '#node_id.query.age#': 25}} - formatted_string = parser.format(inputs) - print(formatted_string) - # Output: "Hello, John! Your age is 25." - """ - - def __init__(self, template: str): - self.template = template - self.variable_keys = self.extract() - - def extract(self): - """ - Extracts all the template variable keys from the template string. - - Returns: - A list of template variable keys. - """ - # Regular expression to match the template rules - matches = re.findall(REGEX, self.template) - - first_group_matches = [match[0] for match in matches] - - return list(set(first_group_matches)) - - def extract_variable_selectors(self) -> list[VariableSelector]: - """ - Extracts the variable selectors from the template variable keys. - - Returns: - A list of VariableSelector objects representing the variable selectors. - """ - variable_selectors = [] - for variable_key in self.variable_keys: - remove_hash = variable_key.replace("#", "") - split_result = remove_hash.split(".") - if len(split_result) < 2: - continue - - variable_selectors.append(VariableSelector(variable=variable_key, value_selector=split_result)) - - return variable_selectors - - def format(self, inputs: Mapping[str, Any]) -> str: - """ - Formats the template string by replacing the template variables with their corresponding values. - - Args: - inputs: A dictionary containing the values for the template variables. - - Returns: - The formatted string with template variables replaced by their values. - """ - - def replacer(match): - key = match.group(1) - value = inputs.get(key, match.group(0)) # return original matched string if key not found - - if value is None: - value = "" - # convert the value to string - if isinstance(value, list | dict | bool | int | float): - value = str(value) - - # remove template variables if required - return VariableTemplateParser.remove_template_variables(value) - - prompt = re.sub(REGEX, replacer, self.template) - return re.sub(r"<\|.*?\|>", "", prompt) - - @classmethod - def remove_template_variables(cls, text: str): - """ - Removes the template variables from the given text. - - Args: - text: The text from which to remove the template variables. - - Returns: - The text with template variables removed. - """ - return re.sub(REGEX, r"{\1}", text) diff --git a/api/graphon/nodes/code/__init__.py b/api/graphon/nodes/code/__init__.py deleted file mode 100644 index 8c6dcc7fcc..0000000000 --- a/api/graphon/nodes/code/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .code_node import CodeNode - -__all__ = ["CodeNode"] diff --git a/api/graphon/nodes/code/code_node.py b/api/graphon/nodes/code/code_node.py deleted file mode 100644 index c2eea0bec1..0000000000 --- a/api/graphon/nodes/code/code_node.py +++ /dev/null @@ -1,493 +0,0 @@ -from collections.abc import Mapping, Sequence -from decimal import Decimal -from textwrap import dedent -from typing import TYPE_CHECKING, Any, Protocol, cast - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.code.entities import CodeLanguage, CodeNodeData -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.variables.segments import ArrayFileSegment -from graphon.variables.types import SegmentType - -from .exc import ( - CodeNodeError, - DepthLimitError, - OutputValidationError, -) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - - -class WorkflowCodeExecutor(Protocol): - def execute( - self, - *, - language: CodeLanguage, - code: str, - inputs: Mapping[str, Any], - ) -> Mapping[str, Any]: ... - - def is_execution_error(self, error: Exception) -> bool: ... - - -def _build_default_config(*, language: CodeLanguage, code: str) -> Mapping[str, object]: - return { - "type": "code", - "config": { - "variables": [ - {"variable": "arg1", "value_selector": []}, - {"variable": "arg2", "value_selector": []}, - ], - "code_language": language, - "code": code, - "outputs": {"result": {"type": "string", "children": None}}, - }, - } - - -_DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = { - CodeLanguage.PYTHON3: dedent( - """ - def main(arg1: str, arg2: str): - return { - "result": arg1 + arg2, - } - """ - ), - CodeLanguage.JAVASCRIPT: dedent( - """ - function main({arg1, arg2}) { - return { - result: arg1 + arg2 - } - } - """ - ), -} - - -class CodeNode(Node[CodeNodeData]): - node_type = BuiltinNodeTypes.CODE - _limits: CodeNodeLimits - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - code_executor: WorkflowCodeExecutor, - code_limits: CodeNodeLimits, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._code_executor: WorkflowCodeExecutor = code_executor - self._limits = code_limits - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ - code_language = CodeLanguage.PYTHON3 - if filters: - code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3)) - - default_code = _DEFAULT_CODE_BY_LANGUAGE.get(code_language) - if default_code is None: - raise CodeNodeError(f"Unsupported code language: {code_language}") - return _build_default_config(language=code_language, code=default_code) - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - # Get code language - code_language = self.node_data.code_language - code = self.node_data.code - - # Get variables - variables = {} - for variable_selector in self.node_data.variables: - variable_name = variable_selector.variable - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if isinstance(variable, ArrayFileSegment): - variables[variable_name] = [v.to_dict() for v in variable.value] if variable.value else None - else: - variables[variable_name] = variable.to_object() if variable else None - # Run code - try: - result = self._code_executor.execute( - language=code_language, - code=code, - inputs=variables, - ) - - # Transform result - result = self._transform_result(result=result, output_schema=self.node_data.outputs) - except CodeNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ - ) - except Exception as e: - if not self._code_executor.is_execution_error(e): - raise - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ - ) - - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) - - def _check_string(self, value: str | None, variable: str) -> str | None: - """ - Check string - :param value: value - :param variable: variable - :return: - """ - if value is None: - return None - - if len(value) > self._limits.max_string_length: - raise OutputValidationError( - f"The length of output variable `{variable}` must be" - f" less than {self._limits.max_string_length} characters" - ) - - return value.replace("\x00", "") - - def _check_boolean(self, value: bool | None, variable: str) -> bool | None: - if value is None: - return None - - return value - - def _check_number(self, value: int | float | None, variable: str) -> int | float | None: - """ - Check number - :param value: value - :param variable: variable - :return: - """ - if value is None: - return None - - if value > self._limits.max_number or value < self._limits.min_number: - raise OutputValidationError( - f"Output variable `{variable}` is out of range," - f" it must be between {self._limits.min_number} and {self._limits.max_number}." - ) - - if isinstance(value, float): - decimal_value = Decimal(str(value)).normalize() - precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator] - # raise error if precision is too high - if precision > self._limits.max_precision: - raise OutputValidationError( - f"Output variable `{variable}` has too high precision," - f" it must be less than {self._limits.max_precision} digits." - ) - - return value - - def _transform_result( - self, - result: Mapping[str, Any], - output_schema: dict[str, CodeNodeData.Output] | None, - prefix: str = "", - depth: int = 1, - ): - # TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes. - # Note that `_transform_result` may produce lists containing `None` values, - # which don't conform to the type requirements of `Array*Segment` classes. - if depth > self._limits.max_depth: - raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.") - - transformed_result: dict[str, Any] = {} - if output_schema is None: - # validate output thought instance type - for output_name, output_value in result.items(): - if isinstance(output_value, dict): - self._transform_result( - result=output_value, - output_schema=None, - prefix=f"{prefix}.{output_name}" if prefix else output_name, - depth=depth + 1, - ) - elif isinstance(output_value, bool): - self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name) - elif isinstance(output_value, int | float): - self._check_number( - value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name - ) - elif isinstance(output_value, str): - self._check_string( - value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name - ) - elif isinstance(output_value, list): - first_element = output_value[0] if len(output_value) > 0 else None - if first_element is not None: - if isinstance(first_element, int | float) and all( - value is None or isinstance(value, int | float) for value in output_value - ): - for i, value in enumerate(output_value): - self._check_number( - value=value, - variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", - ) - elif isinstance(first_element, str) and all( - value is None or isinstance(value, str) for value in output_value - ): - for i, value in enumerate(output_value): - self._check_string( - value=value, - variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", - ) - elif ( - isinstance(first_element, dict) - and all(value is None or isinstance(value, dict) for value in output_value) - or isinstance(first_element, list) - and all(value is None or isinstance(value, list) for value in output_value) - ): - for i, value in enumerate(output_value): - if value is not None: - self._transform_result( - result=value, - output_schema=None, - prefix=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", - depth=depth + 1, - ) - else: - raise OutputValidationError( - f"Output {prefix}.{output_name} is not a valid array." - f" make sure all elements are of the same type." - ) - elif output_value is None: - pass - else: - raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.") - - return result - - parameters_validated = {} - for output_name, output_config in output_schema.items(): - dot = "." if prefix else "" - if output_name not in result: - raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.") - - if output_config.type == SegmentType.OBJECT: - # check if output is object - if not isinstance(result.get(output_name), dict): - if result[output_name] is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an object," - f" got {type(result.get(output_name))} instead." - ) - else: - transformed_result[output_name] = self._transform_result( - result=result[output_name], - output_schema=output_config.children, - prefix=f"{prefix}.{output_name}", - depth=depth + 1, - ) - elif output_config.type == SegmentType.NUMBER: - # check if number available - value = result.get(output_name) - if value is not None and not isinstance(value, (int, float)): - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not a number," - f" got {type(result.get(output_name))} instead." - ) - checked = self._check_number(value=value, variable=f"{prefix}{dot}{output_name}") - # If the output is a boolean and the output schema specifies a NUMBER type, - # convert the boolean value to an integer. - # - # This ensures compatibility with existing workflows that may use - # `True` and `False` as values for NUMBER type outputs. - transformed_result[output_name] = self._convert_boolean_to_int(checked) - - elif output_config.type == SegmentType.STRING: - # check if string available - value = result.get(output_name) - if value is not None and not isinstance(value, str): - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} must be a string, got {type(value).__name__} instead" - ) - transformed_result[output_name] = self._check_string( - value=value, - variable=f"{prefix}{dot}{output_name}", - ) - elif output_config.type == SegmentType.BOOLEAN: - transformed_result[output_name] = self._check_boolean( - value=result[output_name], - variable=f"{prefix}{dot}{output_name}", - ) - elif output_config.type == SegmentType.ARRAY_NUMBER: - # check if array of number available - value = result[output_name] - if not isinstance(value, list): - if value is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead." - ) - else: - if len(value) > self._limits.max_number_array_length: - raise OutputValidationError( - f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {self._limits.max_number_array_length} elements." - ) - - for i, inner_value in enumerate(value): - if not isinstance(inner_value, (int, float)): - raise OutputValidationError( - f"The element at index {i} of output variable `{prefix}{dot}{output_name}` must be" - f" a number." - ) - _ = self._check_number(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]") - transformed_result[output_name] = [ - # If the element is a boolean and the output schema specifies a `array[number]` type, - # convert the boolean value to an integer. - # - # This ensures compatibility with existing workflows that may use - # `True` and `False` as values for NUMBER type outputs. - self._convert_boolean_to_int(v) - for v in value - ] - elif output_config.type == SegmentType.ARRAY_STRING: - # check if array of string available - if not isinstance(result[output_name], list): - if result[output_name] is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array," - f" got {type(result.get(output_name))} instead." - ) - else: - if len(result[output_name]) > self._limits.max_string_array_length: - raise OutputValidationError( - f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {self._limits.max_string_array_length} elements." - ) - - transformed_result[output_name] = [ - self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") - for i, value in enumerate(result[output_name]) - ] - elif output_config.type == SegmentType.ARRAY_OBJECT: - # check if array of object available - if not isinstance(result[output_name], list): - if result[output_name] is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array," - f" got {type(result.get(output_name))} instead." - ) - else: - if len(result[output_name]) > self._limits.max_object_array_length: - raise OutputValidationError( - f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {self._limits.max_object_array_length} elements." - ) - - for i, value in enumerate(result[output_name]): - if not isinstance(value, dict): - if value is None: - pass - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name}[{i}] is not an object," - f" got {type(value)} instead at index {i}." - ) - - transformed_result[output_name] = [ - None - if value is None - else self._transform_result( - result=value, - output_schema=output_config.children, - prefix=f"{prefix}{dot}{output_name}[{i}]", - depth=depth + 1, - ) - for i, value in enumerate(result[output_name]) - ] - elif output_config.type == SegmentType.ARRAY_BOOLEAN: - # check if array of object available - value = result[output_name] - if not isinstance(value, list): - if value is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array," - f" got {type(result.get(output_name))} instead." - ) - else: - for i, inner_value in enumerate(value): - if inner_value is not None and not isinstance(inner_value, bool): - raise OutputValidationError( - f"Output {prefix}{dot}{output_name}[{i}] is not a boolean," - f" got {type(inner_value)} instead." - ) - _ = self._check_boolean(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]") - transformed_result[output_name] = value - - else: - raise OutputValidationError(f"Output type {output_config.type} is not supported.") - - parameters_validated[output_name] = True - - # check if all output parameters are validated - if len(parameters_validated) != len(result): - raise CodeNodeError("Not all output parameters are validated.") - - return transformed_result - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: CodeNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - return { - node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables - } - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled - - @staticmethod - def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None: - """This function convert boolean to integers when the output schema specifies a NUMBER type. - - This ensures compatibility with existing workflows that may use - `True` and `False` as values for NUMBER type outputs. - """ - if value is None: - return None - if isinstance(value, bool): - return int(value) - return value diff --git a/api/graphon/nodes/code/entities.py b/api/graphon/nodes/code/entities.py deleted file mode 100644 index dc89d64495..0000000000 --- a/api/graphon/nodes/code/entities.py +++ /dev/null @@ -1,57 +0,0 @@ -from enum import StrEnum -from typing import Annotated, Literal - -from pydantic import AfterValidator, BaseModel - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base.entities import VariableSelector -from graphon.variables.types import SegmentType - - -class CodeLanguage(StrEnum): - PYTHON3 = "python3" - JINJA2 = "jinja2" - JAVASCRIPT = "javascript" - - -_ALLOWED_OUTPUT_FROM_CODE = frozenset( - [ - SegmentType.STRING, - SegmentType.NUMBER, - SegmentType.OBJECT, - SegmentType.BOOLEAN, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_BOOLEAN, - ] -) - - -def _validate_type(segment_type: SegmentType) -> SegmentType: - if segment_type not in _ALLOWED_OUTPUT_FROM_CODE: - raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}") - return segment_type - - -class CodeNodeData(BaseNodeData): - """ - Code Node Data. - """ - - type: NodeType = BuiltinNodeTypes.CODE - - class Output(BaseModel): - type: Annotated[SegmentType, AfterValidator(_validate_type)] - children: dict[str, "CodeNodeData.Output"] | None = None - - class Dependency(BaseModel): - name: str - version: str - - variables: list[VariableSelector] - code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] - code: str - outputs: dict[str, Output] - dependencies: list[Dependency] | None = None diff --git a/api/graphon/nodes/code/exc.py b/api/graphon/nodes/code/exc.py deleted file mode 100644 index d6334fd554..0000000000 --- a/api/graphon/nodes/code/exc.py +++ /dev/null @@ -1,16 +0,0 @@ -class CodeNodeError(ValueError): - """Base class for code node errors.""" - - pass - - -class OutputValidationError(CodeNodeError): - """Raised when there is an output validation error.""" - - pass - - -class DepthLimitError(CodeNodeError): - """Raised when the depth limit is reached.""" - - pass diff --git a/api/graphon/nodes/code/limits.py b/api/graphon/nodes/code/limits.py deleted file mode 100644 index a6b9e9e68e..0000000000 --- a/api/graphon/nodes/code/limits.py +++ /dev/null @@ -1,13 +0,0 @@ -from dataclasses import dataclass - - -@dataclass(frozen=True) -class CodeNodeLimits: - max_string_length: int - max_number: int | float - min_number: int | float - max_precision: int - max_depth: int - max_number_array_length: int - max_string_array_length: int - max_object_array_length: int diff --git a/api/graphon/nodes/document_extractor/__init__.py b/api/graphon/nodes/document_extractor/__init__.py deleted file mode 100644 index 9922e3949d..0000000000 --- a/api/graphon/nodes/document_extractor/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .entities import DocumentExtractorNodeData, UnstructuredApiConfig -from .node import DocumentExtractorNode - -__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData", "UnstructuredApiConfig"] diff --git a/api/graphon/nodes/document_extractor/entities.py b/api/graphon/nodes/document_extractor/entities.py deleted file mode 100644 index 026a0cd224..0000000000 --- a/api/graphon/nodes/document_extractor/entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from collections.abc import Sequence -from dataclasses import dataclass - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - - -class DocumentExtractorNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.DOCUMENT_EXTRACTOR - variable_selector: Sequence[str] - - -@dataclass(frozen=True) -class UnstructuredApiConfig: - api_url: str | None = None - api_key: str = "" diff --git a/api/graphon/nodes/document_extractor/exc.py b/api/graphon/nodes/document_extractor/exc.py deleted file mode 100644 index 5caf00ebc5..0000000000 --- a/api/graphon/nodes/document_extractor/exc.py +++ /dev/null @@ -1,14 +0,0 @@ -class DocumentExtractorError(ValueError): - """Base exception for errors related to the DocumentExtractorNode.""" - - -class FileDownloadError(DocumentExtractorError): - """Exception raised when there's an error downloading a file.""" - - -class UnsupportedFileTypeError(DocumentExtractorError): - """Exception raised when trying to extract text from an unsupported file type.""" - - -class TextExtractionError(DocumentExtractorError): - """Exception raised when there's an error during text extraction from a file.""" diff --git a/api/graphon/nodes/document_extractor/node.py b/api/graphon/nodes/document_extractor/node.py deleted file mode 100644 index be46481e7d..0000000000 --- a/api/graphon/nodes/document_extractor/node.py +++ /dev/null @@ -1,782 +0,0 @@ -import csv -import io -import json -import logging -import os -import tempfile -import zipfile -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -import charset_normalizer -import docx -import pandas as pd -import pypandoc -import pypdfium2 -import webvtt -import yaml -from docx.document import Document -from docx.oxml.table import CT_Tbl -from docx.oxml.text.paragraph import CT_P -from docx.table import Table -from docx.text.paragraph import Paragraph - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, file_manager -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.protocols import HttpClientProtocol -from graphon.variables import ArrayFileSegment -from graphon.variables.segments import ArrayStringSegment, FileSegment - -from .entities import DocumentExtractorNodeData, UnstructuredApiConfig -from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - - -class DocumentExtractorNode(Node[DocumentExtractorNodeData]): - """ - Extracts text content from various file types. - Supports plain text, PDF, and DOC/DOCX files. - """ - - node_type = BuiltinNodeTypes.DOCUMENT_EXTRACTOR - - @classmethod - def version(cls) -> str: - return "1" - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - unstructured_api_config: UnstructuredApiConfig | None = None, - http_client: HttpClientProtocol, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig() - self._http_client = http_client - - def _run(self): - variable_selector = self.node_data.variable_selector - variable = self.graph_runtime_state.variable_pool.get(variable_selector) - - if variable is None: - error_message = f"File variable not found for selector: {variable_selector}" - return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) - if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment): - error_message = f"Variable {variable_selector} is not an ArrayFileSegment" - return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) - - value = variable.value - inputs = {"variable_selector": variable_selector} - if isinstance(value, list): - value = list(filter(lambda x: x, value)) - process_data = {"documents": value if isinstance(value, list) else [value]} - - if not value: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={"text": ArrayStringSegment(value=[])}, - ) - - try: - if isinstance(value, list): - extracted_text_list = [ - _extract_text_from_file( - self._http_client, file, unstructured_api_config=self._unstructured_api_config - ) - for file in value - ] - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={"text": ArrayStringSegment(value=extracted_text_list)}, - ) - elif isinstance(value, File): - extracted_text = _extract_text_from_file( - self._http_client, value, unstructured_api_config=self._unstructured_api_config - ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={"text": extracted_text}, - ) - else: - raise DocumentExtractorError(f"Unsupported variable type: {type(value)}") - except DocumentExtractorError as e: - logger.warning(e, exc_info=True) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=inputs, - process_data=process_data, - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: DocumentExtractorNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - return {node_id + ".files": node_data.variable_selector} - - -def _extract_text_by_mime_type( - *, - file_content: bytes, - mime_type: str, - unstructured_api_config: UnstructuredApiConfig, -) -> str: - """Extract text from a file based on its MIME type.""" - match mime_type: - case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml": - return _extract_text_from_plain_text(file_content) - case "application/pdf": - return _extract_text_from_pdf(file_content) - case "application/msword": - return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config) - case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": - return _extract_text_from_docx(file_content) - case "text/csv": - return _extract_text_from_csv(file_content) - case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel": - return _extract_text_from_excel(file_content) - case "application/vnd.ms-powerpoint": - return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config) - case "application/vnd.openxmlformats-officedocument.presentationml.presentation": - return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config) - case "application/epub+zip": - return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config) - case "message/rfc822": - return _extract_text_from_eml(file_content) - case "application/vnd.ms-outlook": - return _extract_text_from_msg(file_content) - case "application/json": - return _extract_text_from_json(file_content) - case "application/x-yaml" | "text/yaml": - return _extract_text_from_yaml(file_content) - case "text/vtt": - return _extract_text_from_vtt(file_content) - case "text/properties": - return _extract_text_from_properties(file_content) - case _: - raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") - - -def _extract_text_by_file_extension( - *, - file_content: bytes, - file_extension: str, - unstructured_api_config: UnstructuredApiConfig, -) -> str: - """Extract text from a file based on its file extension.""" - match file_extension: - case ( - ".txt" - | ".markdown" - | ".md" - | ".mdx" - | ".html" - | ".htm" - | ".xml" - | ".c" - | ".h" - | ".cpp" - | ".hpp" - | ".cc" - | ".cxx" - | ".c++" - | ".py" - | ".js" - | ".ts" - | ".jsx" - | ".tsx" - | ".java" - | ".php" - | ".rb" - | ".go" - | ".rs" - | ".swift" - | ".kt" - | ".scala" - | ".sh" - | ".bash" - | ".bat" - | ".ps1" - | ".sql" - | ".r" - | ".m" - | ".pl" - | ".lua" - | ".vim" - | ".asm" - | ".s" - | ".css" - | ".scss" - | ".less" - | ".sass" - | ".ini" - | ".cfg" - | ".conf" - | ".toml" - | ".env" - | ".log" - | ".vtt" - ): - return _extract_text_from_plain_text(file_content) - case ".json": - return _extract_text_from_json(file_content) - case ".yaml" | ".yml": - return _extract_text_from_yaml(file_content) - case ".pdf": - return _extract_text_from_pdf(file_content) - case ".doc": - return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config) - case ".docx": - return _extract_text_from_docx(file_content) - case ".csv": - return _extract_text_from_csv(file_content) - case ".xls" | ".xlsx": - return _extract_text_from_excel(file_content) - case ".ppt": - return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config) - case ".pptx": - return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config) - case ".epub": - return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config) - case ".eml": - return _extract_text_from_eml(file_content) - case ".msg": - return _extract_text_from_msg(file_content) - case ".properties": - return _extract_text_from_properties(file_content) - case _: - raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}") - - -def _extract_text_from_plain_text(file_content: bytes) -> str: - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content, cp_isolation=["utf_8", "latin_1", "cp1252"]).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - return file_content.decode(encoding, errors="ignore") - except (UnicodeDecodeError, LookupError) as e: - # If decoding fails, try with utf-8 as last resort - try: - return file_content.decode("utf-8", errors="ignore") - except UnicodeDecodeError: - raise TextExtractionError(f"Failed to decode plain text file: {e}") from e - - -def _extract_text_from_json(file_content: bytes) -> str: - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - json_data = json.loads(file_content.decode(encoding, errors="ignore")) - return json.dumps(json_data, indent=2, ensure_ascii=False) - except (UnicodeDecodeError, LookupError, json.JSONDecodeError) as e: - # If decoding fails, try with utf-8 as last resort - try: - json_data = json.loads(file_content.decode("utf-8", errors="ignore")) - return json.dumps(json_data, indent=2, ensure_ascii=False) - except (UnicodeDecodeError, json.JSONDecodeError): - raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e - - -def _extract_text_from_yaml(file_content: bytes) -> str: - """Extract the content from yaml file""" - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore")) - return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) - except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e: - # If decoding fails, try with utf-8 as last resort - try: - yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore")) - return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) - except (UnicodeDecodeError, yaml.YAMLError): - raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e - - -def _extract_text_from_pdf(file_content: bytes) -> str: - try: - pdf_file = io.BytesIO(file_content) - pdf_document = pypdfium2.PdfDocument(pdf_file, autoclose=True) - text = "" - for page in pdf_document: - text_page = page.get_textpage() - text += text_page.get_text_range() - text_page.close() - page.close() - return text - except Exception as e: - raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e - - -def _extract_text_from_doc(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - """ - Extract text from a DOC file. - """ - from unstructured.partition.api import partition_via_api - - if not unstructured_api_config.api_url: - raise TextExtractionError("Unstructured API URL is not configured for DOC file processing.") - api_key = unstructured_api_config.api_key or "" - - try: - with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - return "\n".join([getattr(element, "text", "") for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from DOC: {str(e)}") from e - - -def parser_docx_part(block, doc: Document, content_items, i): - if isinstance(block, CT_P): - content_items.append((i, "paragraph", Paragraph(block, doc))) - elif isinstance(block, CT_Tbl): - content_items.append((i, "table", Table(block, doc))) - - -def _normalize_docx_zip(file_content: bytes) -> bytes: - """ - Some DOCX files (e.g. exported by Evernote on Windows) are malformed: - ZIP entry names use backslash (\\) as path separator instead of the forward - slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry - "word\\document.xml" is never found when python-docx looks for - "word/document.xml", which triggers a KeyError about a missing relationship. - - This function rewrites the ZIP in-memory, normalizing all entry names to - use forward slashes without touching any actual document content. - """ - try: - with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin: - out_buf = io.BytesIO() - with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout: - for item in zin.infolist(): - data = zin.read(item.filename) - # Normalize backslash path separators to forward slash - item.filename = item.filename.replace("\\", "/") - zout.writestr(item, data) - return out_buf.getvalue() - except zipfile.BadZipFile: - # Not a valid zip — return as-is and let python-docx report the real error - return file_content - - -def _extract_text_from_docx(file_content: bytes) -> str: - """ - Extract text from a DOCX file. - For now support only paragraph and table add more if needed - """ - try: - doc_file = io.BytesIO(file_content) - try: - doc = docx.Document(doc_file) - except Exception as e: - logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e) - # Some DOCX files exported by tools like Evernote on Windows use - # backslash path separators in ZIP entries and/or single-quoted XML - # attributes, both of which break python-docx on Linux. Normalize and retry. - file_content = _normalize_docx_zip(file_content) - doc = docx.Document(io.BytesIO(file_content)) - text = [] - - # Keep track of paragraph and table positions - content_items: list[tuple[int, str, Table | Paragraph]] = [] - - it = iter(doc.element.body) - part = next(it, None) - i = 0 - while part is not None: - parser_docx_part(part, doc, content_items, i) - i = i + 1 - part = next(it, None) - - # Process sorted content - for _, item_type, item in content_items: - if item_type == "paragraph": - if isinstance(item, Table): - continue - text.append(item.text) - elif item_type == "table": - # Process tables - if not isinstance(item, Table): - continue - try: - # Check if any cell in the table has text - has_content = False - for row in item.rows: - if any(cell.text.strip() for cell in row.cells): - has_content = True - break - - if has_content: - cell_texts = [cell.text.replace("\n", "
") for cell in item.rows[0].cells] - markdown_table = f"| {' | '.join(cell_texts)} |\n" - markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n" - - for row in item.rows[1:]: - # Replace newlines with
in each cell - row_cells = [cell.text.replace("\n", "
") for cell in row.cells] - markdown_table += "| " + " | ".join(row_cells) + " |\n" - - text.append(markdown_table) - except Exception as e: - logger.warning("Failed to extract table from DOC: %s", e) - continue - - return "\n".join(text) - - except Exception as e: - raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e - - -def _download_file_content(http_client: HttpClientProtocol, file: File) -> bytes: - """Download the content of a file based on its transfer method.""" - try: - if file.transfer_method == FileTransferMethod.REMOTE_URL: - if file.remote_url is None: - raise FileDownloadError("Missing URL for remote file") - response = http_client.get(file.remote_url) - response.raise_for_status() - return response.content - else: - return file_manager.download(file) - except Exception as e: - raise FileDownloadError(f"Error downloading file: {str(e)}") from e - - -def _extract_text_from_file( - http_client: HttpClientProtocol, file: File, *, unstructured_api_config: UnstructuredApiConfig -) -> str: - file_content = _download_file_content(http_client, file) - if file.extension: - extracted_text = _extract_text_by_file_extension( - file_content=file_content, - file_extension=file.extension, - unstructured_api_config=unstructured_api_config, - ) - elif file.mime_type: - extracted_text = _extract_text_by_mime_type( - file_content=file_content, - mime_type=file.mime_type, - unstructured_api_config=unstructured_api_config, - ) - else: - raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing") - return extracted_text - - -def _extract_text_from_csv(file_content: bytes) -> str: - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - try: - csv_file = io.StringIO(file_content.decode(encoding, errors="ignore")) - except (UnicodeDecodeError, LookupError): - # If decoding fails, try with utf-8 as last resort - csv_file = io.StringIO(file_content.decode("utf-8", errors="ignore")) - - csv_reader = csv.reader(csv_file) - rows = list(csv_reader) - - if not rows: - return "" - - # Combine multi-line text in the header row - header_row = [cell.replace("\n", " ").replace("\r", "") for cell in rows[0]] - - # Create Markdown table - markdown_table = "| " + " | ".join(header_row) + " |\n" - markdown_table += "| " + " | ".join(["-" * len(col) for col in rows[0]]) + " |\n" - - # Process each data row and combine multi-line text in each cell - for row in rows[1:]: - processed_row = [cell.replace("\n", " ").replace("\r", "") for cell in row] - markdown_table += "| " + " | ".join(processed_row) + " |\n" - - return markdown_table - except Exception as e: - raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e - - -def _extract_text_from_excel(file_content: bytes) -> str: - """Extract text from an Excel file using pandas.""" - - def _construct_markdown_table(df: pd.DataFrame) -> str: - """Manually construct a Markdown table from a DataFrame.""" - # Construct the header row - header_row = "| " + " | ".join(df.columns) + " |" - - # Construct the separator row - separator_row = "| " + " | ".join(["-" * len(col) for col in df.columns]) + " |" - - # Construct the data rows - data_rows = [] - for _, row in df.iterrows(): - data_row = "| " + " | ".join(map(str, row)) + " |" - data_rows.append(data_row) - - # Combine all rows into a single string - markdown_table = "\n".join([header_row, separator_row] + data_rows) - return markdown_table - - try: - excel_file = pd.ExcelFile(io.BytesIO(file_content)) - markdown_table = "" - for sheet_name in excel_file.sheet_names: - try: - df = excel_file.parse(sheet_name=sheet_name) - df.dropna(how="all", inplace=True) - - # Combine multi-line text in each cell into a single line - df = df.map(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) - - # Combine multi-line text in column names into a single line - df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns]) - - # Manually construct the Markdown table - markdown_table += _construct_markdown_table(df) + "\n\n" - except Exception: - continue - return markdown_table - except Exception as e: - raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e - - -def _extract_text_from_ppt(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - from unstructured.partition.api import partition_via_api - from unstructured.partition.ppt import partition_ppt - - api_key = unstructured_api_config.api_key or "" - - try: - if unstructured_api_config.api_url: - with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - else: - with io.BytesIO(file_content) as file: - elements = partition_ppt(file=file) - return "\n".join([getattr(element, "text", "") for element in elements]) - - except Exception as e: - raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e - - -def _extract_text_from_pptx(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - from unstructured.partition.api import partition_via_api - from unstructured.partition.pptx import partition_pptx - - api_key = unstructured_api_config.api_key or "" - - try: - if unstructured_api_config.api_url: - with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - else: - with io.BytesIO(file_content) as file: - elements = partition_pptx(file=file) - return "\n".join([getattr(element, "text", "") for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e - - -def _extract_text_from_epub(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - from unstructured.partition.api import partition_via_api - from unstructured.partition.epub import partition_epub - - api_key = unstructured_api_config.api_key or "" - - try: - if unstructured_api_config.api_url: - with tempfile.NamedTemporaryFile(suffix=".epub", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - else: - pypandoc.download_pandoc() - with io.BytesIO(file_content) as file: - elements = partition_epub(file=file) - return "\n".join([str(element) for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from EPUB: {str(e)}") from e - - -def _extract_text_from_eml(file_content: bytes) -> str: - from unstructured.partition.email import partition_email - - try: - with io.BytesIO(file_content) as file: - elements = partition_email(file=file) - return "\n".join([str(element) for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from EML: {str(e)}") from e - - -def _extract_text_from_msg(file_content: bytes) -> str: - from unstructured.partition.msg import partition_msg - - try: - with io.BytesIO(file_content) as file: - elements = partition_msg(file=file) - return "\n".join([str(element) for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e - - -def _extract_text_from_vtt(vtt_bytes: bytes) -> str: - text = _extract_text_from_plain_text(vtt_bytes) - - # remove bom - text = text.lstrip("\ufeff") - - raw_results = [] - for caption in webvtt.from_string(text): - raw_results.append((caption.voice, caption.text)) - - # Merge consecutive utterances by the same speaker - merged_results = [] - if raw_results: - current_speaker, current_text = raw_results[0] - - for i in range(1, len(raw_results)): - spk, txt = raw_results[i] - if spk is None: - merged_results.append((None, current_text)) - continue - - if spk == current_speaker: - # If it is the same speaker, merge the utterances (joined by space) - current_text += " " + txt - else: - # If the speaker changes, register the utterance so far and move on - merged_results.append((current_speaker, current_text)) - current_speaker, current_text = spk, txt - - # Add the last element - merged_results.append((current_speaker, current_text)) - else: - merged_results = raw_results - - # Return the result in the specified format: Speaker "text" style - formatted = [f'{spk or ""} "{txt}"' for spk, txt in merged_results] - return "\n".join(formatted) - - -def _extract_text_from_properties(file_content: bytes) -> str: - try: - text = _extract_text_from_plain_text(file_content) - lines = text.splitlines() - result = [] - for line in lines: - line = line.strip() - # Preserve comments and empty lines - if not line or line.startswith("#") or line.startswith("!"): - result.append(line) - continue - - if "=" in line: - key, value = line.split("=", 1) - elif ":" in line: - key, value = line.split(":", 1) - else: - key, value = line, "" - - result.append(f"{key.strip()}: {value.strip()}") - - return "\n".join(result) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from properties file: {str(e)}") from e diff --git a/api/graphon/nodes/end/__init__.py b/api/graphon/nodes/end/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/nodes/end/end_node.py b/api/graphon/nodes/end/end_node.py deleted file mode 100644 index 11b9e58644..0000000000 --- a/api/graphon/nodes/end/end_node.py +++ /dev/null @@ -1,47 +0,0 @@ -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.base.template import Template -from graphon.nodes.end.entities import EndNodeData - - -class EndNode(Node[EndNodeData]): - node_type = BuiltinNodeTypes.END - execution_type = NodeExecutionType.RESPONSE - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run node - collect all outputs at once. - - This method runs after streaming is complete (if streaming was enabled). - It collects all output variables and returns them. - """ - output_variables = self.node_data.outputs - - outputs = {} - for variable_selector in output_variables: - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - value = variable.to_object() if variable is not None else None - outputs[variable_selector.variable] = value - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=outputs, - outputs=outputs, - ) - - def get_streaming_template(self) -> Template: - """ - Get the template for streaming. - - Returns: - Template instance for this End node - """ - outputs_config = [ - {"variable": output.variable, "value_selector": output.value_selector} for output in self.node_data.outputs - ] - return Template.from_end_outputs(outputs_config) diff --git a/api/graphon/nodes/end/entities.py b/api/graphon/nodes/end/entities.py deleted file mode 100644 index 839aed7e4b..0000000000 --- a/api/graphon/nodes/end/entities.py +++ /dev/null @@ -1,27 +0,0 @@ -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base.entities import OutputVariableEntity - - -class EndNodeData(BaseNodeData): - """ - END Node Data. - """ - - type: NodeType = BuiltinNodeTypes.END - outputs: list[OutputVariableEntity] - - -class EndStreamParam(BaseModel): - """ - EndStreamParam entity - """ - - end_dependencies: dict[str, list[str]] = Field( - ..., description="end dependencies (end node id -> dependent node ids)" - ) - end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field( - ..., description="end stream variable selector mapping (end node id -> stream variable selectors)" - ) diff --git a/api/graphon/nodes/http_request/__init__.py b/api/graphon/nodes/http_request/__init__.py deleted file mode 100644 index b29099db23..0000000000 --- a/api/graphon/nodes/http_request/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .config import build_http_request_config, resolve_http_request_config -from .entities import ( - HTTP_REQUEST_CONFIG_FILTER_KEY, - BodyData, - HttpRequestNodeAuthorization, - HttpRequestNodeBody, - HttpRequestNodeConfig, - HttpRequestNodeData, -) -from .node import HttpRequestNode - -__all__ = [ - "HTTP_REQUEST_CONFIG_FILTER_KEY", - "BodyData", - "HttpRequestNode", - "HttpRequestNodeAuthorization", - "HttpRequestNodeBody", - "HttpRequestNodeConfig", - "HttpRequestNodeData", - "build_http_request_config", - "resolve_http_request_config", -] diff --git a/api/graphon/nodes/http_request/config.py b/api/graphon/nodes/http_request/config.py deleted file mode 100644 index 53bf6c7ae4..0000000000 --- a/api/graphon/nodes/http_request/config.py +++ /dev/null @@ -1,33 +0,0 @@ -from collections.abc import Mapping - -from .entities import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNodeConfig - - -def build_http_request_config( - *, - max_connect_timeout: int = 10, - max_read_timeout: int = 600, - max_write_timeout: int = 600, - max_binary_size: int = 10 * 1024 * 1024, - max_text_size: int = 1 * 1024 * 1024, - ssl_verify: bool = True, - ssrf_default_max_retries: int = 3, -) -> HttpRequestNodeConfig: - return HttpRequestNodeConfig( - max_connect_timeout=max_connect_timeout, - max_read_timeout=max_read_timeout, - max_write_timeout=max_write_timeout, - max_binary_size=max_binary_size, - max_text_size=max_text_size, - ssl_verify=ssl_verify, - ssrf_default_max_retries=ssrf_default_max_retries, - ) - - -def resolve_http_request_config(filters: Mapping[str, object] | None) -> HttpRequestNodeConfig: - if not filters: - raise ValueError("http_request_config is required to build HTTP request default config") - config = filters.get(HTTP_REQUEST_CONFIG_FILTER_KEY) - if not isinstance(config, HttpRequestNodeConfig): - raise ValueError("http_request_config must be an HttpRequestNodeConfig instance") - return config diff --git a/api/graphon/nodes/http_request/entities.py b/api/graphon/nodes/http_request/entities.py deleted file mode 100644 index 6fa067bdd1..0000000000 --- a/api/graphon/nodes/http_request/entities.py +++ /dev/null @@ -1,241 +0,0 @@ -import mimetypes -from collections.abc import Sequence -from dataclasses import dataclass -from email.message import Message -from typing import Any, Literal - -import charset_normalizer -import httpx -from pydantic import BaseModel, Field, ValidationInfo, field_validator - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - -HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config" - - -class HttpRequestNodeAuthorizationConfig(BaseModel): - type: Literal["basic", "bearer", "custom"] - api_key: str - header: str = "" - - -class HttpRequestNodeAuthorization(BaseModel): - type: Literal["no-auth", "api-key"] - config: HttpRequestNodeAuthorizationConfig | None = None - - @field_validator("config", mode="before") - @classmethod - def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo): - """ - Check config, if type is no-auth, config should be None, otherwise it should be a dict. - """ - if values.data["type"] == "no-auth": - return None - else: - if not v or not isinstance(v, dict): - raise ValueError("config should be a dict") - - return v - - -class BodyData(BaseModel): - key: str = "" - type: Literal["file", "text"] - value: str = "" - file: Sequence[str] = Field(default_factory=list) - - -class HttpRequestNodeBody(BaseModel): - type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json", "binary"] - data: Sequence[BodyData] = Field(default_factory=list) - - @field_validator("data", mode="before") - @classmethod - def check_data(cls, v: Any): - """For compatibility, if body is not set, return empty list.""" - if not v: - return [] - if isinstance(v, str): - return [BodyData(key="", type="text", value=v)] - return v - - -class HttpRequestNodeTimeout(BaseModel): - connect: int | None = None - read: int | None = None - write: int | None = None - - -@dataclass(frozen=True, slots=True) -class HttpRequestNodeConfig: - max_connect_timeout: int - max_read_timeout: int - max_write_timeout: int - max_binary_size: int - max_text_size: int - ssl_verify: bool - ssrf_default_max_retries: int - - def default_timeout(self) -> "HttpRequestNodeTimeout": - return HttpRequestNodeTimeout( - connect=self.max_connect_timeout, - read=self.max_read_timeout, - write=self.max_write_timeout, - ) - - -class HttpRequestNodeData(BaseNodeData): - """ - Code Node Data. - """ - - type: NodeType = BuiltinNodeTypes.HTTP_REQUEST - method: Literal[ - "get", - "post", - "put", - "patch", - "delete", - "head", - "options", - "GET", - "POST", - "PUT", - "PATCH", - "DELETE", - "HEAD", - "OPTIONS", - ] - url: str - authorization: HttpRequestNodeAuthorization - headers: str - params: str - body: HttpRequestNodeBody | None = None - timeout: HttpRequestNodeTimeout | None = None - ssl_verify: bool | None = None - - -class Response: - headers: dict[str, str] - response: httpx.Response - _cached_text: str | None - - def __init__(self, response: httpx.Response): - self.response = response - self.headers = dict(response.headers) - self._cached_text = None - - @property - def is_file(self): - """ - Determine if the response contains a file by checking: - 1. Content-Disposition header (RFC 6266) - 2. Content characteristics - 3. MIME type analysis - """ - content_type = self.content_type.split(";")[0].strip().lower() - parsed_content_disposition = self.parsed_content_disposition - - # Check if it's explicitly marked as an attachment - if parsed_content_disposition: - disp_type = parsed_content_disposition.get_content_disposition() # Returns 'attachment', 'inline', or None - filename = parsed_content_disposition.get_filename() # Returns filename if present, None otherwise - if disp_type == "attachment" or filename is not None: - return True - - # For 'text/' types, only 'csv' should be downloaded as file - if content_type.startswith("text/") and "csv" not in content_type: - return False - - # For application types, try to detect if it's a text-based format - if content_type.startswith("application/"): - # Common text-based application types - if any( - text_type in content_type - for text_type in ("json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql") - ): - return False - - # Try to detect if content is text-based by sampling first few bytes - try: - # Sample first 1024 bytes for text detection - content_sample = self.response.content[:1024] - content_sample.decode("utf-8") - # If we can decode as UTF-8 and find common text patterns, likely not a file - text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ") - if any(marker in content_sample for marker in text_markers): - return False - except UnicodeDecodeError: - # If we can't decode as UTF-8, likely a binary file - return True - - # For other types, use MIME type analysis - main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or "")) - if main_type: - return main_type.split("/")[0] in ("application", "image", "audio", "video") - - # For unknown types, check if it's a media type - return any(media_type in content_type for media_type in ("image/", "audio/", "video/")) - - @property - def content_type(self) -> str: - return self.headers.get("content-type", "") - - @property - def text(self) -> str: - """ - Get response text with robust encoding detection. - - Uses charset_normalizer for better encoding detection than httpx's default, - which helps handle Chinese and other non-ASCII characters properly. - """ - # Check cache first - if hasattr(self, "_cached_text") and self._cached_text is not None: - return self._cached_text - - # Try charset_normalizer for robust encoding detection first - detected_encoding = charset_normalizer.from_bytes(self.response.content).best() - if detected_encoding and detected_encoding.encoding: - try: - text = self.response.content.decode(detected_encoding.encoding) - self._cached_text = text - return text - except (UnicodeDecodeError, TypeError, LookupError): - # Fallback to httpx's encoding detection if charset_normalizer fails - pass - - # Fallback to httpx's built-in encoding detection - text = self.response.text - self._cached_text = text - return text - - @property - def content(self) -> bytes: - return self.response.content - - @property - def status_code(self) -> int: - return self.response.status_code - - @property - def size(self) -> int: - return len(self.content) - - @property - def readable_size(self) -> str: - if self.size < 1024: - return f"{self.size} bytes" - elif self.size < 1024 * 1024: - return f"{(self.size / 1024):.2f} KB" - else: - return f"{(self.size / 1024 / 1024):.2f} MB" - - @property - def parsed_content_disposition(self) -> Message | None: - content_disposition = self.headers.get("content-disposition", "") - if content_disposition: - msg = Message() - msg["content-disposition"] = content_disposition - return msg - return None diff --git a/api/graphon/nodes/http_request/exc.py b/api/graphon/nodes/http_request/exc.py deleted file mode 100644 index 46613c9e86..0000000000 --- a/api/graphon/nodes/http_request/exc.py +++ /dev/null @@ -1,26 +0,0 @@ -class HttpRequestNodeError(ValueError): - """Custom error for HTTP request node.""" - - -class AuthorizationConfigError(HttpRequestNodeError): - """Raised when authorization config is missing or invalid.""" - - -class FileFetchError(HttpRequestNodeError): - """Raised when a file cannot be fetched.""" - - -class InvalidHttpMethodError(HttpRequestNodeError): - """Raised when an invalid HTTP method is used.""" - - -class ResponseSizeError(HttpRequestNodeError): - """Raised when the response size exceeds the allowed threshold.""" - - -class RequestBodyError(HttpRequestNodeError): - """Raised when the request body is invalid.""" - - -class InvalidURLError(HttpRequestNodeError): - """Raised when the URL is invalid.""" diff --git a/api/graphon/nodes/http_request/executor.py b/api/graphon/nodes/http_request/executor.py deleted file mode 100644 index 0c6f4ecd3a..0000000000 --- a/api/graphon/nodes/http_request/executor.py +++ /dev/null @@ -1,488 +0,0 @@ -import base64 -import json -import secrets -import string -from collections.abc import Callable, Mapping -from copy import deepcopy -from typing import Any, Literal -from urllib.parse import urlencode, urlparse - -import httpx -from json_repair import repair_json - -from graphon.file.enums import FileTransferMethod -from graphon.runtime import VariablePool -from graphon.variables.segments import ArrayFileSegment, FileSegment - -from ..protocols import FileManagerProtocol, HttpClientProtocol -from .entities import ( - HttpRequestNodeAuthorization, - HttpRequestNodeConfig, - HttpRequestNodeData, - HttpRequestNodeTimeout, - Response, -) -from .exc import ( - AuthorizationConfigError, - FileFetchError, - HttpRequestNodeError, - InvalidHttpMethodError, - InvalidURLError, - RequestBodyError, - ResponseSizeError, -) - -BODY_TYPE_TO_CONTENT_TYPE = { - "json": "application/json", - "x-www-form-urlencoded": "application/x-www-form-urlencoded", - "form-data": "multipart/form-data", - "raw-text": "text/plain", -} - - -class Executor: - method: Literal[ - "get", - "head", - "post", - "put", - "delete", - "patch", - "options", - "GET", - "POST", - "PUT", - "PATCH", - "DELETE", - "HEAD", - "OPTIONS", - ] - url: str - params: list[tuple[str, str]] | None - content: str | bytes | None - data: Mapping[str, Any] | None - files: list[tuple[str, tuple[str | None, bytes, str]]] | None - json: Any - headers: dict[str, str] - auth: HttpRequestNodeAuthorization - timeout: HttpRequestNodeTimeout - max_retries: int - - boundary: str - - def __init__( - self, - *, - node_data: HttpRequestNodeData, - timeout: HttpRequestNodeTimeout, - variable_pool: VariablePool, - http_request_config: HttpRequestNodeConfig, - max_retries: int | None = None, - ssl_verify: bool | None = None, - http_client: HttpClientProtocol, - file_manager: FileManagerProtocol, - ): - self._http_request_config = http_request_config - # If authorization API key is present, convert the API key using the variable pool - if node_data.authorization.type == "api-key": - if node_data.authorization.config is None: - raise AuthorizationConfigError("authorization config is required") - node_data.authorization.config.api_key = variable_pool.convert_template( - node_data.authorization.config.api_key - ).text - # Validate that API key is not empty after template conversion - if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip(): - raise AuthorizationConfigError( - "API key is required for authorization but was empty. Please provide a valid API key." - ) - - self.url = node_data.url - self.method = node_data.method - self.auth = node_data.authorization - self.timeout = timeout - self.ssl_verify = ssl_verify if ssl_verify is not None else node_data.ssl_verify - if self.ssl_verify is None: - self.ssl_verify = self._http_request_config.ssl_verify - if not isinstance(self.ssl_verify, bool): - raise ValueError("ssl_verify must be a boolean") - self.params = None - self.headers = {} - self.content = None - self.files = None - self.data = None - self.json = None - self.max_retries = ( - max_retries if max_retries is not None else self._http_request_config.ssrf_default_max_retries - ) - self._http_client = http_client - self._file_manager = file_manager - - # init template - self.variable_pool = variable_pool - self.node_data = node_data - self._initialize() - - def _initialize(self): - self._init_url() - self._init_params() - self._init_headers() - self._init_body() - - def _init_url(self): - self.url = self.variable_pool.convert_template(self.node_data.url).text - - # check if url is a valid URL - if not self.url: - raise InvalidURLError("url is required") - if not self.url.startswith(("http://", "https://")): - raise InvalidURLError("url should start with http:// or https://") - - def _init_params(self): - """ - Almost same as _init_headers(), difference: - 1. response a list tuple to support same key, like 'aa=1&aa=2' - 2. param value may have '\n', we need to splitlines then extract the variable value. - """ - result = [] - for line in self.node_data.params.splitlines(): - if not (line := line.strip()): - continue - - key, *value = line.split(":", 1) - if not (key := key.strip()): - continue - - value_str = value[0].strip() if value else "" - result.append( - (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text) - ) - - if result: - self.params = result - - def _init_headers(self): - """ - Convert the header string of frontend to a dictionary. - - Each line in the header string represents a key-value pair. - Keys and values are separated by ':'. - Empty values are allowed. - - Examples: - 'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'} - 'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'} - 'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'} - - """ - headers = self.variable_pool.convert_template(self.node_data.headers).text - self.headers = { - key.strip(): (value[0].strip() if value else "") - for line in headers.splitlines() - if line.strip() - for key, *value in [line.split(":", 1)] - } - - def _init_body(self): - body = self.node_data.body - if body is not None: - data = body.data - match body.type: - case "none": - self.content = "" - case "raw-text": - if len(data) != 1: - raise RequestBodyError("raw-text body type should have exactly one item") - self.content = self.variable_pool.convert_template(data[0].value).text - case "json": - if len(data) != 1: - raise RequestBodyError("json body type should have exactly one item") - json_string = self.variable_pool.convert_template(data[0].value).text - try: - repaired = repair_json(json_string) - json_object = json.loads(repaired, strict=False) - except json.JSONDecodeError as e: - raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e - self.json = json_object - # self.json = self._parse_object_contains_variables(json_object) - case "binary": - if len(data) != 1: - raise RequestBodyError("binary body type should have exactly one item") - file_selector = data[0].file - file_variable = self.variable_pool.get_file(file_selector) - if file_variable is None: - raise FileFetchError(f"cannot fetch file with selector {file_selector}") - file = file_variable.value - self.content = self._file_manager.download(file) - case "x-www-form-urlencoded": - form_data = { - self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( - item.value - ).text - for item in data - } - self.data = form_data - case "form-data": - form_data = { - self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( - item.value - ).text - for item in filter(lambda item: item.type == "text", data) - } - file_selectors = { - self.variable_pool.convert_template(item.key).text: item.file - for item in filter(lambda item: item.type == "file", data) - } - - # get files from file_selectors, add support for array file variables - files_list = [] - for key, selector in file_selectors.items(): - segment = self.variable_pool.get(selector) - if isinstance(segment, FileSegment): - files_list.append((key, [segment.value])) - elif isinstance(segment, ArrayFileSegment): - files_list.append((key, list(segment.value))) - - # get files from file_manager - files: dict[str, list[tuple[str | None, bytes, str]]] = {} - for key, files_in_segment in files_list: - for file in files_in_segment: - if file.reference is not None or ( - file.transfer_method == FileTransferMethod.REMOTE_URL and file.remote_url is not None - ): - file_tuple = ( - file.filename, - self._file_manager.download(file), - file.mime_type or "application/octet-stream", - ) - if key not in files: - files[key] = [] - files[key].append(file_tuple) - - # convert files to list for httpx request - # If there are no actual files, we still need to force httpx to use `multipart/form-data`. - # This is achieved by inserting a harmless placeholder file that will be ignored by the server. - if not files: - self.files = [("__multipart_placeholder__", ("", b"", "application/octet-stream"))] - if files: - self.files = [] - for key, file_tuples in files.items(): - for file_tuple in file_tuples: - self.files.append((key, file_tuple)) - - self.data = form_data - - def _assembling_headers(self) -> dict[str, Any]: - authorization = deepcopy(self.auth) - headers = deepcopy(self.headers) or {} - if self.auth.type == "api-key": - if self.auth.config is None: - raise AuthorizationConfigError("self.authorization config is required") - if authorization.config is None: - raise AuthorizationConfigError("authorization config is required") - - if not authorization.config.header: - authorization.config.header = "Authorization" - - if self.auth.config.type == "bearer" and authorization.config.api_key: - headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" - elif self.auth.config.type == "basic" and authorization.config.api_key: - credentials = authorization.config.api_key - if ":" in credentials: - encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8") - else: - encoded_credentials = credentials - headers[authorization.config.header] = f"Basic {encoded_credentials}" - elif self.auth.config.type == "custom": - if authorization.config.header and authorization.config.api_key: - headers[authorization.config.header] = authorization.config.api_key - - # Handle Content-Type for multipart/form-data requests - # Fix for issue #23829: Missing boundary when using multipart/form-data - body = self.node_data.body - if body and body.type == "form-data": - # For multipart/form-data with files (including placeholder files), - # remove any manually set Content-Type header to let httpx handle - # For multipart/form-data, if any files are present (including placeholder files), - # we must remove any manually set Content-Type header. This is because httpx needs to - # automatically set the Content-Type and boundary for multipart encoding whenever files - # are included, even if they are placeholders, to avoid boundary issues and ensure correct - # file upload behaviour. Manually setting Content-Type can cause httpx to fail to set the - # boundary, resulting in invalid requests. - if self.files: - # Remove Content-Type if it was manually set to avoid boundary issues - headers = {k: v for k, v in headers.items() if k.lower() != "content-type"} - else: - # No files at all, set Content-Type manually - if "content-type" not in (k.lower() for k in headers): - headers["Content-Type"] = "multipart/form-data" - elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE: - # Set Content-Type for other body types - if "content-type" not in (k.lower() for k in headers): - headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] - - return headers - - def _validate_and_parse_response(self, response: httpx.Response) -> Response: - executor_response = Response(response) - - threshold_size = ( - self._http_request_config.max_binary_size - if executor_response.is_file - else self._http_request_config.max_text_size - ) - if executor_response.size > threshold_size: - raise ResponseSizeError( - f"{'File' if executor_response.is_file else 'Text'} size is too large," - f" max size is {threshold_size / 1024 / 1024:.2f} MB," - f" but current size is {executor_response.readable_size}." - ) - - return executor_response - - def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: - """ - do http request depending on api bundle - """ - _METHOD_MAP: dict[str, Callable[..., httpx.Response]] = { - "get": self._http_client.get, - "head": self._http_client.head, - "post": self._http_client.post, - "put": self._http_client.put, - "delete": self._http_client.delete, - "patch": self._http_client.patch, - } - method_lc = self.method.lower() - if method_lc not in _METHOD_MAP: - raise InvalidHttpMethodError(f"Invalid http method {self.method}") - - request_args: dict[str, Any] = { - "data": self.data, - "files": self.files, - "json": self.json, - "content": self.content, - "headers": headers, - "params": self.params, - "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), - "ssl_verify": self.ssl_verify, - "follow_redirects": True, - } - # request_args = {k: v for k, v in request_args.items() if v is not None} - try: - response = _METHOD_MAP[method_lc]( - url=self.url, - **request_args, - max_retries=self.max_retries, - ) - except self._http_client.max_retries_exceeded_error as e: - raise HttpRequestNodeError(f"Reached maximum retries for URL {self.url}") from e - except self._http_client.request_error as e: - raise HttpRequestNodeError(str(e)) from e - return response - - def invoke(self) -> Response: - # assemble headers - headers = self._assembling_headers() - # do http request - response = self._do_http_request(headers) - # validate response - return self._validate_and_parse_response(response) - - def to_log(self): - url_parts = urlparse(self.url) - path = url_parts.path or "/" - - # Add query parameters - if self.params: - query_string = urlencode(self.params) - path += f"?{query_string}" - elif url_parts.query: - path += f"?{url_parts.query}" - - raw = f"{self.method.upper()} {path} HTTP/1.1\r\n" - raw += f"Host: {url_parts.netloc}\r\n" - - headers = self._assembling_headers() - body = self.node_data.body - boundary = f"----WebKitFormBoundary{_generate_random_string(16)}" - if body: - if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE: - headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] - if body.type == "form-data": - headers["Content-Type"] = f"multipart/form-data; boundary={boundary}" - for k, v in headers.items(): - if self.auth.type == "api-key": - authorization_header = "Authorization" - if self.auth.config and self.auth.config.header: - authorization_header = self.auth.config.header - if k.lower() == authorization_header.lower(): - raw += f"{k}: {'*' * len(v)}\r\n" - continue - raw += f"{k}: {v}\r\n" - - body_string = "" - # Only log actual files if present. - # '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file. - # This prevents logging meaningless placeholder entries. - if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files): - for file_entry in self.files: - # file_entry should be (key, (filename, content, mime_type)), but handle edge cases - if len(file_entry) != 2 or len(file_entry[1]) < 2: - continue # skip malformed entries - key = file_entry[0] - content = file_entry[1][1] - body_string += f"--{boundary}\r\n" - body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - # decode content safely - # Do not decode binary content; use a placeholder with file metadata instead. - # Includes filename, size, and MIME type for better logging context. - body_string += ( - f"\r\n" - ) - body_string += f"--{boundary}--\r\n" - elif self.node_data.body: - if self.content: - # If content is bytes, do not decode it; show a placeholder with size. - # Provides content size information for binary data without exposing the raw bytes. - if isinstance(self.content, bytes): - body_string = f"" - else: - body_string = self.content - elif self.data and self.node_data.body.type == "x-www-form-urlencoded": - body_string = urlencode(self.data) - elif self.data and self.node_data.body.type == "form-data": - for key, value in self.data.items(): - body_string += f"--{boundary}\r\n" - body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - body_string += f"{value}\r\n" - body_string += f"--{boundary}--\r\n" - elif self.json: - body_string = json.dumps(self.json) - elif self.node_data.body.type == "raw-text": - if len(self.node_data.body.data) != 1: - raise RequestBodyError("raw-text body type should have exactly one item") - body_string = self.node_data.body.data[0].value - if body_string: - raw += f"Content-Length: {len(body_string)}\r\n" - raw += "\r\n" # Empty line between headers and body - raw += body_string - - return raw - - -def _generate_random_string(n: int) -> str: - """ - Generate a random string of lowercase ASCII letters. - - Args: - n (int): The length of the random string to generate. - - Returns: - str: A random string of lowercase ASCII letters with length n. - - Example: - >>> _generate_random_string(5) - 'abcde' - """ - return "".join(secrets.choice(string.ascii_lowercase) for _ in range(n)) diff --git a/api/graphon/nodes/http_request/node.py b/api/graphon/nodes/http_request/node.py deleted file mode 100644 index 3d74347a7f..0000000000 --- a/api/graphon/nodes/http_request/node.py +++ /dev/null @@ -1,261 +0,0 @@ -import logging -import mimetypes -from collections.abc import Callable, Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod -from graphon.node_events import NodeRunResult -from graphon.nodes.base import variable_template_parser -from graphon.nodes.base.entities import VariableSelector -from graphon.nodes.base.node import Node -from graphon.nodes.http_request.executor import Executor -from graphon.nodes.protocols import ( - FileManagerProtocol, - FileReferenceFactoryProtocol, - HttpClientProtocol, - ToolFileManagerProtocol, -) -from graphon.variables.segments import ArrayFileSegment - -from .config import build_http_request_config, resolve_http_request_config -from .entities import ( - HTTP_REQUEST_CONFIG_FILTER_KEY, - HttpRequestNodeConfig, - HttpRequestNodeData, - HttpRequestNodeTimeout, - Response, -) -from .exc import HttpRequestNodeError, RequestBodyError - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - - -class HttpRequestNode(Node[HttpRequestNodeData]): - node_type = BuiltinNodeTypes.HTTP_REQUEST - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - http_request_config: HttpRequestNodeConfig, - http_client: HttpClientProtocol, - tool_file_manager_factory: Callable[[], ToolFileManagerProtocol], - file_manager: FileManagerProtocol, - file_reference_factory: FileReferenceFactoryProtocol, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - self._http_request_config = http_request_config - self._http_client = http_client - self._tool_file_manager_factory = tool_file_manager_factory - self._file_manager = file_manager - self._file_reference_factory = file_reference_factory - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - if not filters or HTTP_REQUEST_CONFIG_FILTER_KEY not in filters: - http_request_config = build_http_request_config() - else: - http_request_config = resolve_http_request_config(filters) - default_timeout = http_request_config.default_timeout() - return { - "type": "http-request", - "config": { - "method": "get", - "authorization": { - "type": "no-auth", - }, - "body": {"type": "none"}, - "timeout": { - **default_timeout.model_dump(), - "max_connect_timeout": http_request_config.max_connect_timeout, - "max_read_timeout": http_request_config.max_read_timeout, - "max_write_timeout": http_request_config.max_write_timeout, - }, - "ssl_verify": http_request_config.ssl_verify, - }, - "retry_config": { - "max_retries": http_request_config.ssrf_default_max_retries, - "retry_interval": 0.5 * (2**2), - "retry_enabled": True, - }, - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - process_data = {} - try: - http_executor = Executor( - node_data=self.node_data, - timeout=self._get_request_timeout(self.node_data), - variable_pool=self.graph_runtime_state.variable_pool, - http_request_config=self._http_request_config, - # Must be 0 to disable executor-level retries, as the graph engine handles them. - # This is critical to prevent nested retries. - max_retries=0, - ssl_verify=self.node_data.ssl_verify, - http_client=self._http_client, - file_manager=self._file_manager, - ) - process_data["request"] = http_executor.to_log() - - response = http_executor.invoke() - files = self.extract_files(url=http_executor.url, response=response) - if not response.response.is_success and (self.error_strategy or self.retry): - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - outputs={ - "status_code": response.status_code, - "body": response.text if not files.value else "", - "headers": response.headers, - "files": files, - }, - process_data={ - "request": http_executor.to_log(), - }, - error=f"Request failed with status code {response.status_code}", - error_type="HTTPResponseCodeError", - ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "status_code": response.status_code, - "body": response.text if not files.value else "", - "headers": response.headers, - "files": files, - }, - process_data={ - "request": http_executor.to_log(), - }, - ) - except HttpRequestNodeError as e: - logger.warning("http request node %s failed to run: %s", self._node_id, e) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - process_data=process_data, - error_type=type(e).__name__, - ) - - def _get_request_timeout(self, node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: - default_timeout = self._http_request_config.default_timeout() - timeout = node_data.timeout - if timeout is None: - return default_timeout - - return HttpRequestNodeTimeout( - connect=timeout.connect or default_timeout.connect, - read=timeout.read or default_timeout.read, - write=timeout.write or default_timeout.write, - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: HttpRequestNodeData, - ) -> Mapping[str, Sequence[str]]: - selectors: list[VariableSelector] = [] - selectors += variable_template_parser.extract_selectors_from_template(node_data.url) - selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) - selectors += variable_template_parser.extract_selectors_from_template(node_data.params) - if node_data.body: - body_type = node_data.body.type - data = node_data.body.data - match body_type: - case "none": - pass - case "binary": - if len(data) != 1: - raise RequestBodyError("invalid body data, should have only one item") - selector = data[0].file - selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector)) - case "json" | "raw-text": - if len(data) != 1: - raise RequestBodyError("invalid body data, should have only one item") - selectors += variable_template_parser.extract_selectors_from_template(data[0].key) - selectors += variable_template_parser.extract_selectors_from_template(data[0].value) - case "x-www-form-urlencoded": - for item in data: - selectors += variable_template_parser.extract_selectors_from_template(item.key) - selectors += variable_template_parser.extract_selectors_from_template(item.value) - case "form-data": - for item in data: - selectors += variable_template_parser.extract_selectors_from_template(item.key) - if item.type == "text": - selectors += variable_template_parser.extract_selectors_from_template(item.value) - elif item.type == "file": - selectors.append( - VariableSelector(variable="#" + ".".join(item.file) + "#", value_selector=item.file) - ) - - mapping = {} - for selector_iter in selectors: - mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector - - return mapping - - def extract_files(self, url: str, response: Response) -> ArrayFileSegment: - """ - Extract files from response by checking both Content-Type header and URL - """ - files: list[File] = [] - is_file = response.is_file - content_type = response.content_type - content = response.content - parsed_content_disposition = response.parsed_content_disposition - content_disposition_type = None - - if not is_file: - return ArrayFileSegment(value=[]) - - if parsed_content_disposition: - content_disposition_filename = parsed_content_disposition.get_filename() - if content_disposition_filename: - # If filename is available from content-disposition, use it to guess the content type - content_disposition_type = mimetypes.guess_type(content_disposition_filename)[0] - - # Guess file extension from URL or Content-Type header - filename = url.split("?")[0].split("/")[-1] or "" - mime_type = ( - content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" - ) - tool_file_manager = self._tool_file_manager_factory() - - tool_file = tool_file_manager.create_file_by_raw( - file_binary=content, - mimetype=mime_type, - ) - - file = self._file_reference_factory.build_from_mapping( - mapping={ - "tool_file_id": tool_file.id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - ) - files.append(file) - - return ArrayFileSegment(value=files) - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled diff --git a/api/graphon/nodes/human_input/__init__.py b/api/graphon/nodes/human_input/__init__.py deleted file mode 100644 index 1789604577..0000000000 --- a/api/graphon/nodes/human_input/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Human Input node implementation. -""" diff --git a/api/graphon/nodes/human_input/entities.py b/api/graphon/nodes/human_input/entities.py deleted file mode 100644 index aa01bde145..0000000000 --- a/api/graphon/nodes/human_input/entities.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Human Input node entities. - -The graph package owns the workflow-facing form schema and keeps it transportable -across runtimes. Dify-specific delivery surface and recipient translation stay -outside `graphon`. -""" - -import re -from collections.abc import Mapping, Sequence -from datetime import datetime, timedelta -from typing import Any, Self - -from pydantic import BaseModel, Field, field_validator, model_validator - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from graphon.variables.consts import SELECTORS_LENGTH - -from .enums import ButtonStyle, FormInputType, PlaceholderType, TimeoutUnit - -_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") - - -class FormInputDefault(BaseModel): - """Default configuration for form inputs.""" - - # NOTE: Ideally, a discriminated union would be used to model - # FormInputDefault. However, the UI requires preserving the previous - # value when switching between `VARIABLE` and `CONSTANT` types. This - # necessitates retaining all fields, making a discriminated union unsuitable. - - type: PlaceholderType - - # The selector of default variable, used when `type` is `VARIABLE`. - selector: Sequence[str] = Field(default_factory=tuple) # - - # The value of the default, used when `type` is `CONSTANT`. - # TODO: How should we express JSON values? - value: str = "" - - @model_validator(mode="after") - def _validate_selector(self) -> Self: - if self.type == PlaceholderType.CONSTANT: - return self - if len(self.selector) < SELECTORS_LENGTH: - raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}") - return self - - -class FormInput(BaseModel): - """Form input definition.""" - - type: FormInputType - output_variable_name: str - default: FormInputDefault | None = None - - -_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") - - -class UserAction(BaseModel): - """User action configuration.""" - - # id is the identifier for this action. - # It also serves as the identifiers of output handle. - # - # The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.) - id: str = Field(max_length=20) - title: str = Field(max_length=20) - button_style: ButtonStyle = ButtonStyle.DEFAULT - - @field_validator("id") - @classmethod - def _validate_id(cls, value: str) -> str: - if not _IDENTIFIER_PATTERN.match(value): - raise ValueError( - f"'{value}' is not a valid identifier. It must start with a letter or underscore, " - f"and contain only letters, numbers, or underscores." - ) - return value - - -class HumanInputNodeData(BaseNodeData): - """Human Input node data.""" - - type: NodeType = BuiltinNodeTypes.HUMAN_INPUT - form_content: str = "" - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - timeout: int = 36 - timeout_unit: TimeoutUnit = TimeoutUnit.HOUR - - @field_validator("inputs") - @classmethod - def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: - seen_names: set[str] = set() - for form_input in inputs: - name = form_input.output_variable_name - if name in seen_names: - raise ValueError(f"duplicated output_variable_name '{name}' in inputs") - seen_names.add(name) - return inputs - - @field_validator("user_actions") - @classmethod - def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]: - seen_ids: set[str] = set() - for action in user_actions: - action_id = action.id - if action_id in seen_ids: - raise ValueError(f"duplicated user action id '{action_id}'") - seen_ids.add(action_id) - return user_actions - - def expiration_time(self, start_time: datetime) -> datetime: - if self.timeout_unit == TimeoutUnit.HOUR: - return start_time + timedelta(hours=self.timeout) - elif self.timeout_unit == TimeoutUnit.DAY: - return start_time + timedelta(days=self.timeout) - else: - raise AssertionError("unknown timeout unit.") - - def outputs_field_names(self) -> Sequence[str]: - field_names = [] - for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content): - field_names.append(match.group("field_name")) - return field_names - - def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]: - variable_mappings: dict[str, Sequence[str]] = {} - - def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None: - for selector in selectors: - if len(selector) < SELECTORS_LENGTH: - continue - qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#" - variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH]) - - form_template_parser = VariableTemplateParser(template=self.form_content) - _add_variable_selectors( - [selector.value_selector for selector in form_template_parser.extract_variable_selectors()] - ) - - for input in self.inputs: - default_value = input.default - if default_value is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - default_value_key = ".".join(default_value.selector) - qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#" - variable_mappings[qualified_variable_mapping_key] = default_value.selector - - return variable_mappings - - def find_action_text(self, action_id: str) -> str: - """ - Resolve action display text by id. - """ - for action in self.user_actions: - if action.id == action_id: - return action.title - return action_id - - -class FormDefinition(BaseModel): - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - rendered_content: str - expiration_time: datetime - - # this is used to store the resolved default values - default_values: dict[str, Any] = Field(default_factory=dict) - - # node_title records the title of the HumanInput node. - node_title: str | None = None - - # display_in_ui controls whether the form should be displayed in UI surfaces. - display_in_ui: bool | None = None - - -class HumanInputSubmissionValidationError(ValueError): - pass - - -def validate_human_input_submission( - *, - inputs: Sequence[FormInput], - user_actions: Sequence[UserAction], - selected_action_id: str, - form_data: Mapping[str, Any], -) -> None: - available_actions = {action.id for action in user_actions} - if selected_action_id not in available_actions: - raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}") - - provided_inputs = set(form_data.keys()) - missing_inputs = [ - form_input.output_variable_name - for form_input in inputs - if form_input.output_variable_name not in provided_inputs - ] - - if missing_inputs: - missing_list = ", ".join(missing_inputs) - raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}") diff --git a/api/graphon/nodes/human_input/enums.py b/api/graphon/nodes/human_input/enums.py deleted file mode 100644 index 3fb0ab4499..0000000000 --- a/api/graphon/nodes/human_input/enums.py +++ /dev/null @@ -1,55 +0,0 @@ -import enum - - -class HumanInputFormStatus(enum.StrEnum): - """Status of a human input form.""" - - # Awaiting submission from any recipient. Forms stay in this state until - # submitted or a timeout rule applies. - WAITING = enum.auto() - # Global timeout reached. The workflow run is stopped and will not resume. - # This is distinct from node-level timeout. - EXPIRED = enum.auto() - # Submitted by a recipient; form data is available and execution resumes - # along the selected action edge. - SUBMITTED = enum.auto() - # Node-level timeout reached. The human input node should emit a timeout - # event and the workflow should resume along the timeout edge. - TIMEOUT = enum.auto() - - -class HumanInputFormKind(enum.StrEnum): - """Kind of a human input form.""" - - RUNTIME = enum.auto() # Form created during workflow execution. - DELIVERY_TEST = enum.auto() # Form created for delivery tests. - - -class ButtonStyle(enum.StrEnum): - """Button styles for user actions.""" - - PRIMARY = enum.auto() - DEFAULT = enum.auto() - ACCENT = enum.auto() - GHOST = enum.auto() - - -class TimeoutUnit(enum.StrEnum): - """Timeout unit for form expiration.""" - - HOUR = enum.auto() - DAY = enum.auto() - - -class FormInputType(enum.StrEnum): - """Form input types.""" - - TEXT_INPUT = enum.auto() - PARAGRAPH = enum.auto() - - -class PlaceholderType(enum.StrEnum): - """Default value types for form inputs.""" - - VARIABLE = enum.auto() - CONSTANT = enum.auto() diff --git a/api/graphon/nodes/human_input/human_input_node.py b/api/graphon/nodes/human_input/human_input_node.py deleted file mode 100644 index fe04022877..0000000000 --- a/api/graphon/nodes/human_input/human_input_node.py +++ /dev/null @@ -1,299 +0,0 @@ -import json -import logging -from collections.abc import Generator, Mapping, Sequence -from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, cast - -from graphon.entities.graph_config import NodeConfigDict -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import ( - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - NodeRunResult, - PauseRequestedEvent, -) -from graphon.node_events.base import NodeEventBase -from graphon.node_events.node import StreamCompletedEvent -from graphon.nodes.base.node import Node -from graphon.nodes.runtime import HumanInputFormStateProtocol, HumanInputNodeRuntimeProtocol -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter - -from .entities import HumanInputNodeData -from .enums import HumanInputFormStatus, PlaceholderType - -if TYPE_CHECKING: - from graphon.entities.graph_init_params import GraphInitParams - from graphon.runtime.graph_runtime_state import GraphRuntimeState - - -_SELECTED_BRANCH_KEY = "selected_branch" - - -logger = logging.getLogger(__name__) - - -class HumanInputNode(Node[HumanInputNodeData]): - node_type = BuiltinNodeTypes.HUMAN_INPUT - execution_type = NodeExecutionType.BRANCH - - _BRANCH_SELECTION_KEYS: tuple[str, ...] = ( - "edge_source_handle", - "edgeSourceHandle", - "source_handle", - _SELECTED_BRANCH_KEY, - "selectedBranch", - "branch", - "branch_id", - "branchId", - "handle", - ) - - _node_data: HumanInputNodeData - _OUTPUT_FIELD_ACTION_ID = "__action_id" - _OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content" - _TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout" - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - runtime: HumanInputNodeRuntimeProtocol | None = None, - form_repository: object | None = None, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - resolved_runtime = runtime - if resolved_runtime is None: - raise ValueError("runtime is required") - if form_repository is not None: - with_form_repository = getattr(resolved_runtime, "with_form_repository", None) - if callable(with_form_repository): - resolved_runtime = cast(HumanInputNodeRuntimeProtocol, with_form_repository(form_repository)) - self._runtime: HumanInputNodeRuntimeProtocol = resolved_runtime - - @classmethod - def version(cls) -> str: - return "1" - - def _resolve_branch_selection(self) -> str | None: - """Determine the branch handle selected by human input if available.""" - - variable_pool = self.graph_runtime_state.variable_pool - - for key in self._BRANCH_SELECTION_KEYS: - handle = self._extract_branch_handle(variable_pool.get((self.id, key))) - if handle: - return handle - - default_values = self.node_data.default_value_dict - for key in self._BRANCH_SELECTION_KEYS: - handle = self._normalize_branch_value(default_values.get(key)) - if handle: - return handle - - return None - - @staticmethod - def _extract_branch_handle(segment: Any) -> str | None: - if segment is None: - return None - - candidate = getattr(segment, "to_object", None) - raw_value = candidate() if callable(candidate) else getattr(segment, "value", None) - if raw_value is None: - return None - - return HumanInputNode._normalize_branch_value(raw_value) - - @staticmethod - def _normalize_branch_value(value: Any) -> str | None: - if value is None: - return None - - if isinstance(value, str): - stripped = value.strip() - return stripped or None - - if isinstance(value, Mapping): - for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"): - candidate = value.get(key) - if isinstance(candidate, str) and candidate: - return candidate - - return None - - def _form_to_pause_event(self, form_entity: HumanInputFormStateProtocol): - required_event = self._human_input_required_event(form_entity) - pause_requested_event = PauseRequestedEvent(reason=required_event) - return pause_requested_event - - def resolve_default_values(self) -> Mapping[str, Any]: - variable_pool = self.graph_runtime_state.variable_pool - resolved_defaults: dict[str, Any] = {} - for input in self._node_data.inputs: - if (default_value := input.default) is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - resolved_value = variable_pool.get(default_value.selector) - if resolved_value is None: - # TODO: How should we handle this? - continue - resolved_defaults[input.output_variable_name] = ( - WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value) - ) - - return resolved_defaults - - def _human_input_required_event(self, form_entity: HumanInputFormStateProtocol) -> HumanInputRequired: - node_data = self._node_data - resolved_default_values = self.resolve_default_values() - return HumanInputRequired( - form_id=form_entity.id, - form_content=form_entity.rendered_content, - inputs=node_data.inputs, - actions=node_data.user_actions, - node_id=self.id, - node_title=node_data.title, - resolved_default_values=resolved_default_values, - ) - - def _run(self) -> Generator[NodeEventBase, None, None]: - """ - Execute the human input node. - - This method will: - 1. Generate a unique form ID - 2. Create form content with variable substitution - 3. Persist the form through the configured repository - 4. Send form via configured delivery methods - 5. Suspend workflow execution - 6. Wait for form submission to resume - """ - form = self._runtime.get_form(node_id=self.id) - if form is None: - form_entity = self._runtime.create_form( - node_id=self.id, - node_data=self._node_data, - rendered_content=self.render_form_content_before_submission(), - resolved_default_values=self.resolve_default_values(), - ) - - logger.info( - "Human Input node suspended workflow for form. node_id=%s, form_id=%s", - self.id, - form_entity.id, - ) - yield self._form_to_pause_event(form_entity) - return - - if form.status in { - HumanInputFormStatus.TIMEOUT, - HumanInputFormStatus.EXPIRED, - } or form.expiration_time <= datetime.now(UTC).replace(tzinfo=None): - yield HumanInputFormTimeoutEvent( - node_title=self._node_data.title, - expiration_time=form.expiration_time, - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={self._OUTPUT_FIELD_ACTION_ID: ""}, - edge_source_handle=self._TIMEOUT_HANDLE, - ) - ) - return - - if not form.submitted: - yield self._form_to_pause_event(form) - return - - selected_action_id = form.selected_action_id - if selected_action_id is None: - raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}") - submitted_data = form.submitted_data or {} - outputs: dict[str, Any] = dict(submitted_data) - outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id - rendered_content = self.render_form_content_with_outputs( - form.rendered_content, - outputs, - self._node_data.outputs_field_names(), - ) - outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content - - action_text = self._node_data.find_action_text(selected_action_id) - - yield HumanInputFormFilledEvent( - node_title=self._node_data.title, - rendered_content=rendered_content, - action_id=selected_action_id, - action_text=action_text, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - edge_source_handle=selected_action_id, - ) - ) - - def render_form_content_before_submission(self) -> str: - """ - Process form content by substituting variables. - - This method should: - 1. Parse the form_content markdown - 2. Substitute {{#node_name.var_name#}} with actual values - 3. Keep {{#$output.field_name#}} placeholders for form inputs - """ - rendered_form_content = self.graph_runtime_state.variable_pool.convert_template( - self._node_data.form_content, - ) - return rendered_form_content.markdown - - @staticmethod - def render_form_content_with_outputs( - form_content: str, - outputs: Mapping[str, Any], - field_names: Sequence[str], - ) -> str: - """ - Replace {{#$output.xxx#}} placeholders with submitted values. - """ - rendered_content = form_content - for field_name in field_names: - placeholder = "{{#$output." + field_name + "#}}" - value = outputs.get(field_name) - if value is None: - replacement = "" - elif isinstance(value, (dict, list)): - replacement = json.dumps(value, ensure_ascii=False) - else: - replacement = str(value) - rendered_content = rendered_content.replace(placeholder, replacement) - return rendered_content - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: HumanInputNodeData, - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selectors referenced in form content and input default values. - - This method should parse: - 1. Variables referenced in form_content ({{#node_name.var_name#}}) - 2. Variables referenced in input default values - """ - return node_data.extract_variable_selector_to_variable_mapping(node_id) diff --git a/api/graphon/nodes/if_else/__init__.py b/api/graphon/nodes/if_else/__init__.py deleted file mode 100644 index afa0e8112c..0000000000 --- a/api/graphon/nodes/if_else/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .if_else_node import IfElseNode - -__all__ = ["IfElseNode"] diff --git a/api/graphon/nodes/if_else/entities.py b/api/graphon/nodes/if_else/entities.py deleted file mode 100644 index d59b782747..0000000000 --- a/api/graphon/nodes/if_else/entities.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.utils.condition.entities import Condition - - -class IfElseNodeData(BaseNodeData): - """ - If Else Node Data. - """ - - type: NodeType = BuiltinNodeTypes.IF_ELSE - - class Case(BaseModel): - """ - Case entity representing a single logical condition group - """ - - case_id: str - logical_operator: Literal["and", "or"] - conditions: list[Condition] - - logical_operator: Literal["and", "or"] | None = "and" - conditions: list[Condition] | None = Field(default=None, deprecated=True) - - cases: list[Case] | None = None diff --git a/api/graphon/nodes/if_else/if_else_node.py b/api/graphon/nodes/if_else/if_else_node.py deleted file mode 100644 index 81e934971a..0000000000 --- a/api/graphon/nodes/if_else/if_else_node.py +++ /dev/null @@ -1,124 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Literal - -from typing_extensions import deprecated - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.if_else.entities import IfElseNodeData -from graphon.runtime import VariablePool -from graphon.utils.condition.entities import Condition -from graphon.utils.condition.processor import ConditionProcessor - - -class IfElseNode(Node[IfElseNodeData]): - node_type = BuiltinNodeTypes.IF_ELSE - execution_type = NodeExecutionType.BRANCH - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run node - :return: - """ - node_inputs: dict[str, Sequence[Mapping[str, Any]]] = {"conditions": []} - - process_data: dict[str, list] = {"condition_results": []} - - input_conditions: Sequence[Mapping[str, Any]] = [] - final_result = False - selected_case_id = "false" - condition_processor = ConditionProcessor() - try: - # Check if the new cases structure is used - if self.node_data.cases: - for case in self.node_data.cases: - input_conditions, group_result, final_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=case.conditions, - operator=case.logical_operator, - ) - - process_data["condition_results"].append( - { - "group": case.model_dump(), - "results": group_result, - "final_result": final_result, - } - ) - - # Break if a case passes (logical short-circuit) - if final_result: - selected_case_id = case.case_id # Capture the ID of the passing case - break - - else: - # TODO: Remove this once all graph definitions use the `cases` structure. - # Fallback to the legacy node shape when `cases` are not defined. - input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated] - condition_processor=condition_processor, - variable_pool=self.graph_runtime_state.variable_pool, - conditions=self.node_data.conditions or [], - operator=self.node_data.logical_operator or "and", - ) - - selected_case_id = "true" if final_result else "false" - - process_data["condition_results"].append( - {"group": "default", "results": group_result, "final_result": final_result} - ) - - node_inputs["conditions"] = input_conditions - - except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_data, error=str(e) - ) - - outputs = {"result": final_result, "selected_case_id": selected_case_id} - - data = NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - process_data=process_data, - edge_source_handle=selected_case_id or "false", # Use case ID or 'default' - outputs=outputs, - ) - - return data - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IfElseNodeData, - ) -> Mapping[str, Sequence[str]]: - var_mapping: dict[str, list[str]] = {} - _ = graph_config # Explicitly mark as unused - for case in node_data.cases or []: - for condition in case.conditions: - key = f"{node_id}.#{'.'.join(condition.variable_selector)}#" - var_mapping[key] = condition.variable_selector - - return var_mapping - - -@deprecated("This function is deprecated. You should use the new cases structure.") -def _should_not_use_old_function( - *, - condition_processor: ConditionProcessor, - variable_pool: VariablePool, - conditions: list[Condition], - operator: Literal["and", "or"], -): - return condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=conditions, - operator=operator, - ) diff --git a/api/graphon/nodes/iteration/__init__.py b/api/graphon/nodes/iteration/__init__.py deleted file mode 100644 index 5bb87aaffa..0000000000 --- a/api/graphon/nodes/iteration/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .entities import IterationNodeData -from .iteration_node import IterationNode -from .iteration_start_node import IterationStartNode - -__all__ = ["IterationNode", "IterationNodeData", "IterationStartNode"] diff --git a/api/graphon/nodes/iteration/entities.py b/api/graphon/nodes/iteration/entities.py deleted file mode 100644 index 30b6e4bea8..0000000000 --- a/api/graphon/nodes/iteration/entities.py +++ /dev/null @@ -1,67 +0,0 @@ -from enum import StrEnum -from typing import Any - -from pydantic import Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base import BaseIterationNodeData, BaseIterationState - - -class ErrorHandleMode(StrEnum): - TERMINATED = "terminated" - CONTINUE_ON_ERROR = "continue-on-error" - REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output" - - -class IterationNodeData(BaseIterationNodeData): - """ - Iteration Node Data. - """ - - type: NodeType = BuiltinNodeTypes.ITERATION - parent_loop_id: str | None = None # redundant field, not used currently - iterator_selector: list[str] # variable selector - output_selector: list[str] # output selector - is_parallel: bool = False # open the parallel mode or not - parallel_nums: int = 10 # the numbers of parallel - error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error - flatten_output: bool = True # whether to flatten the output array if all elements are lists - - -class IterationStartNodeData(BaseNodeData): - """ - Iteration Start Node Data. - """ - - type: NodeType = BuiltinNodeTypes.ITERATION_START - - -class IterationState(BaseIterationState): - """ - Iteration State. - """ - - outputs: list[Any] = Field(default_factory=list) - current_output: Any = None - - class MetaData(BaseIterationState.MetaData): - """ - Data. - """ - - iterator_length: int - - def get_last_output(self) -> Any: - """ - Get last output. - """ - if self.outputs: - return self.outputs[-1] - return None - - def get_current_output(self) -> Any: - """ - Get current output. - """ - return self.current_output diff --git a/api/graphon/nodes/iteration/exc.py b/api/graphon/nodes/iteration/exc.py deleted file mode 100644 index 7b6af61b9d..0000000000 --- a/api/graphon/nodes/iteration/exc.py +++ /dev/null @@ -1,26 +0,0 @@ -class IterationNodeError(ValueError): - """Base class for iteration node errors.""" - - -class IteratorVariableNotFoundError(IterationNodeError): - """Raised when the iterator variable is not found.""" - - -class InvalidIteratorValueError(IterationNodeError): - """Raised when the iterator value is invalid.""" - - -class StartNodeIdNotFoundError(IterationNodeError): - """Raised when the start node ID is not found.""" - - -class IterationGraphNotFoundError(IterationNodeError): - """Raised when the iteration graph is not found.""" - - -class IterationIndexNotFoundError(IterationNodeError): - """Raised when the iteration index is not found.""" - - -class ChildGraphAbortedError(IterationNodeError): - """Raised when a child graph aborts and the container must stop immediately.""" diff --git a/api/graphon/nodes/iteration/iteration_node.py b/api/graphon/nodes/iteration/iteration_node.py deleted file mode 100644 index c013739653..0000000000 --- a/api/graphon/nodes/iteration/iteration_node.py +++ /dev/null @@ -1,686 +0,0 @@ -import logging -from collections.abc import Generator, Mapping, Sequence -from concurrent.futures import Future, ThreadPoolExecutor, as_completed -from contextlib import suppress -from datetime import UTC, datetime -from threading import Lock -from typing import TYPE_CHECKING, Any, NewType, cast - -from typing_extensions import TypeIs - -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.enums import ( - BuiltinNodeTypes, - NodeExecutionType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.graph_events import ( - GraphNodeEventBase, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunSucceededEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import ( - IterationFailedEvent, - IterationNextEvent, - IterationStartedEvent, - IterationSucceededEvent, - NodeEventBase, - NodeRunResult, - StreamCompletedEvent, -) -from graphon.nodes.base import LLMUsageTrackingMixin -from graphon.nodes.base.node import Node -from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from graphon.runtime import VariablePool -from graphon.variables import IntegerVariable, NoneSegment -from graphon.variables.segments import ArrayAnySegment, ArraySegment - -from .exc import ( - ChildGraphAbortedError, - InvalidIteratorValueError, - IterationGraphNotFoundError, - IterationIndexNotFoundError, - IterationNodeError, - IteratorVariableNotFoundError, - StartNodeIdNotFoundError, -) - -if TYPE_CHECKING: - from graphon.graph_engine import GraphEngine - -logger = logging.getLogger(__name__) -_DEFAULT_CHILD_ABORT_REASON = "child graph aborted" - -EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) - - -class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): - """ - Iteration Node. - """ - - node_type = BuiltinNodeTypes.ITERATION - execution_type = NodeExecutionType.CONTAINER - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return { - "type": "iteration", - "config": { - "is_parallel": False, - "parallel_nums": 10, - "error_handle_mode": ErrorHandleMode.TERMINATED, - "flatten_output": True, - }, - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # type: ignore - variable = self._get_iterator_variable() - - if self._is_empty_iteration(variable): - yield from self._handle_empty_iteration(variable) - return - - iterator_list_value = self._validate_and_get_iterator_list(variable) - inputs = {"iterator_selector": iterator_list_value} - - self._validate_start_node() - - started_at = datetime.now(UTC).replace(tzinfo=None) - iter_run_map: dict[str, float] = {} - outputs: list[object] = [] - usage_accumulator = [LLMUsage.empty_usage()] - - yield IterationStartedEvent( - start_at=started_at, - inputs=inputs, - metadata={"iteration_length": len(iterator_list_value)}, - ) - - try: - yield from self._execute_iterations( - iterator_list_value=iterator_list_value, - outputs=outputs, - iter_run_map=iter_run_map, - usage_accumulator=usage_accumulator, - ) - - self._accumulate_usage(usage_accumulator[0]) - yield from self._handle_iteration_success( - started_at=started_at, - inputs=inputs, - outputs=outputs, - iterator_list_value=iterator_list_value, - iter_run_map=iter_run_map, - usage=usage_accumulator[0], - ) - except IterationNodeError as e: - self._accumulate_usage(usage_accumulator[0]) - yield from self._handle_iteration_failure( - started_at=started_at, - inputs=inputs, - outputs=outputs, - iterator_list_value=iterator_list_value, - iter_run_map=iter_run_map, - usage=usage_accumulator[0], - error=e, - ) - - def _get_iterator_variable(self) -> ArraySegment | NoneSegment: - variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) - - if not variable: - raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found") - - if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment): - raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") - - return variable - - def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]: - return isinstance(variable, NoneSegment) or len(variable.value) == 0 - - def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]: - # Try our best to preserve the type information. - if isinstance(variable, ArraySegment): - output = variable.model_copy(update={"value": []}) - else: - output = ArrayAnySegment(value=[]) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - # TODO(QuantumGhost): is it possible to compute the type of `output` - # from graph definition? - outputs={"output": output}, - ) - ) - - def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]: - iterator_list_value = variable.to_object() - - if not isinstance(iterator_list_value, list): - raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") - - return cast(list[object], iterator_list_value) - - def _validate_start_node(self) -> None: - if not self.node_data.start_node_id: - raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") - - def _execute_iterations( - self, - iterator_list_value: Sequence[object], - outputs: list[object], - iter_run_map: dict[str, float], - usage_accumulator: list[LLMUsage], - ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: - if self.node_data.is_parallel: - # Parallel mode execution - yield from self._execute_parallel_iterations( - iterator_list_value=iterator_list_value, - outputs=outputs, - iter_run_map=iter_run_map, - usage_accumulator=usage_accumulator, - ) - else: - # Sequential mode execution - for index, item in enumerate(iterator_list_value): - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - yield IterationNextEvent(index=index) - - graph_engine = self._create_graph_engine(index, item) - - # Run the iteration - try: - yield from self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs, - graph_engine=graph_engine, - ) - finally: - self._merge_graph_engine_usage(usage_accumulator=usage_accumulator, graph_engine=graph_engine) - iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() - - def _execute_parallel_iterations( - self, - iterator_list_value: Sequence[object], - outputs: list[object], - iter_run_map: dict[str, float], - usage_accumulator: list[LLMUsage], - ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: - # Initialize outputs list with None values to maintain order - outputs.extend([None] * len(iterator_list_value)) - - # Determine the number of parallel workers - max_workers = min(self.node_data.parallel_nums, len(iterator_list_value)) - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - # Submit all iteration tasks - started_child_engines: dict[int, GraphEngine] = {} - started_child_engines_lock = Lock() - merged_usage_indexes: set[int] = set() - future_to_index: dict[ - Future[ - tuple[ - float, - list[GraphNodeEventBase], - object | None, - LLMUsage, - ] - ], - int, - ] = {} - for index, item in enumerate(iterator_list_value): - yield IterationNextEvent(index=index) - future = executor.submit( - self._execute_tracked_iteration_parallel, - index=index, - item=item, - started_child_engines=started_child_engines, - started_child_engines_lock=started_child_engines_lock, - ) - future_to_index[future] = index - - # Process completed iterations as they finish - for future in as_completed(future_to_index): - index = future_to_index[future] - try: - result = future.result() - ( - iteration_duration, - events, - output_value, - iteration_usage, - ) = result - - # Update outputs at the correct index - outputs[index] = output_value - - # Yield all events from this iteration - yield from events - - # The worker computes duration before we replay buffered events here, - # so slow downstream consumers don't inflate per-iteration timing. - iter_run_map[str(index)] = iteration_duration - - usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) - merged_usage_indexes.add(index) - - except Exception as e: - if index not in merged_usage_indexes: - self._merge_graph_engine_usage( - usage_accumulator=usage_accumulator, - graph_engine=started_child_engines.get(index), - ) - merged_usage_indexes.add(index) - if isinstance(e, ChildGraphAbortedError): - self._abort_parallel_siblings( - future_to_index=future_to_index, - current_future=future, - started_child_engines=started_child_engines, - reason=str(e) or _DEFAULT_CHILD_ABORT_REASON, - ) - self._drain_parallel_siblings( - future_to_index=future_to_index, - current_future=future, - started_child_engines=started_child_engines, - usage_accumulator=usage_accumulator, - merged_usage_indexes=merged_usage_indexes, - ) - raise e - - # Handle errors based on error_handle_mode - match self.node_data.error_handle_mode: - case ErrorHandleMode.TERMINATED: - # Cancel remaining futures and re-raise - for f in future_to_index: - if f != future: - f.cancel() - raise IterationNodeError(str(e)) - case ErrorHandleMode.CONTINUE_ON_ERROR: - outputs[index] = None - case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: - outputs[index] = None # Will be filtered later - - # Remove None values if in REMOVE_ABNORMAL_OUTPUT mode - if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: - outputs[:] = [output for output in outputs if output is not None] - - @staticmethod - def _merge_graph_engine_usage( - *, - usage_accumulator: list[LLMUsage], - graph_engine: "GraphEngine | None", - ) -> None: - if graph_engine is None: - return - usage_accumulator[0] = IterationNode._merge_usage( - usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage - ) - - def _abort_parallel_siblings( - self, - *, - future_to_index: Mapping[Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], int], - current_future: Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], - started_child_engines: Mapping[int, "GraphEngine"], - reason: str, - ) -> None: - for future, index in future_to_index.items(): - if future == current_future: - continue - - graph_engine = started_child_engines.get(index) - if graph_engine is not None: - graph_engine.request_abort(reason) - - future.cancel() - - def _drain_parallel_siblings( - self, - *, - future_to_index: Mapping[Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], int], - current_future: Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], - started_child_engines: Mapping[int, "GraphEngine"], - usage_accumulator: list[LLMUsage], - merged_usage_indexes: set[int], - ) -> None: - for future, index in future_to_index.items(): - if future == current_future: - continue - if future.cancelled(): - continue - - with suppress(Exception): - future.result() - - if index in merged_usage_indexes: - continue - - self._merge_graph_engine_usage( - usage_accumulator=usage_accumulator, - graph_engine=started_child_engines.get(index), - ) - merged_usage_indexes.add(index) - - def _execute_tracked_iteration_parallel( - self, - *, - index: int, - item: object, - started_child_engines: dict[int, "GraphEngine"], - started_child_engines_lock: Lock, - ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: - graph_engine = self._create_graph_engine(index, item) - with started_child_engines_lock: - started_child_engines[index] = graph_engine - - return self._execute_parallel_iteration_with_graph_engine( - index=index, - graph_engine=graph_engine, - ) - - def _execute_single_iteration_parallel( - self, - index: int, - item: object, - ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: - """Execute a single iteration in parallel mode and return results.""" - graph_engine = self._create_graph_engine(index, item) - return self._execute_parallel_iteration_with_graph_engine(index=index, graph_engine=graph_engine) - - def _execute_parallel_iteration_with_graph_engine( - self, - *, - index: int, - graph_engine: "GraphEngine", - ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: - """Execute a prepared child engine in parallel mode and return results.""" - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - events: list[GraphNodeEventBase] = [] - outputs_temp: list[object] = [] - - # Collect events instead of yielding them directly - for event in self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs_temp, - graph_engine=graph_engine, - ): - events.append(event) - - # Get the output value from the temporary outputs list - output_value = outputs_temp[0] if outputs_temp else None - iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() - - return ( - iteration_duration, - events, - output_value, - graph_engine.graph_runtime_state.llm_usage, - ) - - def _handle_iteration_success( - self, - started_at: datetime, - inputs: dict[str, Sequence[object]], - outputs: list[object], - iterator_list_value: Sequence[object], - iter_run_map: dict[str, float], - *, - usage: LLMUsage, - ) -> Generator[NodeEventBase, None, None]: - # Flatten the list of lists if all outputs are lists - flattened_outputs = self._flatten_outputs_if_needed(outputs) - - yield IterationSucceededEvent( - start_at=started_at, - inputs=inputs, - outputs={"output": flattened_outputs}, - steps=len(iterator_list_value), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, - }, - ) - - # Yield final success event - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": flattened_outputs}, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - ) - - def _flatten_outputs_if_needed(self, outputs: list[object]) -> list[object]: - """ - Flatten the outputs list if all elements are lists. - This maintains backward compatibility with version 1.8.1 behavior. - - If flatten_output is False, returns outputs as-is (nested structure). - If flatten_output is True (default), flattens the list if all elements are lists. - """ - # If flatten_output is disabled, return outputs as-is - if not self.node_data.flatten_output: - return outputs - - if not outputs: - return outputs - - # Check if all non-None outputs are lists - non_none_outputs: list[object] = [output for output in outputs if output is not None] - if not non_none_outputs: - return outputs - - if all(isinstance(output, list) for output in non_none_outputs): - # Flatten the list of lists - flattened: list[Any] = [] - for output in outputs: - if isinstance(output, list): - flattened.extend(output) - elif output is not None: - # This shouldn't happen based on our check, but handle it gracefully - flattened.append(output) - return flattened - - return outputs - - def _handle_iteration_failure( - self, - started_at: datetime, - inputs: dict[str, Sequence[object]], - outputs: list[object], - iterator_list_value: Sequence[object], - iter_run_map: dict[str, float], - *, - usage: LLMUsage, - error: IterationNodeError, - ) -> Generator[NodeEventBase, None, None]: - # Flatten the list of lists if all outputs are lists (even in failure case) - flattened_outputs = self._flatten_outputs_if_needed(outputs) - - yield IterationFailedEvent( - start_at=started_at, - inputs=inputs, - outputs={"output": flattened_outputs}, - steps=len(iterator_list_value), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, - }, - error=str(error), - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(error), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IterationNodeData, - ) -> Mapping[str, Sequence[str]]: - variable_mapping: dict[str, Sequence[str]] = { - f"{node_id}.input_selector": node_data.iterator_selector, - } - iteration_node_ids = set() - - # Find all nodes that belong to this loop - nodes = graph_config.get("nodes", []) - for node in nodes: - node_config_data = node.get("data", {}) - if node_config_data.get("iteration_id") == node_id: - in_iteration_node_id = node.get("id") - if in_iteration_node_id: - iteration_node_ids.add(in_iteration_node_id) - - # Get node configs from graph_config instead of non-existent node_id_config_mapping - node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} - for sub_node_id, sub_node_config in node_configs.items(): - if sub_node_config.get("data", {}).get("iteration_id") != node_id: - continue - - # variable selector to variable mapping - try: - typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) - node_type = typed_sub_node_config["data"].type - node_mapping = Node.get_node_type_classes_mapping() - if node_type not in node_mapping: - continue - node_version = str(typed_sub_node_config["data"].version) - node_cls = node_mapping[node_type][node_version] - - sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=typed_sub_node_config - ) - sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) - except NotImplementedError: - sub_node_variable_mapping = {} - - # remove iteration variables - sub_node_variable_mapping = { - sub_node_id + "." + key: value - for key, value in sub_node_variable_mapping.items() - if value[0] != node_id - } - - variable_mapping.update(sub_node_variable_mapping) - - # remove variable out from iteration - variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids} - - return variable_mapping - - def _append_iteration_info_to_event( - self, - event: GraphNodeEventBase, - iter_run_index: int, - ): - event.in_iteration_id = self._node_id - iter_metadata = { - WorkflowNodeExecutionMetadataKey.ITERATION_ID: self._node_id, - WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index, - } - - current_metadata = event.node_run_result.metadata - if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata: - event.node_run_result.metadata = {**current_metadata, **iter_metadata} - - def _run_single_iter( - self, - *, - variable_pool: VariablePool, - outputs: list[object], - graph_engine: "GraphEngine", - ) -> Generator[GraphNodeEventBase, None, None]: - rst = graph_engine.run() - # get current iteration index - index_variable = variable_pool.get([self._node_id, "index"]) - if not isinstance(index_variable, IntegerVariable): - raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found") - current_index = index_variable.value - for event in rst: - if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.ITERATION_START: - continue - - if isinstance(event, GraphNodeEventBase): - self._append_iteration_info_to_event(event=event, iter_run_index=current_index) - yield event - elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)): - result = variable_pool.get(self.node_data.output_selector) - if result is None: - outputs.append(None) - else: - outputs.append(result.to_object()) - return - elif isinstance(event, GraphRunAbortedEvent): - raise ChildGraphAbortedError(event.reason or _DEFAULT_CHILD_ABORT_REASON) - elif isinstance(event, GraphRunFailedEvent): - match self.node_data.error_handle_mode: - case ErrorHandleMode.TERMINATED: - raise IterationNodeError(event.error) - case ErrorHandleMode.CONTINUE_ON_ERROR: - outputs.append(None) - return - case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: - return - - def _create_graph_engine(self, index: int, item: object): - from graphon.entities import GraphInitParams - from graphon.runtime import ChildGraphNotFoundError - - # Create GraphInitParams for child graph execution. - graph_init_params = GraphInitParams( - workflow_id=self.workflow_id, - graph_config=self.graph_config, - run_context=self.run_context, - call_depth=self.workflow_call_depth, - ) - # Create a deep copy of the variable pool for each iteration - variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True) - - # append iteration variable (item, index) to variable pool - variable_pool_copy.add([self._node_id, "index"], index) - variable_pool_copy.add([self._node_id, "item"], item) - root_node_id = self.node_data.start_node_id - if root_node_id is None: - raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") - - try: - return self.graph_runtime_state.create_child_engine( - workflow_id=self.workflow_id, - graph_init_params=graph_init_params, - root_node_id=root_node_id, - variable_pool=variable_pool_copy, - ) - except ChildGraphNotFoundError as exc: - raise IterationGraphNotFoundError("iteration graph not found") from exc diff --git a/api/graphon/nodes/iteration/iteration_start_node.py b/api/graphon/nodes/iteration/iteration_start_node.py deleted file mode 100644 index 3a44d3d81d..0000000000 --- a/api/graphon/nodes/iteration/iteration_start_node.py +++ /dev/null @@ -1,22 +0,0 @@ -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.iteration.entities import IterationStartNodeData - - -class IterationStartNode(Node[IterationStartNodeData]): - """ - Iteration Start Node. - """ - - node_type = BuiltinNodeTypes.ITERATION_START - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run the node. - """ - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/graphon/nodes/list_operator/__init__.py b/api/graphon/nodes/list_operator/__init__.py deleted file mode 100644 index 1877586ef4..0000000000 --- a/api/graphon/nodes/list_operator/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .node import ListOperatorNode - -__all__ = ["ListOperatorNode"] diff --git a/api/graphon/nodes/list_operator/entities.py b/api/graphon/nodes/list_operator/entities.py deleted file mode 100644 index 0db1c75cdd..0000000000 --- a/api/graphon/nodes/list_operator/entities.py +++ /dev/null @@ -1,71 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum - -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - - -class FilterOperator(StrEnum): - # string conditions - CONTAINS = "contains" - START_WITH = "start with" - END_WITH = "end with" - IS = "is" - IN = "in" - EMPTY = "empty" - NOT_CONTAINS = "not contains" - IS_NOT = "is not" - NOT_IN = "not in" - NOT_EMPTY = "not empty" - # number conditions - EQUAL = "=" - NOT_EQUAL = "≠" - LESS_THAN = "<" - GREATER_THAN = ">" - GREATER_THAN_OR_EQUAL = "≥" - LESS_THAN_OR_EQUAL = "≤" - - -class Order(StrEnum): - ASC = "asc" - DESC = "desc" - - -class FilterCondition(BaseModel): - key: str = "" - comparison_operator: FilterOperator = FilterOperator.CONTAINS - # the value is bool if the filter operator is comparing with - # a boolean constant. - value: str | Sequence[str] | bool = "" - - -class FilterBy(BaseModel): - enabled: bool = False - conditions: Sequence[FilterCondition] = Field(default_factory=list) - - -class OrderByConfig(BaseModel): - enabled: bool = False - key: str = "" - value: Order = Order.ASC - - -class Limit(BaseModel): - enabled: bool = False - size: int = -1 - - -class ExtractConfig(BaseModel): - enabled: bool = False - serial: str = "1" - - -class ListOperatorNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.LIST_OPERATOR - variable: Sequence[str] = Field(default_factory=list) - filter_by: FilterBy - order_by: OrderByConfig - limit: Limit - extract_by: ExtractConfig = Field(default_factory=ExtractConfig) diff --git a/api/graphon/nodes/list_operator/exc.py b/api/graphon/nodes/list_operator/exc.py deleted file mode 100644 index f88aa0be29..0000000000 --- a/api/graphon/nodes/list_operator/exc.py +++ /dev/null @@ -1,16 +0,0 @@ -class ListOperatorError(ValueError): - """Base class for all ListOperator errors.""" - - pass - - -class InvalidFilterValueError(ListOperatorError): - pass - - -class InvalidKeyError(ListOperatorError): - pass - - -class InvalidConditionError(ListOperatorError): - pass diff --git a/api/graphon/nodes/list_operator/node.py b/api/graphon/nodes/list_operator/node.py deleted file mode 100644 index dad17a8f4a..0000000000 --- a/api/graphon/nodes/list_operator/node.py +++ /dev/null @@ -1,345 +0,0 @@ -from collections.abc import Callable, Sequence -from typing import Any, TypeAlias, TypeVar - -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.file import File -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment -from graphon.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment - -from .entities import FilterOperator, ListOperatorNodeData, Order -from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError - -_SUPPORTED_TYPES_TUPLE = ( - ArrayFileSegment, - ArrayNumberSegment, - ArrayStringSegment, - ArrayBooleanSegment, -) -_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment - - -_T = TypeVar("_T") - - -def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]: - """Returns the negation of a given filter function. If the original filter - returns `True` for a value, the negated filter will return `False`, and vice versa. - """ - - def wrapper(value: _T) -> bool: - return not filter_(value) - - return wrapper - - -class ListOperatorNode(Node[ListOperatorNodeData]): - node_type = BuiltinNodeTypes.LIST_OPERATOR - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self): - inputs: dict[str, Sequence[object]] = {} - process_data: dict[str, Sequence[object]] = {} - outputs: dict[str, Any] = {} - - variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) - if variable is None: - error_message = f"Variable not found for selector: {self.node_data.variable}" - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs - ) - if not variable.value: - inputs = {"variable": []} - process_data = {"variable": []} - if isinstance(variable, ArraySegment): - result = variable.model_copy(update={"value": []}) - else: - result = ArrayAnySegment(value=[]) - outputs = {"result": result, "first_record": None, "last_record": None} - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) - if not isinstance(variable, _SUPPORTED_TYPES_TUPLE): - error_message = f"Variable {self.node_data.variable} is not an array type, actual type: {type(variable)}" - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs - ) - - if isinstance(variable, ArrayFileSegment): - inputs = {"variable": [item.to_dict() for item in variable.value]} - process_data["variable"] = [item.to_dict() for item in variable.value] - else: - inputs = {"variable": variable.value} - process_data["variable"] = variable.value - - try: - # Filter - if self.node_data.filter_by.enabled: - variable = self._apply_filter(variable) - - # Extract - if self.node_data.extract_by.enabled: - variable = self._extract_slice(variable) - - # Order - if self.node_data.order_by.enabled: - variable = self._apply_order(variable) - - # Slice - if self.node_data.limit.enabled: - variable = self._apply_slice(variable) - - outputs = { - "result": variable, - "first_record": variable.value[0] if variable.value else None, - "last_record": variable.value[-1] if variable.value else None, - } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) - except ListOperatorError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) - - def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - filter_func: Callable[[Any], bool] - result: list[Any] = [] - for condition in self.node_data.filter_by.conditions: - if isinstance(variable, ArrayStringSegment): - if not isinstance(condition.value, str): - raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayNumberSegment): - if not isinstance(condition.value, str): - raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayFileSegment): - if isinstance(condition.value, str): - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - elif isinstance(condition.value, bool): - raise ValueError(f"File filter expects a string value, got {type(condition.value)}") - else: - value = condition.value - filter_func = _get_file_filter_func( - key=condition.key, - condition=condition.comparison_operator, - value=value, - ) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - else: - if not isinstance(condition.value, bool): - raise ValueError(f"Boolean filter expects a boolean value, got {type(condition.value)}") - filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - return variable - - def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)): - result = sorted(variable.value, reverse=self.node_data.order_by.value == Order.DESC) - variable = variable.model_copy(update={"value": result}) - else: - result = _order_file( - order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value - ) - variable = variable.model_copy(update={"value": result}) - - return variable - - def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - result = variable.value[: self.node_data.limit.size] - return variable.model_copy(update={"value": result}) - - def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) - if value < 1: - raise ValueError(f"Invalid serial index: must be >= 1, got {value}") - if value > len(variable.value): - raise InvalidKeyError(f"Invalid serial index: must be <= {len(variable.value)}, got {value}") - value -= 1 - result = variable.value[value] - return variable.model_copy(update={"value": [result]}) - - -def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: - match key: - case "size": - return lambda x: x.size - case _: - raise InvalidKeyError(f"Invalid key: {key}") - - -def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: - match key: - case "name": - return lambda x: x.filename or "" - case "type": - return lambda x: str(x.type) - case "extension": - return lambda x: x.extension or "" - case "mime_type": - return lambda x: x.mime_type or "" - case "transfer_method": - return lambda x: str(x.transfer_method) - case "url": - return lambda x: x.remote_url or "" - case "related_id": - return lambda x: x.related_id or "" - case _: - raise InvalidKeyError(f"Invalid key: {key}") - - -def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]: - match condition: - case "contains": - return _contains(value) - case "start with": - return _startswith(value) - case "end with": - return _endswith(value) - case "is": - return _is(value) - case "in": - return _in(value) - case "empty": - return lambda x: x == "" - case "not contains": - return _negation(_contains(value)) - case "is not": - return _negation(_is(value)) - case "not in": - return _negation(_in(value)) - case "not empty": - return lambda x: x != "" - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]: - match condition: - case "in": - return _in(value) - case "not in": - return _negation(_in(value)) - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]: - match condition: - case "=": - return _eq(value) - case "≠": - return _ne(value) - case "<": - return _lt(value) - case "≤": - return _le(value) - case ">": - return _gt(value) - case "≥": - return _ge(value) - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]: - match condition: - case FilterOperator.IS: - return _is(value) - case FilterOperator.IS_NOT: - return _negation(_is(value)) - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: - if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str): - extract_func = _get_file_extract_string_func(key=key) - return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) - if key in {"type", "transfer_method"}: - extract_func = _get_file_extract_string_func(key=key) - return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x)) - elif key == "size" and isinstance(value, str): - extract_number = _get_file_extract_number_func(key=key) - return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x)) - else: - raise InvalidKeyError(f"Invalid key: {key}") - - -def _contains(value: str) -> Callable[[str], bool]: - return lambda x: value in x - - -def _startswith(value: str) -> Callable[[str], bool]: - return lambda x: x.startswith(value) - - -def _endswith(value: str) -> Callable[[str], bool]: - return lambda x: x.endswith(value) - - -def _is(value: _T) -> Callable[[_T], bool]: - return lambda x: x == value - - -def _in(value: str | Sequence[str]) -> Callable[[str], bool]: - return lambda x: x in value - - -def _eq(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x == value - - -def _ne(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x != value - - -def _lt(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x < value - - -def _le(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x <= value - - -def _gt(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x > value - - -def _ge(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x >= value - - -def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]): - extract_func: Callable[[File], Any] - if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url", "related_id"}: - extract_func = _get_file_extract_string_func(key=order_by) - return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC) - elif order_by == "size": - extract_func = _get_file_extract_number_func(key=order_by) - return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC) - else: - raise InvalidKeyError(f"Invalid order key: {order_by}") diff --git a/api/graphon/nodes/llm/__init__.py b/api/graphon/nodes/llm/__init__.py deleted file mode 100644 index f7bc713f63..0000000000 --- a/api/graphon/nodes/llm/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from .entities import ( - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from .node import LLMNode - -__all__ = [ - "LLMNode", - "LLMNodeChatModelMessage", - "LLMNodeCompletionModelPromptTemplate", - "LLMNodeData", - "ModelConfig", - "VisionConfig", -] diff --git a/api/graphon/nodes/llm/entities.py b/api/graphon/nodes/llm/entities.py deleted file mode 100644 index 196152548c..0000000000 --- a/api/graphon/nodes/llm/entities.py +++ /dev/null @@ -1,100 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Literal - -from pydantic import BaseModel, Field, field_validator - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.model_runtime.entities import ImagePromptMessageContent, LLMMode -from graphon.nodes.base.entities import VariableSelector -from graphon.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig - - -class ModelConfig(BaseModel): - provider: str - name: str - mode: LLMMode - completion_params: dict[str, Any] = Field(default_factory=dict) - - -class ContextConfig(BaseModel): - enabled: bool - variable_selector: list[str] | None = None - - -class VisionConfigOptions(BaseModel): - variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"]) - detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH - - -class VisionConfig(BaseModel): - enabled: bool = False - configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions) - - @field_validator("configs", mode="before") - @classmethod - def convert_none_configs(cls, v: Any): - if v is None: - return VisionConfigOptions() - return v - - -class PromptConfig(BaseModel): - jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list) - - @field_validator("jinja2_variables", mode="before") - @classmethod - def convert_none_jinja2_variables(cls, v: Any): - if v is None: - return [] - return v - - -class LLMNodeChatModelMessage(ChatModelMessage): - text: str = "" - jinja2_text: str | None = None - - -class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): - jinja2_text: str | None = None - - -class LLMNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.LLM - model: ModelConfig - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - prompt_config: PromptConfig = Field(default_factory=PromptConfig) - memory: MemoryConfig | None = None - context: ContextConfig - vision: VisionConfig = Field(default_factory=VisionConfig) - structured_output: Mapping[str, Any] | None = None - # We used 'structured_output_enabled' in the past, but it's not a good name. - structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") - reasoning_format: Literal["separated", "tagged"] = Field( - # Keep tagged as default for backward compatibility - default="tagged", - description=( - """ - Strategy for handling model reasoning output. - - separated: Return clean text (without tags) + reasoning_content field. - Recommended for new workflows. Enables safe downstream parsing and - workflow variable access: {{#node_id.reasoning_content#}} - - tagged : Return original text (with tags) + reasoning_content field. - Maintains full backward compatibility while still providing reasoning_content - for workflow automation. Frontend thinking panels work as before. - """ - ), - ) - - @field_validator("prompt_config", mode="before") - @classmethod - def convert_none_prompt_config(cls, v: Any): - if v is None: - return PromptConfig() - return v - - @property - def structured_output_enabled(self) -> bool: - return self.structured_output_switch_on and self.structured_output is not None diff --git a/api/graphon/nodes/llm/exc.py b/api/graphon/nodes/llm/exc.py deleted file mode 100644 index 4d16095296..0000000000 --- a/api/graphon/nodes/llm/exc.py +++ /dev/null @@ -1,45 +0,0 @@ -class LLMNodeError(ValueError): - """Base class for LLM Node errors.""" - - -class VariableNotFoundError(LLMNodeError): - """Raised when a required variable is not found.""" - - -class InvalidContextStructureError(LLMNodeError): - """Raised when the context structure is invalid.""" - - -class InvalidVariableTypeError(LLMNodeError): - """Raised when the variable type is invalid.""" - - -class ModelNotExistError(LLMNodeError): - """Raised when the specified model does not exist.""" - - -class LLMModeRequiredError(LLMNodeError): - """Raised when LLM mode is required but not provided.""" - - -class NoPromptFoundError(LLMNodeError): - """Raised when no prompt is found in the LLM configuration.""" - - -class TemplateTypeNotSupportError(LLMNodeError): - def __init__(self, *, type_name: str): - super().__init__(f"Prompt type {type_name} is not supported.") - - -class MemoryRolePrefixRequiredError(LLMNodeError): - """Raised when memory role prefix is required for completion model.""" - - -class FileTypeNotSupportError(LLMNodeError): - def __init__(self, *, type_name: str): - super().__init__(f"{type_name} type is not supported by this model") - - -class UnsupportedPromptContentTypeError(LLMNodeError): - def __init__(self, *, type_name: str): - super().__init__(f"Prompt content type {type_name} is not supported.") diff --git a/api/graphon/nodes/llm/file_saver.py b/api/graphon/nodes/llm/file_saver.py deleted file mode 100644 index 0bedb42f3a..0000000000 --- a/api/graphon/nodes/llm/file_saver.py +++ /dev/null @@ -1,139 +0,0 @@ -import mimetypes -import typing as tp - -from graphon.file import File, FileTransferMethod, FileType -from graphon.file.constants import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE -from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol - - -class LLMFileSaver(tp.Protocol): - """LLMFileSaver is responsible for save multimodal output returned by - LLM. - """ - - def save_binary_string( - self, - data: bytes, - mime_type: str, - file_type: FileType, - extension_override: str | None = None, - ) -> File: - """save_binary_string saves the inline file data returned by LLM. - - Currently (2025-04-30), only some of Google Gemini models will return - multimodal output as inline data. - - :param data: the contents of the file - :param mime_type: the media type of the file, specified by rfc6838 - (https://datatracker.ietf.org/doc/html/rfc6838) - :param file_type: The file type of the inline file. - :param extension_override: Override the auto-detected file extension while saving this file. - - The default value is `None`, which means do not override the file extension and guessing it - from the `mime_type` attribute while saving the file. - - Setting it to values other than `None` means override the file's extension, and - will bypass the extension guessing saving the file. - - Specially, setting it to empty string (`""`) will leave the file extension empty. - - When it is not `None` or empty string (`""`), it should be a string beginning with a - dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py` - and `tar.gz` are not. - """ - raise NotImplementedError() - - def save_remote_url(self, url: str, file_type: FileType) -> File: - """save_remote_url saves the file from a remote url returned by LLM. - - Currently (2025-04-30), no model returns multimodel output as a url. - - :param url: the url of the file. - :param file_type: the file type of the file, check `FileType` enum for reference. - """ - raise NotImplementedError() - - -class FileSaverImpl(LLMFileSaver): - _tool_file_manager: ToolFileManagerProtocol - _file_reference_factory: FileReferenceFactoryProtocol - - def __init__( - self, - *, - tool_file_manager: ToolFileManagerProtocol, - file_reference_factory: FileReferenceFactoryProtocol, - http_client: HttpClientProtocol, - ): - self._tool_file_manager = tool_file_manager - self._file_reference_factory = file_reference_factory - self._http_client = http_client - - def save_remote_url(self, url: str, file_type: FileType) -> File: - http_response = self._http_client.get(url) - http_response.raise_for_status() - data = http_response.content - mime_type_from_header = http_response.headers.get("Content-Type") - mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header) - return self.save_binary_string(data, mime_type, file_type, extension_override=extension) - - def save_binary_string( - self, - data: bytes, - mime_type: str, - file_type: FileType, - extension_override: str | None = None, - ) -> File: - tool_file = self._tool_file_manager.create_file_by_raw( - file_binary=data, - mimetype=mime_type, - ) - extension_override = _validate_extension_override(extension_override) - extension = _get_extension(mime_type, extension_override) - return self._file_reference_factory.build_from_mapping( - mapping={ - "type": file_type, - "transfer_method": FileTransferMethod.TOOL_FILE, - "filename": tool_file.name, - "extension": extension, - "mime_type": mime_type, - "size": len(data), - "tool_file_id": str(tool_file.id), - "related_id": str(tool_file.id), - "storage_key": tool_file.file_key, - } - ) - - -def _get_extension(mime_type: str, extension_override: str | None = None) -> str: - """get_extension return the extension of file. - - If the `extension_override` parameter is set, this function should honor it and - return its value. - """ - if extension_override is not None: - return extension_override - return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION - - -def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]: - """_extract_content_type_and_extension tries to - guess content type of file from url and `Content-Type` header in response. - """ - if content_type_header: - extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION - return content_type_header, extension - content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE - extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION - return content_type, extension - - -def _validate_extension_override(extension_override: str | None) -> str | None: - # `extension_override` is allow to be `None or `""`. - if extension_override is None: - return None - if extension_override == "": - return "" - if not extension_override.startswith("."): - raise ValueError("extension_override should start with '.' if not None or empty.", extension_override) - return extension_override diff --git a/api/graphon/nodes/llm/llm_utils.py b/api/graphon/nodes/llm/llm_utils.py deleted file mode 100644 index 11a1d83a9d..0000000000 --- a/api/graphon/nodes/llm/llm_utils.py +++ /dev/null @@ -1,545 +0,0 @@ -from __future__ import annotations - -import json -import logging -import re -from collections.abc import Mapping, Sequence -from typing import Any - -from graphon.file import FileType, file_manager -from graphon.file.models import File -from graphon.model_runtime.entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - PromptMessageRole, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageContentUnionTypes, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.nodes.base.entities import VariableSelector -from graphon.runtime import VariablePool -from graphon.template_rendering import Jinja2TemplateRenderer -from graphon.variables import ArrayFileSegment, FileSegment -from graphon.variables.segments import ArrayAnySegment, NoneSegment - -from .entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig -from .exc import ( - InvalidVariableTypeError, - MemoryRolePrefixRequiredError, - NoPromptFoundError, - TemplateTypeNotSupportError, -) -from .runtime_protocols import PreparedLLMProtocol - -CONTEXT_PLACEHOLDER = "{{#context#}}" - -logger = logging.getLogger(__name__) - -VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}") -MAX_RESOLVED_VALUE_LENGTH = 1024 - - -def fetch_model_schema(*, model_instance: PreparedLLMProtocol) -> AIModelEntity: - model_schema = model_instance.get_model_schema() - if not model_schema: - raise ValueError(f"Model schema not found for {getattr(model_instance, 'model_name', 'unknown model')}") - return model_schema - - -def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence[File]: - variable = variable_pool.get(selector) - if variable is None: - return [] - elif isinstance(variable, FileSegment): - return [variable.value] - elif isinstance(variable, ArrayFileSegment): - return variable.value - elif isinstance(variable, NoneSegment | ArrayAnySegment): - return [] - raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") - - -def convert_history_messages_to_text( - *, - history_messages: Sequence[PromptMessage], - human_prefix: str, - ai_prefix: str, -) -> str: - string_messages: list[str] = [] - for message in history_messages: - if message.role == PromptMessageRole.USER: - role = human_prefix - elif message.role == PromptMessageRole.ASSISTANT: - role = ai_prefix - else: - continue - - if isinstance(message.content, list): - content_parts = [] - for content in message.content: - if isinstance(content, TextPromptMessageContent): - content_parts.append(content.data) - elif isinstance(content, ImagePromptMessageContent): - content_parts.append("[image]") - - inner_msg = "\n".join(content_parts) - string_messages.append(f"{role}: {inner_msg}") - else: - string_messages.append(f"{role}: {message.content}") - - return "\n".join(string_messages) - - -def fetch_memory_text( - *, - memory: PromptMessageMemory, - max_token_limit: int, - message_limit: int | None = None, - human_prefix: str = "Human", - ai_prefix: str = "Assistant", -) -> str: - history_messages = memory.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=message_limit, - ) - return convert_history_messages_to_text( - history_messages=history_messages, - human_prefix=human_prefix, - ai_prefix=ai_prefix, - ) - - -def fetch_prompt_messages( - *, - sys_query: str | None = None, - sys_files: Sequence[File], - context: str = "", - memory: PromptMessageMemory | None = None, - model_instance: PreparedLLMProtocol, - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, - stop: Sequence[str] | None = None, - memory_config: MemoryConfig | None = None, - vision_enabled: bool = False, - vision_detail: ImagePromptMessageContent.DETAIL, - variable_pool: VariablePool, - jinja2_variables: Sequence[VariableSelector], - context_files: list[File] | None = None, - template_renderer: Jinja2TemplateRenderer | None = None, -) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - prompt_messages: list[PromptMessage] = [] - model_schema = fetch_model_schema(model_instance=model_instance) - - if isinstance(prompt_template, list): - prompt_messages.extend( - handle_list_messages( - messages=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail, - template_renderer=template_renderer, - ) - ) - - prompt_messages.extend( - handle_memory_chat_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - ) - - if sys_query: - prompt_messages.extend( - handle_list_messages( - messages=[ - LLMNodeChatModelMessage( - text=sys_query, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context="", - jinja2_variables=[], - variable_pool=variable_pool, - vision_detail_config=vision_detail, - template_renderer=template_renderer, - ) - ) - elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - prompt_messages.extend( - handle_completion_template( - template=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - template_renderer=template_renderer, - ) - ) - - memory_text = handle_memory_completion_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - prompt_content = prompt_messages[0].content - if isinstance(prompt_content, str): - prompt_content = str(prompt_content) - if "#histories#" in prompt_content: - prompt_content = prompt_content.replace("#histories#", memory_text) - else: - prompt_content = memory_text + "\n" + prompt_content - prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - if "#histories#" in content_item.data: - content_item.data = content_item.data.replace("#histories#", memory_text) - else: - content_item.data = memory_text + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - - if sys_query: - if isinstance(prompt_content, str): - prompt_messages[0].content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - content_item.data = sys_query + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - else: - raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) - - _append_file_prompts( - prompt_messages=prompt_messages, - files=sys_files, - vision_enabled=vision_enabled, - vision_detail=vision_detail, - ) - _append_file_prompts( - prompt_messages=prompt_messages, - files=context_files or [], - vision_enabled=vision_enabled, - vision_detail=vision_detail, - ) - - filtered_prompt_messages: list[PromptMessage] = [] - for prompt_message in prompt_messages: - if isinstance(prompt_message.content, list): - prompt_message_content: list[PromptMessageContentUnionTypes] = [] - for content_item in prompt_message.content: - if not model_schema.features: - if content_item.type == PromptMessageContentType.TEXT: - prompt_message_content.append(content_item) - continue - - if ( - ( - content_item.type == PromptMessageContentType.IMAGE - and ModelFeature.VISION not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.DOCUMENT - and ModelFeature.DOCUMENT not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.VIDEO - and ModelFeature.VIDEO not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.AUDIO - and ModelFeature.AUDIO not in model_schema.features - ) - ): - continue - prompt_message_content.append(content_item) - if not prompt_message_content: - continue - if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: - prompt_message.content = prompt_message_content[0].data - else: - prompt_message.content = prompt_message_content - filtered_prompt_messages.append(prompt_message) - elif not prompt_message.is_empty(): - filtered_prompt_messages.append(prompt_message) - - if len(filtered_prompt_messages) == 0: - raise NoPromptFoundError( - "No prompt found in the LLM configuration. Please ensure a prompt is properly configured before proceeding." - ) - - return filtered_prompt_messages, stop - - -def handle_list_messages( - *, - messages: Sequence[LLMNodeChatModelMessage], - context: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - vision_detail_config: ImagePromptMessageContent.DETAIL, - template_renderer: Jinja2TemplateRenderer | None = None, -) -> Sequence[PromptMessage]: - prompt_messages: list[PromptMessage] = [] - for message in messages: - if message.edition_type == "jinja2": - result_text = render_jinja2_message( - template=message.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - template_renderer=template_renderer, - ) - prompt_messages.append( - combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], - role=message.role, - ) - ) - continue - - template = message.text.replace(CONTEXT_PLACEHOLDER, context) - segment_group = variable_pool.convert_template(template) - file_contents: list[PromptMessageContentUnionTypes] = [] - for segment in segment_group.value: - if isinstance(segment, ArrayFileSegment): - for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config) - ) - elif isinstance(segment, FileSegment): - file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config) - ) - - if segment_group.text: - prompt_messages.append( - combine_message_content_with_role( - contents=[TextPromptMessageContent(data=segment_group.text)], - role=message.role, - ) - ) - if file_contents: - prompt_messages.append(combine_message_content_with_role(contents=file_contents, role=message.role)) - - return prompt_messages - - -def render_jinja2_message( - *, - template: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - template_renderer: Jinja2TemplateRenderer | None = None, -) -> str: - if not template: - return "" - if template_renderer is None: - raise ValueError("template_renderer is required for jinja2 prompt rendering") - - jinja2_inputs: dict[str, Any] = {} - for jinja2_variable in jinja2_variables: - variable = variable_pool.get(jinja2_variable.value_selector) - jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" - return template_renderer.render_template(template, jinja2_inputs) - - -def handle_completion_template( - *, - template: LLMNodeCompletionModelPromptTemplate, - context: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - template_renderer: Jinja2TemplateRenderer | None = None, -) -> Sequence[PromptMessage]: - if template.edition_type == "jinja2": - result_text = render_jinja2_message( - template=template.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - template_renderer=template_renderer, - ) - else: - template_text = template.text.replace(CONTEXT_PLACEHOLDER, context) - result_text = variable_pool.convert_template(template_text).text - return [ - combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], - role=PromptMessageRole.USER, - ) - ] - - -def combine_message_content_with_role( - *, - contents: str | list[PromptMessageContentUnionTypes] | None = None, - role: PromptMessageRole, -) -> PromptMessage: - match role: - case PromptMessageRole.USER: - return UserPromptMessage(content=contents) - case PromptMessageRole.ASSISTANT: - return AssistantPromptMessage(content=contents) - case PromptMessageRole.SYSTEM: - return SystemPromptMessage(content=contents) - case _: - raise NotImplementedError(f"Role {role} is not supported") - - -def calculate_rest_token( - *, - prompt_messages: list[PromptMessage], - model_instance: PreparedLLMProtocol, -) -> int: - rest_tokens = 2000 - runtime_model_schema = fetch_model_schema(model_instance=model_instance) - runtime_model_parameters = model_instance.parameters - - model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in runtime_model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - runtime_model_parameters.get(parameter_rule.name) - or runtime_model_parameters.get(str(parameter_rule.use_template)) - or 0 - ) - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - -def handle_memory_chat_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: PreparedLLMProtocol, -) -> Sequence[PromptMessage]: - if not memory or not memory_config: - return [] - rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance) - return memory.get_history_prompt_messages( - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - ) - - -def handle_memory_completion_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: PreparedLLMProtocol, -) -> str: - if not memory or not memory_config: - return "" - - rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance) - if not memory_config.role_prefix: - raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") - - return fetch_memory_text( - memory=memory, - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - human_prefix=memory_config.role_prefix.user, - ai_prefix=memory_config.role_prefix.assistant, - ) - - -def _append_file_prompts( - *, - prompt_messages: list[PromptMessage], - files: Sequence[File], - vision_enabled: bool, - vision_detail: ImagePromptMessageContent.DETAIL, -) -> None: - if not vision_enabled or not files: - return - - file_prompts = [file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in files] - if ( - prompt_messages - and isinstance(prompt_messages[-1], UserPromptMessage) - and isinstance(prompt_messages[-1].content, list) - ): - existing_contents = prompt_messages[-1].content - assert isinstance(existing_contents, list) - prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents) - else: - prompt_messages.append(UserPromptMessage(content=file_prompts)) - - -def _coerce_resolved_value(raw: str) -> int | float | bool | str: - """Try to restore the original type from a resolved template string. - - Variable references are always resolved to text, but completion params may - expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to - the ``temperature`` parameter). This helper attempts a JSON parse so that - ``"0.7"`` → ``0.7``, ``"true"`` → ``True``, etc. Plain strings that are not - valid JSON literals are returned as-is. - """ - stripped = raw.strip() - if not stripped: - return raw - - try: - parsed: object = json.loads(stripped) - except (json.JSONDecodeError, ValueError): - return raw - - if isinstance(parsed, (int, float, bool)): - return parsed - return raw - - -def resolve_completion_params_variables( - completion_params: Mapping[str, Any], - variable_pool: VariablePool, -) -> dict[str, Any]: - """Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params. - - Security notes: - - Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to - prevent denial-of-service through excessively large variable payloads. - - This follows the same ``VariablePool.convert_template`` pattern used across - Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream - model plugin receives these values as structured JSON key-value pairs — they - are never concatenated into raw HTTP headers or SQL queries. - - Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are - restored to their native type rather than sent as a bare string. - """ - resolved: dict[str, Any] = {} - for key, value in completion_params.items(): - if isinstance(value, str) and VARIABLE_PATTERN.search(value): - segment_group = variable_pool.convert_template(value) - text = segment_group.text - if len(text) > MAX_RESOLVED_VALUE_LENGTH: - logger.warning( - "Resolved value for param '%s' truncated from %d to %d chars", - key, - len(text), - MAX_RESOLVED_VALUE_LENGTH, - ) - text = text[:MAX_RESOLVED_VALUE_LENGTH] - resolved[key] = _coerce_resolved_value(text) - else: - resolved[key] = value - return resolved diff --git a/api/graphon/nodes/llm/node.py b/api/graphon/nodes/llm/node.py deleted file mode 100644 index 4de2a95465..0000000000 --- a/api/graphon/nodes/llm/node.py +++ /dev/null @@ -1,1372 +0,0 @@ -from __future__ import annotations - -import base64 -import io -import json -import logging -import re -import time -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, cast - -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ( - BuiltinNodeTypes, - NodeType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import File, FileType, file_manager -from graphon.model_runtime.entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, - LLMStructuredOutput, - LLMUsage, -) -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageContentUnionTypes, - PromptMessageRole, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.node_events import ( - ModelInvokeCompletedEvent, - NodeEventBase, - NodeRunResult, - RunRetrieverResourceEvent, - StreamChunkEvent, - StreamCompletedEvent, -) -from graphon.nodes.base.entities import VariableSelector -from graphon.nodes.base.node import Node -from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from graphon.nodes.llm.runtime_protocols import ( - PreparedLLMProtocol, - PromptMessageSerializerProtocol, - RetrieverAttachmentLoaderProtocol, -) -from graphon.nodes.protocols import HttpClientProtocol -from graphon.prompt_entities import CompletionModelPromptTemplate, MemoryConfig -from graphon.runtime import VariablePool -from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError -from graphon.variables import ( - ArrayFileSegment, - ArraySegment, - FileSegment, - NoneSegment, - ObjectSegment, - StringSegment, -) - -from . import llm_utils -from .entities import ( - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - LLMNodeData, -) -from .exc import ( - InvalidContextStructureError, - InvalidVariableTypeError, - LLMNodeError, - MemoryRolePrefixRequiredError, - NoPromptFoundError, - TemplateTypeNotSupportError, - VariableNotFoundError, -) -from .file_saver import LLMFileSaver - -if TYPE_CHECKING: - from graphon.file.models import File - from graphon.runtime import GraphRuntimeState - -logger = logging.getLogger(__name__) - - -class LLMNode(Node[LLMNodeData]): - node_type = BuiltinNodeTypes.LLM - - # Compiled regex for extracting blocks (with compatibility for attributes) - _THINK_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) - - # Instance attributes specific to LLMNode. - # Output variable for file - _file_outputs: list[File] - - _llm_file_saver: LLMFileSaver - _retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None - _prompt_message_serializer: PromptMessageSerializerProtocol - _jinja2_template_renderer: Jinja2TemplateRenderer | None - _model_instance: PreparedLLMProtocol - _memory: PromptMessageMemory | None - _default_query_selector: tuple[str, ...] | None - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - *, - credentials_provider: object | None = None, - model_factory: object | None = None, - model_instance: PreparedLLMProtocol, - http_client: HttpClientProtocol, - memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver, - prompt_message_serializer: PromptMessageSerializerProtocol, - retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None = None, - jinja2_template_renderer: Jinja2TemplateRenderer | None = None, - default_query_selector: Sequence[str] | None = None, - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - # LLM file outputs, used for MultiModal outputs. - self._file_outputs = [] - - _ = credentials_provider, model_factory, http_client - self._model_instance = model_instance - self._memory = memory - - self._llm_file_saver = llm_file_saver - self._prompt_message_serializer = prompt_message_serializer - self._retriever_attachment_loader = retriever_attachment_loader - self._jinja2_template_renderer = jinja2_template_renderer - self._default_query_selector = tuple(default_query_selector) if default_query_selector is not None else None - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator: - node_inputs: dict[str, Any] = {} - process_data: dict[str, Any] = {} - result_text = "" - clean_text = "" - usage = LLMUsage.empty_usage() - finish_reason = None - reasoning_content = None - variable_pool = self.graph_runtime_state.variable_pool - - try: - # init messages template - self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) - - # fetch variables and fetch values from variable pool - inputs = self._fetch_inputs(node_data=self.node_data) - - # fetch jinja2 inputs - jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) - - # merge inputs - inputs.update(jinja_inputs) - - # fetch files - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=self.node_data.vision.configs.variable_selector, - ) - if self.node_data.vision.enabled - else [] - ) - - if files: - node_inputs["#files#"] = [file.to_dict() for file in files] - - # fetch context value - generator = self._fetch_context(node_data=self.node_data) - context = None - context_files: list[File] = [] - if generator is not None: - for event in generator: - context = event.context - context_files = event.context_files or [] - yield event - if context: - node_inputs["#context#"] = context - - if context_files: - node_inputs["#context_files#"] = [file.model_dump() for file in context_files] - - # fetch model config - model_instance = self._model_instance - # Resolve variable references in string-typed completion params - model_instance.parameters = llm_utils.resolve_completion_params_variables( - model_instance.parameters, variable_pool - ) - model_name = model_instance.model_name - model_provider = model_instance.provider - model_stop = model_instance.stop - - memory = self._memory - - query: str | None = None - if self.node_data.memory: - query = self.node_data.memory.query_prompt_template - if ( - not query - and self._default_query_selector - and (query_variable := variable_pool.get(self._default_query_selector)) - ): - query = query_variable.text - - prompt_messages, stop = LLMNode.fetch_prompt_messages( - sys_query=query, - sys_files=files, - context=context or "", - memory=memory, - model_instance=model_instance, - stop=model_stop, - prompt_template=self.node_data.prompt_template, - memory_config=self.node_data.memory, - vision_enabled=self.node_data.vision.enabled, - vision_detail=self.node_data.vision.configs.detail, - variable_pool=variable_pool, - jinja2_variables=self.node_data.prompt_config.jinja2_variables, - context_files=context_files, - jinja2_template_renderer=self._jinja2_template_renderer, - ) - - # handle invoke result - generator = LLMNode.invoke_llm( - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - structured_output_enabled=self.node_data.structured_output_enabled, - structured_output=self.node_data.structured_output, - file_saver=self._llm_file_saver, - file_outputs=self._file_outputs, - node_id=self._node_id, - node_type=self.node_type, - reasoning_format=self.node_data.reasoning_format, - ) - - structured_output: LLMStructuredOutput | None = None - - for event in generator: - if isinstance(event, StreamChunkEvent): - yield event - elif isinstance(event, ModelInvokeCompletedEvent): - # Raw text - result_text = event.text - usage = event.usage - finish_reason = event.finish_reason - reasoning_content = event.reasoning_content or "" - - # For downstream nodes, determine clean text based on reasoning_format - if self.node_data.reasoning_format == "tagged": - # Keep tags for backward compatibility - clean_text = result_text - else: - # Extract clean text from tags - clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format) - - # Process structured output if available from the event. - structured_output = ( - LLMStructuredOutput(structured_output=event.structured_output) - if event.structured_output - else None - ) - - break - elif isinstance(event, LLMStructuredOutput): - structured_output = event - - process_data = { - "model_mode": self.node_data.model.mode, - "prompts": self._prompt_message_serializer.serialize( - model_mode=self.node_data.model.mode, prompt_messages=prompt_messages - ), - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - "model_provider": model_provider, - "model_name": model_name, - } - - outputs = { - "text": clean_text, - "reasoning_content": reasoning_content, - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - } - if structured_output: - outputs["structured_output"] = structured_output.structured_output - if self._file_outputs: - outputs["files"] = ArrayFileSegment(value=self._file_outputs) - - # Send final chunk event to indicate streaming is complete - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - process_data=process_data, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - ) - except ValueError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=node_inputs, - process_data=process_data, - error_type=type(e).__name__, - llm_usage=usage, - ) - ) - except Exception as e: - logger.exception("error while executing llm node") - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=node_inputs, - process_data=process_data, - error_type=type(e).__name__, - llm_usage=usage, - ) - ) - - @staticmethod - def invoke_llm( - *, - model_instance: PreparedLLMProtocol, - prompt_messages: Sequence[PromptMessage], - stop: Sequence[str] | None = None, - structured_output_enabled: bool, - structured_output: Mapping[str, Any] | None = None, - file_saver: LLMFileSaver, - file_outputs: list[File], - node_id: str, - node_type: NodeType, - reasoning_format: Literal["separated", "tagged"] = "tagged", - ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: - model_parameters = model_instance.parameters - invoke_model_parameters = dict(model_parameters) - invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None] - if structured_output_enabled: - output_schema = LLMNode.fetch_structured_output_schema( - structured_output=structured_output or {}, - ) - request_start_time = time.perf_counter() - - invoke_result = cast( - LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], - model_instance.invoke_llm_with_structured_output( - prompt_messages=prompt_messages, - json_schema=output_schema, - model_parameters=invoke_model_parameters, - stop=stop, - stream=True, - ), - ) - else: - request_start_time = time.perf_counter() - - invoke_result = cast( - LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], - model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=invoke_model_parameters, - tools=None, - stop=stop, - stream=True, - ), - ) - - return LLMNode.handle_invoke_result( - invoke_result=invoke_result, - file_saver=file_saver, - file_outputs=file_outputs, - node_id=node_id, - node_type=node_type, - model_instance=model_instance, - reasoning_format=reasoning_format, - request_start_time=request_start_time, - ) - - @staticmethod - def handle_invoke_result( - *, - invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], - file_saver: LLMFileSaver, - file_outputs: list[File], - node_id: str, - node_type: NodeType, - model_instance: PreparedLLMProtocol | object, - reasoning_format: Literal["separated", "tagged"] = "tagged", - request_start_time: float | None = None, - ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: - # For blocking mode - if isinstance(invoke_result, LLMResult): - duration = None - if request_start_time is not None: - duration = time.perf_counter() - request_start_time - invoke_result.usage.latency = round(duration, 3) - event = LLMNode.handle_blocking_result( - invoke_result=invoke_result, - saver=file_saver, - file_outputs=file_outputs, - reasoning_format=reasoning_format, - request_latency=duration, - ) - yield event - return - - # For streaming mode - model = "" - prompt_messages: list[PromptMessage] = [] - - usage = LLMUsage.empty_usage() - finish_reason = None - full_text_buffer = io.StringIO() - - # Initialize streaming metrics tracking - start_time = request_start_time if request_start_time is not None else time.perf_counter() - first_token_time = None - has_content = False - - collected_structured_output = None # Collect structured_output from streaming chunks - # Consume the invoke result and handle generator exception - try: - for result in invoke_result: - if isinstance(result, LLMResultChunkWithStructuredOutput): - # Collect structured_output from the chunk - if result.structured_output is not None: - collected_structured_output = dict(result.structured_output) - yield result - if isinstance(result, LLMResultChunk): - contents = result.delta.message.content - for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( - contents=contents, - file_saver=file_saver, - file_outputs=file_outputs, - ): - # Detect first token for TTFT calculation - if text_part and not has_content: - first_token_time = time.perf_counter() - has_content = True - - full_text_buffer.write(text_part) - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=text_part, - is_final=False, - ) - - # Update the whole metadata - if not model and result.model: - model = result.model - if len(prompt_messages) == 0: - # TODO(QuantumGhost): it seems that this update has no visable effect. - # What's the purpose of the line below? - prompt_messages = list(result.prompt_messages) - if usage.prompt_tokens == 0 and result.delta.usage: - usage = result.delta.usage - if finish_reason is None and result.delta.finish_reason: - finish_reason = result.delta.finish_reason - except Exception as e: - if hasattr(model_instance, "is_structured_output_parse_error") and cast( - PreparedLLMProtocol, model_instance - ).is_structured_output_parse_error(e): - raise LLMNodeError(f"Failed to parse structured output: {e}") from e - if type(e).__name__ == "OutputParserError": - raise LLMNodeError(f"Failed to parse structured output: {e}") from e - raise - - # Extract reasoning content from tags in the main text - full_text = full_text_buffer.getvalue() - - if reasoning_format == "tagged": - # Keep tags in text for backward compatibility - clean_text = full_text - reasoning_content = "" - else: - # Extract clean text and reasoning from tags - clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) - - # Calculate streaming metrics - end_time = time.perf_counter() - total_duration = end_time - start_time - usage.latency = round(total_duration, 3) - if has_content and first_token_time: - gen_ai_server_time_to_first_token = first_token_time - start_time - llm_streaming_time_to_generate = end_time - first_token_time - usage.time_to_first_token = round(gen_ai_server_time_to_first_token, 3) - usage.time_to_generate = round(llm_streaming_time_to_generate, 3) - - yield ModelInvokeCompletedEvent( - # Use clean_text for separated mode, full_text for tagged mode - text=clean_text if reasoning_format == "separated" else full_text, - usage=usage, - finish_reason=finish_reason, - # Reasoning content for workflow variables and downstream nodes - reasoning_content=reasoning_content, - # Pass structured output if collected from streaming chunks - structured_output=collected_structured_output, - ) - - @staticmethod - def _image_file_to_markdown(file: File, /): - text_chunk = f"![]({file.generate_url()})" - return text_chunk - - @classmethod - def _split_reasoning( - cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged" - ) -> tuple[str, str]: - """ - Split reasoning content from text based on reasoning_format strategy. - - Args: - text: Full text that may contain blocks - reasoning_format: Strategy for handling reasoning content - - "separated": Remove tags and return clean text + reasoning_content field - - "tagged": Keep tags in text, return empty reasoning_content - - Returns: - tuple of (clean_text, reasoning_content) - """ - - if reasoning_format == "tagged": - return text, "" - - # Find all ... blocks (case-insensitive) - matches = cls._THINK_PATTERN.findall(text) - - # Extract reasoning content from all blocks - reasoning_content = "\n".join(match.strip() for match in matches) if matches else "" - - # Remove all ... blocks from original text - clean_text = cls._THINK_PATTERN.sub("", text) - - # Clean up extra whitespace - clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip() - - # Separated mode: always return clean text and reasoning_content - return clean_text, reasoning_content or "" - - def _transform_chat_messages( - self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / - ) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: - if isinstance(messages, LLMNodeCompletionModelPromptTemplate): - if messages.edition_type == "jinja2" and messages.jinja2_text: - messages.text = messages.jinja2_text - - return messages - - for message in messages: - if message.edition_type == "jinja2" and message.jinja2_text: - message.text = message.jinja2_text - - return messages - - def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: - variables: dict[str, Any] = {} - - if not node_data.prompt_config: - return variables - - for variable_selector in node_data.prompt_config.jinja2_variables or []: - variable_name = variable_selector.variable - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if variable is None: - raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") - - def parse_dict(input_dict: Mapping[str, Any]) -> str: - """ - Parse dict into string - """ - # check if it's a context structure - if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict: - return str(input_dict["content"]) - - # else, parse the dict - try: - return json.dumps(input_dict, ensure_ascii=False) - except Exception: - return str(input_dict) - - if isinstance(variable, ArraySegment): - result = "" - for item in variable.value: - if isinstance(item, dict): - result += parse_dict(item) - else: - result += str(item) - result += "\n" - value = result.strip() - elif isinstance(variable, ObjectSegment): - value = parse_dict(variable.value) - else: - value = variable.text - - variables[variable_name] = value - - return variables - - def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]: - inputs = {} - prompt_template = node_data.prompt_template - - variable_selectors = [] - if isinstance(prompt_template, list): - for prompt in prompt_template: - variable_template_parser = VariableTemplateParser(template=prompt.text) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - elif isinstance(prompt_template, CompletionModelPromptTemplate): - variable_template_parser = VariableTemplateParser(template=prompt_template.text) - variable_selectors = variable_template_parser.extract_variable_selectors() - - for variable_selector in variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if variable is None: - raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") - if isinstance(variable, NoneSegment): - inputs[variable_selector.variable] = "" - inputs[variable_selector.variable] = variable.to_object() - - memory = node_data.memory - if memory and memory.query_prompt_template: - query_variable_selectors = VariableTemplateParser( - template=memory.query_prompt_template - ).extract_variable_selectors() - for variable_selector in query_variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if variable is None: - raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") - if isinstance(variable, NoneSegment): - continue - inputs[variable_selector.variable] = variable.to_object() - - return inputs - - def _fetch_context(self, node_data: LLMNodeData): - if not node_data.context.enabled: - return - - if not node_data.context.variable_selector: - return - - context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector) - if context_value_variable: - if isinstance(context_value_variable, StringSegment): - yield RunRetrieverResourceEvent( - retriever_resources=[], context=context_value_variable.value, context_files=[] - ) - elif isinstance(context_value_variable, ArraySegment): - context_str = "" - original_retriever_resource: list[dict[str, Any]] = [] - context_files: list[File] = [] - for item in context_value_variable.value: - if isinstance(item, str): - context_str += item + "\n" - else: - if "content" not in item: - raise InvalidContextStructureError(f"Invalid context structure: {item}") - - if item.get("summary"): - context_str += item["summary"] + "\n" - context_str += item["content"] + "\n" - - retriever_resource = self._convert_to_original_retriever_resource(item) - if retriever_resource: - original_retriever_resource.append(retriever_resource) - segment_id = retriever_resource.get("segment_id") - if not segment_id: - continue - if self._retriever_attachment_loader is not None: - context_files.extend(self._retriever_attachment_loader.load(segment_id=segment_id)) - yield RunRetrieverResourceEvent( - retriever_resources=original_retriever_resource, - context=context_str.strip(), - context_files=context_files, - ) - - def _convert_to_original_retriever_resource(self, context_dict: dict) -> dict[str, Any] | None: - if ( - "metadata" in context_dict - and "_source" in context_dict["metadata"] - and context_dict["metadata"]["_source"] == "knowledge" - ): - metadata = context_dict.get("metadata", {}) - - return { - "position": metadata.get("position"), - "dataset_id": metadata.get("dataset_id"), - "dataset_name": metadata.get("dataset_name"), - "document_id": metadata.get("document_id"), - "document_name": metadata.get("document_name"), - "data_source_type": metadata.get("data_source_type"), - "segment_id": metadata.get("segment_id"), - "retriever_from": metadata.get("retriever_from"), - "score": metadata.get("score"), - "hit_count": metadata.get("segment_hit_count"), - "word_count": metadata.get("segment_word_count"), - "segment_position": metadata.get("segment_position"), - "index_node_hash": metadata.get("segment_index_node_hash"), - "content": context_dict.get("content"), - "page": metadata.get("page"), - "doc_metadata": metadata.get("doc_metadata"), - "files": context_dict.get("files"), - "summary": context_dict.get("summary"), - } - - return None - - @staticmethod - def fetch_prompt_messages( - *, - sys_query: str | None = None, - sys_files: Sequence[File], - context: str = "", - memory: PromptMessageMemory | None = None, - model_instance: PreparedLLMProtocol, - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, - stop: Sequence[str] | None = None, - memory_config: MemoryConfig | None = None, - vision_enabled: bool = False, - vision_detail: ImagePromptMessageContent.DETAIL, - variable_pool: VariablePool, - jinja2_variables: Sequence[VariableSelector], - context_files: list[File] | None = None, - jinja2_template_renderer: Jinja2TemplateRenderer | None = None, - ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - prompt_messages: list[PromptMessage] = [] - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - - if isinstance(prompt_template, list): - # For chat model - prompt_messages.extend( - LLMNode.handle_list_messages( - messages=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail, - jinja2_template_renderer=jinja2_template_renderer, - ) - ) - - # Get memory messages for chat mode - memory_messages = _handle_memory_chat_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - # Extend prompt_messages with memory messages - prompt_messages.extend(memory_messages) - - # Add current query to the prompt messages - if sys_query: - message = LLMNodeChatModelMessage( - text=sys_query, - role=PromptMessageRole.USER, - edition_type="basic", - ) - prompt_messages.extend( - LLMNode.handle_list_messages( - messages=[message], - context="", - jinja2_variables=[], - variable_pool=variable_pool, - vision_detail_config=vision_detail, - jinja2_template_renderer=jinja2_template_renderer, - ) - ) - - elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - # For completion model - prompt_messages.extend( - _handle_completion_template( - template=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - jinja2_template_renderer=jinja2_template_renderer, - ) - ) - - # Get memory text for completion model - memory_text = _handle_memory_completion_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - # Insert histories into the prompt - prompt_content = prompt_messages[0].content - # For issue #11247 - Check if prompt content is a string or a list - if isinstance(prompt_content, str): - prompt_content = str(prompt_content) - if "#histories#" in prompt_content: - prompt_content = prompt_content.replace("#histories#", memory_text) - else: - prompt_content = memory_text + "\n" + prompt_content - prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - if "#histories#" in content_item.data: - content_item.data = content_item.data.replace("#histories#", memory_text) - else: - content_item.data = memory_text + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - - # Add current query to the prompt message - if sys_query: - if isinstance(prompt_content, str): - prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) - prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - content_item.data = sys_query + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - else: - raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) - - # The sys_files will be deprecated later - if vision_enabled and sys_files: - file_prompts = [] - for file in sys_files: - file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) - file_prompts.append(file_prompt) - # If last prompt is a user prompt, add files into its contents, - # otherwise append a new user prompt - if ( - len(prompt_messages) > 0 - and isinstance(prompt_messages[-1], UserPromptMessage) - and isinstance(prompt_messages[-1].content, list) - ): - prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) - else: - prompt_messages.append(UserPromptMessage(content=file_prompts)) - - # The context_files - if vision_enabled and context_files: - file_prompts = [] - for file in context_files: - file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) - file_prompts.append(file_prompt) - # If last prompt is a user prompt, add files into its contents, - # otherwise append a new user prompt - if ( - len(prompt_messages) > 0 - and isinstance(prompt_messages[-1], UserPromptMessage) - and isinstance(prompt_messages[-1].content, list) - ): - prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) - else: - prompt_messages.append(UserPromptMessage(content=file_prompts)) - - # Remove empty messages and filter unsupported content - filtered_prompt_messages = [] - for prompt_message in prompt_messages: - if isinstance(prompt_message.content, list): - prompt_message_content: list[PromptMessageContentUnionTypes] = [] - for content_item in prompt_message.content: - # Skip content if features are not defined - if not model_schema.features: - if content_item.type != PromptMessageContentType.TEXT: - continue - prompt_message_content.append(content_item) - continue - - # Skip content if corresponding feature is not supported - if ( - ( - content_item.type == PromptMessageContentType.IMAGE - and ModelFeature.VISION not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.DOCUMENT - and ModelFeature.DOCUMENT not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.VIDEO - and ModelFeature.VIDEO not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.AUDIO - and ModelFeature.AUDIO not in model_schema.features - ) - ): - continue - prompt_message_content.append(content_item) - if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: - prompt_message.content = prompt_message_content[0].data - else: - prompt_message.content = prompt_message_content - if prompt_message.is_empty(): - continue - filtered_prompt_messages.append(prompt_message) - - if len(filtered_prompt_messages) == 0: - raise NoPromptFoundError( - "No prompt found in the LLM configuration. " - "Please ensure a prompt is properly configured before proceeding." - ) - - return filtered_prompt_messages, stop - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: LLMNodeData, - ) -> Mapping[str, Sequence[str]]: - # graph_config is not used in this node type - _ = graph_config # Explicitly mark as unused - prompt_template = node_data.prompt_template - variable_selectors = [] - if isinstance(prompt_template, list): - for prompt in prompt_template: - if prompt.edition_type != "jinja2": - variable_template_parser = VariableTemplateParser(template=prompt.text) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - if prompt_template.edition_type != "jinja2": - variable_template_parser = VariableTemplateParser(template=prompt_template.text) - variable_selectors = variable_template_parser.extract_variable_selectors() - else: - raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") - - variable_mapping: dict[str, Any] = {} - for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - memory = node_data.memory - if memory and memory.query_prompt_template: - query_variable_selectors = VariableTemplateParser( - template=memory.query_prompt_template - ).extract_variable_selectors() - for variable_selector in query_variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - if node_data.context.enabled: - variable_mapping["#context#"] = node_data.context.variable_selector - - if node_data.vision.enabled: - variable_mapping["#files#"] = node_data.vision.configs.variable_selector - - if node_data.prompt_config: - enable_jinja = False - - if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - if prompt_template.edition_type == "jinja2": - enable_jinja = True - else: - for prompt in prompt_template: - if prompt.edition_type == "jinja2": - enable_jinja = True - break - - if enable_jinja: - for variable_selector in node_data.prompt_config.jinja2_variables or []: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} - - return variable_mapping - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return { - "type": "llm", - "config": { - "prompt_templates": { - "chat_model": { - "prompts": [ - {"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"} - ] - }, - "completion_model": { - "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, - "prompt": { - "text": "Here are the chat histories between human and assistant, inside " - " XML tags.\n\n\n{{" - "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", - "edition_type": "basic", - }, - "stop": ["Human:"], - }, - } - }, - } - - @staticmethod - def handle_list_messages( - *, - messages: Sequence[LLMNodeChatModelMessage], - context: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - vision_detail_config: ImagePromptMessageContent.DETAIL, - jinja2_template_renderer: Jinja2TemplateRenderer | None = None, - ) -> Sequence[PromptMessage]: - prompt_messages: list[PromptMessage] = [] - for message in messages: - if message.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=message.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - jinja2_template_renderer=jinja2_template_renderer, - ) - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], role=message.role - ) - prompt_messages.append(prompt_message) - else: - # Get segment group from basic message - template = message.text.replace(llm_utils.CONTEXT_PLACEHOLDER, context) - segment_group = variable_pool.convert_template(template) - - # Process segments for images - file_contents = [] - for segment in segment_group.value: - if isinstance(segment, ArrayFileSegment): - for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - elif isinstance(segment, FileSegment): - file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - - # Create message with text from all segments - plain_text = segment_group.text - if plain_text: - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=plain_text)], role=message.role - ) - prompt_messages.append(prompt_message) - - if file_contents: - # Create message with image contents - prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) - prompt_messages.append(prompt_message) - - return prompt_messages - - @staticmethod - def handle_blocking_result( - *, - invoke_result: LLMResult | LLMResultWithStructuredOutput, - saver: LLMFileSaver, - file_outputs: list[File], - reasoning_format: Literal["separated", "tagged"] = "tagged", - request_latency: float | None = None, - ) -> ModelInvokeCompletedEvent: - buffer = io.StringIO() - for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( - contents=invoke_result.message.content, - file_saver=saver, - file_outputs=file_outputs, - ): - buffer.write(text_part) - - # Extract reasoning content from tags in the main text - full_text = buffer.getvalue() - - if reasoning_format == "tagged": - # Keep tags in text for backward compatibility - clean_text = full_text - reasoning_content = "" - else: - # Extract clean text and reasoning from tags - clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) - - event = ModelInvokeCompletedEvent( - # Use clean_text for separated mode, full_text for tagged mode - text=clean_text if reasoning_format == "separated" else full_text, - usage=invoke_result.usage, - finish_reason=None, - # Reasoning content for workflow variables and downstream nodes - reasoning_content=reasoning_content, - # Pass structured output if enabled - structured_output=getattr(invoke_result, "structured_output", None), - ) - if request_latency is not None: - event.usage.latency = round(request_latency, 3) - return event - - @staticmethod - def save_multimodal_image_output( - *, - content: ImagePromptMessageContent, - file_saver: LLMFileSaver, - ) -> File: - """_save_multimodal_output saves multi-modal contents generated by LLM plugins. - - There are two kinds of multimodal outputs: - - - Inlined data encoded in base64, which would be saved to storage directly. - - Remote files referenced by an url, which would be downloaded and then saved to storage. - - Currently, only image files are supported. - """ - if content.url != "": - saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE) - else: - saved_file = file_saver.save_binary_string( - data=base64.b64decode(content.base64_data), - mime_type=content.mime_type, - file_type=FileType.IMAGE, - ) - return saved_file - - @staticmethod - def fetch_structured_output_schema( - *, - structured_output: Mapping[str, Any], - ) -> dict[str, Any]: - """ - Fetch the structured output schema from the node data. - - Returns: - dict[str, Any]: The structured output schema - """ - if not structured_output: - raise LLMNodeError("Please provide a valid structured output schema") - structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False) - if not structured_output_schema: - raise LLMNodeError("Please provide a valid structured output schema") - - try: - schema = json.loads(structured_output_schema) - if not isinstance(schema, dict): - raise LLMNodeError("structured_output_schema must be a JSON object") - return schema - except json.JSONDecodeError: - raise LLMNodeError("structured_output_schema is not valid JSON format") - - @staticmethod - def _save_multimodal_output_and_convert_result_to_markdown( - *, - contents: str | list[PromptMessageContentUnionTypes] | None, - file_saver: LLMFileSaver, - file_outputs: list[File], - ) -> Generator[str, None, None]: - """Convert intermediate prompt messages into strings and yield them to the caller. - - If the messages contain non-textual content (e.g., multimedia like images or videos), - it will be saved separately, and the corresponding Markdown representation will - be yielded to the caller. - """ - - # NOTE(QuantumGhost): This function should yield results to the caller immediately - # whenever new content or partial content is available. Avoid any intermediate buffering - # of results. Additionally, do not yield empty strings; instead, yield from an empty list - # if necessary. - if contents is None: - yield from [] - return - if isinstance(contents, str): - yield contents - else: - for item in contents: - if isinstance(item, TextPromptMessageContent): - yield item.data - elif isinstance(item, ImagePromptMessageContent): - file = LLMNode.save_multimodal_image_output( - content=item, - file_saver=file_saver, - ) - file_outputs.append(file) - yield LLMNode._image_file_to_markdown(file) - else: - logger.warning("unknown item type encountered, type=%s", type(item)) - yield str(item) - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled - - @property - def model_instance(self) -> PreparedLLMProtocol: - return self._model_instance - - -def _combine_message_content_with_role( - *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole -): - match role: - case PromptMessageRole.USER: - return UserPromptMessage(content=contents) - case PromptMessageRole.ASSISTANT: - return AssistantPromptMessage(content=contents) - case PromptMessageRole.SYSTEM: - return SystemPromptMessage(content=contents) - case _: - raise NotImplementedError(f"Role {role} is not supported") - - -def _render_jinja2_message( - *, - template: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - jinja2_template_renderer: Jinja2TemplateRenderer | None, -): - if not template: - return "" - - jinja2_inputs = {} - for jinja2_variable in jinja2_variables: - variable = variable_pool.get(jinja2_variable.value_selector) - jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" - if jinja2_template_renderer is None: - raise TemplateRenderError("LLMNode requires an injected jinja2_template_renderer for jinja2 prompts.") - return jinja2_template_renderer.render_template(template, jinja2_inputs) - - -def _calculate_rest_token( - *, - prompt_messages: list[PromptMessage], - model_instance: PreparedLLMProtocol, -) -> int: - rest_tokens = 2000 - runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - runtime_model_parameters = model_instance.parameters - - model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in runtime_model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - runtime_model_parameters.get(parameter_rule.name) - or runtime_model_parameters.get(str(parameter_rule.use_template)) - or 0 - ) - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - -def _handle_memory_chat_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: PreparedLLMProtocol, -) -> Sequence[PromptMessage]: - memory_messages: Sequence[PromptMessage] = [] - # Get messages from memory for chat model - if memory and memory_config: - rest_tokens = _calculate_rest_token( - prompt_messages=[], - model_instance=model_instance, - ) - memory_messages = memory.get_history_prompt_messages( - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - ) - return memory_messages - - -def _handle_memory_completion_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: PreparedLLMProtocol, -) -> str: - memory_text = "" - # Get history text from memory for completion model - if memory and memory_config: - rest_tokens = _calculate_rest_token( - prompt_messages=[], - model_instance=model_instance, - ) - if not memory_config.role_prefix: - raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") - memory_text = llm_utils.fetch_memory_text( - memory=memory, - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - human_prefix=memory_config.role_prefix.user, - ai_prefix=memory_config.role_prefix.assistant, - ) - return memory_text - - -def _handle_completion_template( - *, - template: LLMNodeCompletionModelPromptTemplate, - context: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - jinja2_template_renderer: Jinja2TemplateRenderer | None = None, -) -> Sequence[PromptMessage]: - """Handle completion template processing outside of LLMNode class. - - Args: - template: The completion model prompt template - context: Context string - jinja2_variables: Variables for jinja2 template rendering - variable_pool: Variable pool for template conversion - - Returns: - Sequence of prompt messages - """ - prompt_messages = [] - if template.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=template.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - jinja2_template_renderer=jinja2_template_renderer, - ) - else: - template_text = template.text.replace(llm_utils.CONTEXT_PLACEHOLDER, context) - result_text = variable_pool.convert_template(template_text).text - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER - ) - prompt_messages.append(prompt_message) - return prompt_messages diff --git a/api/graphon/nodes/llm/protocols.py b/api/graphon/nodes/llm/protocols.py deleted file mode 100644 index 65bfd533d1..0000000000 --- a/api/graphon/nodes/llm/protocols.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - -from typing import Any, Protocol - -from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol - - -class CredentialsProvider(Protocol): - """Port for loading runtime credentials for a provider/model pair.""" - - def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: - """Return credentials for the target provider/model or raise a domain error.""" - ... - - -class ModelFactory(Protocol): - """Port for creating prepared graph-facing LLM runtimes for execution.""" - - def init_model_instance(self, provider_name: str, model_name: str) -> PreparedLLMProtocol: - """Create a prepared LLM runtime that is ready for graph execution.""" - ... diff --git a/api/graphon/nodes/llm/runtime_protocols.py b/api/graphon/nodes/llm/runtime_protocols.py deleted file mode 100644 index dbe415d363..0000000000 --- a/api/graphon/nodes/llm/runtime_protocols.py +++ /dev/null @@ -1,77 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator, Mapping, Sequence -from typing import Any, Protocol - -from graphon.file import File -from graphon.model_runtime.entities import LLMMode, PromptMessage -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from graphon.model_runtime.entities.message_entities import PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity - - -class PreparedLLMProtocol(Protocol): - """A graph-facing LLM runtime with provider-specific setup already applied.""" - - @property - def provider(self) -> str: ... - - @property - def model_name(self) -> str: ... - - @property - def parameters(self) -> Mapping[str, Any]: ... - - @parameters.setter - def parameters(self, value: Mapping[str, Any]) -> None: ... - - @property - def stop(self) -> Sequence[str] | None: ... - - def get_model_schema(self) -> AIModelEntity: ... - - def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: ... - - def invoke_llm( - self, - *, - prompt_messages: Sequence[PromptMessage], - model_parameters: Mapping[str, Any], - tools: Sequence[PromptMessageTool] | None, - stop: Sequence[str] | None, - stream: bool, - ) -> LLMResult | Generator[LLMResultChunk, None, None]: ... - - def invoke_llm_with_structured_output( - self, - *, - prompt_messages: Sequence[PromptMessage], - json_schema: Mapping[str, Any], - model_parameters: Mapping[str, Any], - stop: Sequence[str] | None, - stream: bool, - ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... - - def is_structured_output_parse_error(self, error: Exception) -> bool: ... - - -class PromptMessageSerializerProtocol(Protocol): - """Port for converting compiled prompt messages into persisted process data.""" - - def serialize( - self, - *, - model_mode: LLMMode, - prompt_messages: Sequence[PromptMessage], - ) -> Any: ... - - -class RetrieverAttachmentLoaderProtocol(Protocol): - """Port for resolving retriever segment attachments into graph file references.""" - - def load(self, *, segment_id: str) -> Sequence[File]: ... diff --git a/api/graphon/nodes/loop/__init__.py b/api/graphon/nodes/loop/__init__.py deleted file mode 100644 index 9fe695607b..0000000000 --- a/api/graphon/nodes/loop/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .entities import LoopNodeData -from .loop_end_node import LoopEndNode -from .loop_node import LoopNode -from .loop_start_node import LoopStartNode - -__all__ = ["LoopEndNode", "LoopNode", "LoopNodeData", "LoopStartNode"] diff --git a/api/graphon/nodes/loop/entities.py b/api/graphon/nodes/loop/entities.py deleted file mode 100644 index e7362769e9..0000000000 --- a/api/graphon/nodes/loop/entities.py +++ /dev/null @@ -1,107 +0,0 @@ -from enum import StrEnum -from typing import Annotated, Any, Literal - -from pydantic import AfterValidator, BaseModel, Field, field_validator - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base import BaseLoopNodeData, BaseLoopState -from graphon.utils.condition.entities import Condition -from graphon.variables.types import SegmentType - -_VALID_VAR_TYPE = frozenset( - [ - SegmentType.STRING, - SegmentType.NUMBER, - SegmentType.OBJECT, - SegmentType.BOOLEAN, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_BOOLEAN, - ] -) - - -def _is_valid_var_type(seg_type: SegmentType) -> SegmentType: - if seg_type not in _VALID_VAR_TYPE: - raise ValueError(...) - return seg_type - - -class LoopVariableData(BaseModel): - """ - Loop Variable Data. - """ - - label: str - var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] - value_type: Literal["variable", "constant"] - value: Any | list[str] | None = None - - -class LoopNodeData(BaseLoopNodeData): - type: NodeType = BuiltinNodeTypes.LOOP - loop_count: int # Maximum number of loops - break_conditions: list[Condition] # Conditions to break the loop - logical_operator: Literal["and", "or"] - loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData]) - outputs: dict[str, Any] = Field(default_factory=dict) - - @field_validator("outputs", mode="before") - @classmethod - def validate_outputs(cls, v): - if v is None: - return {} - return v - - -class LoopStartNodeData(BaseNodeData): - """ - Loop Start Node Data. - """ - - type: NodeType = BuiltinNodeTypes.LOOP_START - - -class LoopEndNodeData(BaseNodeData): - """ - Loop End Node Data. - """ - - type: NodeType = BuiltinNodeTypes.LOOP_END - - -class LoopState(BaseLoopState): - """ - Loop State. - """ - - outputs: list[Any] = Field(default_factory=list) - current_output: Any = None - - class MetaData(BaseLoopState.MetaData): - """ - Data. - """ - - loop_length: int - - def get_last_output(self) -> Any: - """ - Get last output. - """ - if self.outputs: - return self.outputs[-1] - return None - - def get_current_output(self) -> Any: - """ - Get current output. - """ - return self.current_output - - -class LoopCompletedReason(StrEnum): - LOOP_BREAK = "loop_break" - LOOP_COMPLETED = "loop_completed" diff --git a/api/graphon/nodes/loop/loop_end_node.py b/api/graphon/nodes/loop/loop_end_node.py deleted file mode 100644 index c0562b59c4..0000000000 --- a/api/graphon/nodes/loop/loop_end_node.py +++ /dev/null @@ -1,22 +0,0 @@ -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.loop.entities import LoopEndNodeData - - -class LoopEndNode(Node[LoopEndNodeData]): - """ - Loop End Node. - """ - - node_type = BuiltinNodeTypes.LOOP_END - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run the node. - """ - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/graphon/nodes/loop/loop_node.py b/api/graphon/nodes/loop/loop_node.py deleted file mode 100644 index d574e9f7ae..0000000000 --- a/api/graphon/nodes/loop/loop_node.py +++ /dev/null @@ -1,428 +0,0 @@ -import contextlib -import json -import logging -from collections.abc import Callable, Generator, Mapping, Sequence -from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, Literal, cast - -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.enums import ( - BuiltinNodeTypes, - NodeExecutionType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.graph_events import ( - GraphNodeEventBase, - GraphRunAbortedEvent, - GraphRunFailedEvent, - NodeRunSucceededEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import ( - LoopFailedEvent, - LoopNextEvent, - LoopStartedEvent, - LoopSucceededEvent, - NodeEventBase, - NodeRunResult, - StreamCompletedEvent, -) -from graphon.nodes.base import LLMUsageTrackingMixin -from graphon.nodes.base.node import Node -from graphon.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData -from graphon.utils.condition.processor import ConditionProcessor -from graphon.variables import Segment, SegmentType, TypeMismatchError, build_segment_with_type, segment_to_variable - -if TYPE_CHECKING: - from graphon.graph_engine import GraphEngine - -logger = logging.getLogger(__name__) -_DEFAULT_CHILD_ABORT_REASON = "child graph aborted" - - -class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): - """ - Loop Node. - """ - - node_type = BuiltinNodeTypes.LOOP - execution_type = NodeExecutionType.CONTAINER - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator: - """Run the node.""" - # Get inputs - loop_count = self.node_data.loop_count - break_conditions = self.node_data.break_conditions - logical_operator = self.node_data.logical_operator - - inputs = {"loop_count": loop_count} - - if not self.node_data.start_node_id: - raise ValueError(f"field start_node_id in loop {self._node_id} not found") - - root_node_id = self.node_data.start_node_id - - # Initialize loop variables in the original variable pool - loop_variable_selectors = {} - if self.node_data.loop_variables: - value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = { - "constant": lambda var: self._get_segment_for_constant(var.var_type, var.value), - "variable": lambda var: ( - self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None - ), - } - for loop_variable in self.node_data.loop_variables: - if loop_variable.value_type not in value_processor: - raise ValueError( - f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}" - ) - - processed_segment = value_processor[loop_variable.value_type](loop_variable) - if not processed_segment: - raise ValueError(f"Invalid value for loop variable {loop_variable.label}") - variable_selector = [self._node_id, loop_variable.label] - variable = segment_to_variable(segment=processed_segment, selector=variable_selector) - self.graph_runtime_state.variable_pool.add(variable_selector, variable.value) - loop_variable_selectors[loop_variable.label] = variable_selector - inputs[loop_variable.label] = processed_segment.value - - start_at = datetime.now(UTC).replace(tzinfo=None) - condition_processor = ConditionProcessor() - - loop_duration_map: dict[str, float] = {} - single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output - loop_usage = LLMUsage.empty_usage() - loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id) - - # Start Loop event - yield LoopStartedEvent( - start_at=start_at, - inputs=inputs, - metadata={"loop_length": loop_count}, - ) - - try: - reach_break_condition = False - if break_conditions: - with contextlib.suppress(ValueError): - _, _, reach_break_condition = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) - - if reach_break_condition: - loop_count = 0 - - for i in range(loop_count): - # Clear stale variables from previous loop iterations to avoid streaming old values - self._clear_loop_subgraph_variables(loop_node_ids) - graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id) - - loop_start_time = datetime.now(UTC).replace(tzinfo=None) - try: - reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i) - finally: - loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) - # Track loop duration - loop_duration_map[str(i)] = (datetime.now(UTC).replace(tzinfo=None) - loop_start_time).total_seconds() - - # Accumulate outputs from the sub-graph's response nodes - for key, value in graph_engine.graph_runtime_state.outputs.items(): - if key == "answer": - # Concatenate answer outputs with newline - existing_answer = self.graph_runtime_state.get_output("answer", "") - if existing_answer: - self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}") - else: - self.graph_runtime_state.set_output("answer", value) - else: - # For other outputs, just update - self.graph_runtime_state.set_output(key, value) - - # Collect loop variable values after iteration - single_loop_variable = {} - for key, selector in loop_variable_selectors.items(): - segment = self.graph_runtime_state.variable_pool.get(selector) - single_loop_variable[key] = segment.value if segment else None - - single_loop_variable_map[str(i)] = single_loop_variable - - if reach_break_node: - break - - if break_conditions: - _, _, reach_break_condition = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) - if reach_break_condition: - break - - yield LoopNextEvent( - index=i + 1, - pre_loop_output=self.node_data.outputs, - ) - - self._accumulate_usage(loop_usage) - # Loop completed successfully - yield LoopSucceededEvent( - start_at=start_at, - inputs=inputs, - outputs=self.node_data.outputs, - steps=loop_count, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - WorkflowNodeExecutionMetadataKey.COMPLETED_REASON: ( - LoopCompletedReason.LOOP_BREAK - if reach_break_condition - else LoopCompletedReason.LOOP_COMPLETED.value - ), - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - outputs=self.node_data.outputs, - inputs=inputs, - llm_usage=loop_usage, - ) - ) - - except Exception as e: - self._accumulate_usage(loop_usage) - yield LoopFailedEvent( - start_at=start_at, - inputs=inputs, - steps=loop_count, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - "completed_reason": "error", - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - error=str(e), - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - llm_usage=loop_usage, - ) - ) - - def _run_single_loop( - self, - *, - graph_engine: "GraphEngine", - current_index: int, - ) -> Generator[NodeEventBase | GraphNodeEventBase, None, bool]: - reach_break_node = False - for event in graph_engine.run(): - if isinstance(event, GraphNodeEventBase): - self._append_loop_info_to_event(event=event, loop_run_index=current_index) - - if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.LOOP_START: - continue - if isinstance(event, GraphNodeEventBase): - yield event - if isinstance(event, NodeRunSucceededEvent) and event.node_type == BuiltinNodeTypes.LOOP_END: - reach_break_node = True - if isinstance(event, GraphRunAbortedEvent): - raise RuntimeError(event.reason or _DEFAULT_CHILD_ABORT_REASON) - if isinstance(event, GraphRunFailedEvent): - raise Exception(event.error) - - for loop_var in self.node_data.loop_variables or []: - key, sel = loop_var.label, [self._node_id, loop_var.label] - segment = self.graph_runtime_state.variable_pool.get(sel) - self.node_data.outputs[key] = segment.value if segment else None - self.node_data.outputs["loop_round"] = current_index + 1 - - return reach_break_node - - def _append_loop_info_to_event( - self, - event: GraphNodeEventBase, - loop_run_index: int, - ): - event.in_loop_id = self._node_id - loop_metadata = { - WorkflowNodeExecutionMetadataKey.LOOP_ID: self._node_id, - WorkflowNodeExecutionMetadataKey.LOOP_INDEX: loop_run_index, - } - - current_metadata = event.node_run_result.metadata - if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata: - event.node_run_result.metadata = {**current_metadata, **loop_metadata} - - def _clear_loop_subgraph_variables(self, loop_node_ids: set[str]) -> None: - """ - Remove variables produced by loop sub-graph nodes from previous iterations. - - Keeping stale variables causes a freshly created response coordinator in the - next iteration to fall back to outdated values when no stream chunks exist. - """ - variable_pool = self.graph_runtime_state.variable_pool - for node_id in loop_node_ids: - variable_pool.remove([node_id]) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: LoopNodeData, - ) -> Mapping[str, Sequence[str]]: - variable_mapping = {} - - # Extract loop node IDs statically from graph_config - - loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id) - - # Get node configs from graph_config - node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} - for sub_node_id, sub_node_config in node_configs.items(): - if sub_node_config.get("data", {}).get("loop_id") != node_id: - continue - - # variable selector to variable mapping - try: - typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) - node_type = typed_sub_node_config["data"].type - node_mapping = Node.get_node_type_classes_mapping() - if node_type not in node_mapping: - continue - node_version = str(typed_sub_node_config["data"].version) - node_cls = node_mapping[node_type][node_version] - - sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=typed_sub_node_config - ) - sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) - except NotImplementedError: - sub_node_variable_mapping = {} - - # remove loop variables - sub_node_variable_mapping = { - sub_node_id + "." + key: value - for key, value in sub_node_variable_mapping.items() - if value[0] != node_id - } - - variable_mapping.update(sub_node_variable_mapping) - - for loop_variable in node_data.loop_variables or []: - if loop_variable.value_type == "variable": - assert loop_variable.value is not None, "Loop variable value must be provided for variable type" - # add loop variable to variable mapping - selector = loop_variable.value - variable_mapping[f"{node_id}.{loop_variable.label}"] = selector - - # remove variable out from loop - variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids} - - return variable_mapping - - @classmethod - def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]: - """ - Extract node IDs that belong to a specific loop from graph configuration. - - This method statically analyzes the graph configuration to find all nodes - that are part of the specified loop, without creating actual node instances. - - :param graph_config: the complete graph configuration - :param loop_node_id: the ID of the loop node - :return: set of node IDs that belong to the loop - """ - loop_node_ids = set() - - # Find all nodes that belong to this loop - nodes = graph_config.get("nodes", []) - for node in nodes: - node_data = node.get("data", {}) - if node_data.get("loop_id") == loop_node_id: - node_id = node.get("id") - if node_id: - loop_node_ids.add(node_id) - - return loop_node_ids - - @staticmethod - def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment: - """Get the appropriate segment type for a constant value.""" - # TODO: Refactor for maintainability: - # 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py) - # 2. Consider moving this method to LoopVariableData class for better encapsulation - if not var_type.is_array_type() or var_type == SegmentType.ARRAY_BOOLEAN: - value = original_value - elif var_type in [ - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_STRING, - ]: - if original_value and isinstance(original_value, str): - value = json.loads(original_value) - else: - logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type) - value = [] - else: - raise AssertionError("this statement should be unreachable.") - try: - return build_segment_with_type(var_type, value=value) - except TypeMismatchError as type_exc: - # Attempt to parse the value as a JSON-encoded string, if applicable. - if not isinstance(original_value, str): - raise - try: - value = json.loads(original_value) - except ValueError: - raise type_exc - return build_segment_with_type(var_type, value) - - def _create_graph_engine(self, start_at: datetime, root_node_id: str): - from graphon.entities import GraphInitParams - - # Create GraphInitParams for child graph execution. - graph_init_params = GraphInitParams( - workflow_id=self.workflow_id, - graph_config=self.graph_config, - run_context=self.run_context, - call_depth=self.workflow_call_depth, - ) - - return self.graph_runtime_state.create_child_engine( - workflow_id=self.workflow_id, - graph_init_params=graph_init_params, - root_node_id=root_node_id, - ) diff --git a/api/graphon/nodes/loop/loop_start_node.py b/api/graphon/nodes/loop/loop_start_node.py deleted file mode 100644 index 2b17054ae2..0000000000 --- a/api/graphon/nodes/loop/loop_start_node.py +++ /dev/null @@ -1,22 +0,0 @@ -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.loop.entities import LoopStartNodeData - - -class LoopStartNode(Node[LoopStartNodeData]): - """ - Loop Start Node. - """ - - node_type = BuiltinNodeTypes.LOOP_START - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run the node. - """ - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/graphon/nodes/parameter_extractor/__init__.py b/api/graphon/nodes/parameter_extractor/__init__.py deleted file mode 100644 index bdbf19a7d3..0000000000 --- a/api/graphon/nodes/parameter_extractor/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .parameter_extractor_node import ParameterExtractorNode - -__all__ = ["ParameterExtractorNode"] diff --git a/api/graphon/nodes/parameter_extractor/entities.py b/api/graphon/nodes/parameter_extractor/entities.py deleted file mode 100644 index 8fda1b9e79..0000000000 --- a/api/graphon/nodes/parameter_extractor/entities.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Annotated, Any, Literal - -from pydantic import ( - BaseModel, - BeforeValidator, - Field, - field_validator, -) - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.llm.entities import ModelConfig, VisionConfig -from graphon.prompt_entities import MemoryConfig -from graphon.variables.types import SegmentType - -_OLD_BOOL_TYPE_NAME = "bool" -_OLD_SELECT_TYPE_NAME = "select" - -_VALID_PARAMETER_TYPES = frozenset( - [ - SegmentType.STRING, # "string", - SegmentType.NUMBER, # "number", - SegmentType.BOOLEAN, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_BOOLEAN, - _OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node - _OLD_SELECT_TYPE_NAME, # string type with enumeration choices. - ] -) - - -def _validate_type(parameter_type: str) -> SegmentType: - if parameter_type not in _VALID_PARAMETER_TYPES: - raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.") - - if parameter_type == _OLD_BOOL_TYPE_NAME: - return SegmentType.BOOLEAN - elif parameter_type == _OLD_SELECT_TYPE_NAME: - return SegmentType.STRING - return SegmentType(parameter_type) - - -class ParameterConfig(BaseModel): - """ - Parameter Config. - """ - - name: str - type: Annotated[SegmentType, BeforeValidator(_validate_type)] - options: list[str] | None = None - description: str - required: bool - - @field_validator("name", mode="before") - @classmethod - def validate_name(cls, value) -> str: - if not value: - raise ValueError("Parameter name is required") - if value in {"__reason", "__is_success"}: - raise ValueError("Invalid parameter name, __reason and __is_success are reserved") - return str(value) - - def is_array_type(self) -> bool: - return self.type.is_array_type() - - def element_type(self) -> SegmentType: - """Return the element type of the parameter. - - Raises a ValueError if the parameter's type is not an array type. - """ - element_type = self.type.element_type() - # At this point, self.type is guaranteed to be one of `ARRAY_STRING`, - # `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`. - # - # See: _VALID_PARAMETER_TYPES for reference. - assert element_type is not None, f"the element type should not be None, {self.type=}" - return element_type - - -class ParameterExtractorNodeData(BaseNodeData): - """ - Parameter Extractor Node Data. - """ - - type: NodeType = BuiltinNodeTypes.PARAMETER_EXTRACTOR - model: ModelConfig - query: list[str] - parameters: list[ParameterConfig] - instruction: str | None = None - memory: MemoryConfig | None = None - reasoning_mode: Literal["function_call", "prompt"] - vision: VisionConfig = Field(default_factory=VisionConfig) - - @field_validator("reasoning_mode", mode="before") - @classmethod - def set_reasoning_mode(cls, v) -> str: - return v or "function_call" - - def get_parameter_json_schema(self): - """ - Get parameter json schema. - - :return: parameter json schema - """ - parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []} - - for parameter in self.parameters: - parameter_schema: dict[str, Any] = {"description": parameter.description} - - if parameter.type == SegmentType.STRING: - parameter_schema["type"] = "string" - elif parameter.type.is_array_type(): - parameter_schema["type"] = "array" - element_type = parameter.type.element_type() - if element_type is None: - raise AssertionError("element type should not be None.") - parameter_schema["items"] = {"type": element_type.value} - else: - parameter_schema["type"] = parameter.type - - if parameter.options: - parameter_schema["enum"] = parameter.options - - parameters["properties"][parameter.name] = parameter_schema - - if parameter.required: - parameters["required"].append(parameter.name) - - return parameters diff --git a/api/graphon/nodes/parameter_extractor/exc.py b/api/graphon/nodes/parameter_extractor/exc.py deleted file mode 100644 index faa90313c1..0000000000 --- a/api/graphon/nodes/parameter_extractor/exc.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Any - -from graphon.variables.types import SegmentType - - -class ParameterExtractorNodeError(ValueError): - """Base error for ParameterExtractorNode.""" - - -class InvalidModelTypeError(ParameterExtractorNodeError): - """Raised when the model is not a Large Language Model.""" - - -class ModelSchemaNotFoundError(ParameterExtractorNodeError): - """Raised when the model schema is not found.""" - - -class InvalidInvokeResultError(ParameterExtractorNodeError): - """Raised when the invoke result is invalid.""" - - -class InvalidTextContentTypeError(ParameterExtractorNodeError): - """Raised when the text content type is invalid.""" - - -class InvalidNumberOfParametersError(ParameterExtractorNodeError): - """Raised when the number of parameters is invalid.""" - - -class RequiredParameterMissingError(ParameterExtractorNodeError): - """Raised when a required parameter is missing.""" - - -class InvalidSelectValueError(ParameterExtractorNodeError): - """Raised when a select value is invalid.""" - - -class InvalidNumberValueError(ParameterExtractorNodeError): - """Raised when a number value is invalid.""" - - -class InvalidBoolValueError(ParameterExtractorNodeError): - """Raised when a bool value is invalid.""" - - -class InvalidStringValueError(ParameterExtractorNodeError): - """Raised when a string value is invalid.""" - - -class InvalidArrayValueError(ParameterExtractorNodeError): - """Raised when an array value is invalid.""" - - -class InvalidModelModeError(ParameterExtractorNodeError): - """Raised when the model mode is invalid.""" - - -class InvalidValueTypeError(ParameterExtractorNodeError): - def __init__( - self, - /, - parameter_name: str, - expected_type: SegmentType, - actual_type: SegmentType | None, - value: Any, - ): - message = ( - f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, " - f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}" - ) - super().__init__(message) - self.parameter_name = parameter_name - self.expected_type = expected_type - self.actual_type = actual_type - self.value = value diff --git a/api/graphon/nodes/parameter_extractor/parameter_extractor_node.py b/api/graphon/nodes/parameter_extractor/parameter_extractor_node.py deleted file mode 100644 index 25379e325c..0000000000 --- a/api/graphon/nodes/parameter_extractor/parameter_extractor_node.py +++ /dev/null @@ -1,846 +0,0 @@ -import contextlib -import json -import logging -import uuid -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import File -from graphon.model_runtime.entities import ImagePromptMessageContent, LLMMode -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageRole, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.node_events import NodeRunResult -from graphon.nodes.base import variable_template_parser -from graphon.nodes.base.node import Node -from graphon.nodes.llm import LLMNode, llm_utils -from graphon.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate -from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol -from graphon.runtime import VariablePool -from graphon.variables import build_segment_with_type -from graphon.variables.types import ArrayValidation, SegmentType - -from .entities import ParameterExtractorNodeData -from .exc import ( - InvalidModelModeError, - InvalidModelTypeError, - InvalidNumberOfParametersError, - InvalidSelectValueError, - InvalidTextContentTypeError, - InvalidValueTypeError, - ModelSchemaNotFoundError, - ParameterExtractorNodeError, - RequiredParameterMissingError, -) -from .prompts import ( - CHAT_EXAMPLE, - CHAT_GENERATE_JSON_PROMPT, - CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, - COMPLETION_GENERATE_JSON_PROMPT, - FUNCTION_CALLING_EXTRACTOR_EXAMPLE, - FUNCTION_CALLING_EXTRACTOR_NAME, - FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT, - FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE, -) - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - - -def extract_json(text): - """ - From a given JSON started from '{' or '[' extract the complete JSON object. - """ - stack = [] - for i, c in enumerate(text): - if c in {"{", "["}: - stack.append(c) - elif c in {"}", "]"}: - # check if stack is empty - if not stack: - return text[:i] - # check if the last element in stack is matching - if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["): - stack.pop() - if not stack: - return text[: i + 1] - else: - return text[:i] - return None - - -class ParameterExtractorNode(Node[ParameterExtractorNodeData]): - """ - Parameter Extractor Node. - """ - - node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR - - _model_instance: PreparedLLMProtocol - _prompt_message_serializer: PromptMessageSerializerProtocol - _memory: PromptMessageMemory | None - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - credentials_provider: object | None = None, - model_factory: object | None = None, - model_instance: PreparedLLMProtocol, - memory: PromptMessageMemory | None = None, - prompt_message_serializer: PromptMessageSerializerProtocol, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - _ = credentials_provider, model_factory - self._model_instance = model_instance - self._prompt_message_serializer = prompt_message_serializer - self._memory = memory - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return { - "model": { - "prompt_templates": { - "completion_model": { - "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, - "stop": ["Human:"], - } - } - } - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self): - """ - Run the node. - """ - node_data = self.node_data - variable = self.graph_runtime_state.variable_pool.get(node_data.query) - query = variable.text if variable else "" - - variable_pool = self.graph_runtime_state.variable_pool - - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=node_data.vision.configs.variable_selector, - ) - if node_data.vision.enabled - else [] - ) - - model_instance = self._model_instance - # Resolve variable references in string-typed completion params - model_instance.parameters = llm_utils.resolve_completion_params_variables( - model_instance.parameters, variable_pool - ) - try: - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - except ValueError as exc: - raise ModelSchemaNotFoundError("Model schema not found") from exc - if model_schema.model_type != ModelType.LLM: - raise InvalidModelTypeError("Model is not a Large Language Model") - memory = self._memory - - if ( - set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} - and node_data.reasoning_mode == "function_call" - ): - # use function call - prompt_messages, prompt_message_tools = self._generate_function_call_prompt( - node_data=node_data, - query=query, - variable_pool=self.graph_runtime_state.variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=node_data.vision.configs.detail, - ) - else: - # use prompt engineering - prompt_messages = self._generate_prompt_engineering_prompt( - data=node_data, - query=query, - variable_pool=self.graph_runtime_state.variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=node_data.vision.configs.detail, - ) - - prompt_message_tools = [] - - inputs = { - "query": query, - "files": [f.to_dict() for f in files], - "parameters": jsonable_encoder(node_data.parameters), - "instruction": jsonable_encoder(node_data.instruction), - } - - process_data = { - "model_mode": node_data.model.mode, - "prompts": self._prompt_message_serializer.serialize( - model_mode=node_data.model.mode, - prompt_messages=prompt_messages, - ), - "usage": None, - "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), - "tool_call": None, - "model_provider": model_instance.provider, - "model_name": model_instance.model_name, - } - - try: - text, usage, tool_call = self._invoke( - model_instance=model_instance, - prompt_messages=prompt_messages, - tools=prompt_message_tools, - stop=model_instance.stop, - ) - process_data["usage"] = jsonable_encoder(usage) - process_data["tool_call"] = jsonable_encoder(tool_call) - process_data["llm_text"] = text - except ParameterExtractorNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - outputs={"__is_success": 0, "__reason": str(e)}, - error=str(e), - metadata={}, - ) - except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)}, - error=str(e), - metadata={}, - ) - - error = None - - if tool_call: - result = self._extract_json_from_tool_call(tool_call) - else: - result = self._extract_complete_json_response(text) - if not result: - result = self._generate_default_result(node_data) - error = "Failed to extract result from function call or text response, using empty result." - - try: - result = self._validate_result(data=node_data, result=result or {}) - except ParameterExtractorNodeError as e: - error = str(e) - - # transform result into standard format - result = self._transform_result(data=node_data, result=result or {}) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={ - "__is_success": 1 if not error else 0, - "__reason": error, - "__usage": jsonable_encoder(usage), - **result, - }, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - - def _invoke( - self, - model_instance: PreparedLLMProtocol, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - stop: Sequence[str] | None, - ) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]: - invoke_result = cast( - LLMResult, - model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=dict(model_instance.parameters), - tools=tools or None, - stop=stop, - stream=False, - ), - ) - - # handle invoke result - - text = invoke_result.message.get_text_content() - if not isinstance(text, str): - raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.") - - usage = invoke_result.usage - tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None - - return text, usage, tool_call - - def _generate_function_call_prompt( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: PreparedLLMProtocol, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: - """ - Generate function call prompt. - """ - query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format( - content=query, structure=json.dumps(node_data.get_parameter_json_schema()) - ) - - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - context="", - ) - prompt_template = self._get_function_calling_prompt_template( - node_data, query, variable_pool, memory, rest_token - ) - prompt_messages = self._compile_prompt_messages( - model_instance=model_instance, - prompt_template=prompt_template, - files=files, - vision_enabled=node_data.vision.enabled, - image_detail_config=vision_detail, - ) - - # find last user message - last_user_message_idx = -1 - for i, prompt_message in enumerate(prompt_messages): - if prompt_message.role == PromptMessageRole.USER: - last_user_message_idx = i - - # add function call messages before last user message - example_messages = [] - for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE: - id = uuid.uuid4().hex - example_messages.extend( - [ - UserPromptMessage(content=example["user"]["query"]), - AssistantPromptMessage( - content=example["assistant"]["text"], - tool_calls=[ - AssistantPromptMessage.ToolCall( - id=id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=example["assistant"]["function_call"]["name"], - arguments=json.dumps(example["assistant"]["function_call"]["parameters"]), - ), - ) - ], - ), - ToolPromptMessage( - content="Great! You have called the function with the correct parameters.", tool_call_id=id - ), - AssistantPromptMessage( - content="I have extracted the parameters, let's move on.", - ), - ] - ) - - prompt_messages = ( - prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] - ) - - # generate tool - tool = PromptMessageTool( - name=FUNCTION_CALLING_EXTRACTOR_NAME, - description="Extract parameters from the natural language text", - parameters=node_data.get_parameter_json_schema(), - ) - - return prompt_messages, [tool] - - def _generate_prompt_engineering_prompt( - self, - data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: PreparedLLMProtocol, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - """ - Generate prompt engineering prompt. - """ - if data.model.mode == LLMMode.COMPLETION: - return self._generate_prompt_engineering_completion_prompt( - node_data=data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=vision_detail, - ) - if data.model.mode == LLMMode.CHAT: - return self._generate_prompt_engineering_chat_prompt( - node_data=data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=vision_detail, - ) - raise InvalidModelModeError(f"Invalid model mode: {data.model.mode}") - - def _generate_prompt_engineering_completion_prompt( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: PreparedLLMProtocol, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - """ - Generate completion prompt. - """ - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - context="", - ) - prompt_template = self._get_prompt_engineering_prompt_template( - node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token - ) - return self._compile_prompt_messages( - model_instance=model_instance, - prompt_template=prompt_template, - files=files, - vision_enabled=node_data.vision.enabled, - image_detail_config=vision_detail, - ) - - def _generate_prompt_engineering_chat_prompt( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: PreparedLLMProtocol, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - """ - Generate chat prompt. - """ - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - context="", - ) - prompt_template = self._get_prompt_engineering_prompt_template( - node_data=node_data, - query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(node_data.get_parameter_json_schema()), text=query - ), - variable_pool=variable_pool, - memory=memory, - max_token_limit=rest_token, - ) - - prompt_messages = self._compile_prompt_messages( - model_instance=model_instance, - prompt_template=prompt_template, - files=files, - vision_enabled=node_data.vision.enabled, - image_detail_config=vision_detail, - ) - - # find last user message - last_user_message_idx = -1 - for i, prompt_message in enumerate(prompt_messages): - if prompt_message.role == PromptMessageRole.USER: - last_user_message_idx = i - - # add example messages before last user message - example_messages = [] - for example in CHAT_EXAMPLE: - example_messages.extend( - [ - UserPromptMessage( - content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(example["user"]["json"]), - text=example["user"]["query"], - ) - ), - AssistantPromptMessage( - content=json.dumps(example["assistant"]["json"]), - ), - ] - ) - - prompt_messages = ( - prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] - ) - - return prompt_messages - - def _validate_result(self, data: ParameterExtractorNodeData, result: dict): - if len(data.parameters) != len(result): - raise InvalidNumberOfParametersError("Invalid number of parameters") - - for parameter in data.parameters: - if parameter.required and parameter.name not in result: - raise RequiredParameterMissingError(f"Parameter {parameter.name} is required") - - param_value = result.get(parameter.name) - if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL): - inferred_type = SegmentType.infer_segment_type(param_value) - raise InvalidValueTypeError( - parameter_name=parameter.name, - expected_type=parameter.type, - actual_type=inferred_type, - value=param_value, - ) - if parameter.type == SegmentType.STRING and parameter.options: - if param_value not in parameter.options: - raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}") - return result - - @staticmethod - def _transform_number(value: int | float | str | bool) -> int | float | None: - """ - Attempts to transform the input into an integer or float. - - Returns: - int or float: The transformed number if the conversion is successful. - None: If the transformation fails. - - Note: - Boolean values `True` and `False` are converted to integers `1` and `0`, respectively. - This behavior ensures compatibility with existing workflows that may use boolean types as integers. - """ - if isinstance(value, bool): - return int(value) - elif isinstance(value, (int, float)): - return value - elif isinstance(value, str): - if "." in value: - try: - return float(value) - except ValueError: - return None - else: - try: - return int(value) - except ValueError: - return None - else: - return None - - def _transform_result(self, data: ParameterExtractorNodeData, result: dict): - """ - Transform result into standard format. - """ - transformed_result: dict[str, Any] = {} - for parameter in data.parameters: - if parameter.name in result: - param_value = result[parameter.name] - # transform value - if parameter.type == SegmentType.NUMBER: - transformed = self._transform_number(param_value) - if transformed is not None: - transformed_result[parameter.name] = transformed - elif parameter.type == SegmentType.BOOLEAN: - if isinstance(result[parameter.name], (bool, int)): - transformed_result[parameter.name] = bool(result[parameter.name]) - # elif isinstance(result[parameter.name], str): - # if result[parameter.name].lower() in ["true", "false"]: - # transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true") - elif parameter.type == SegmentType.STRING: - if isinstance(param_value, str): - transformed_result[parameter.name] = param_value - elif parameter.is_array_type(): - if isinstance(param_value, list): - nested_type = parameter.element_type() - assert nested_type is not None - segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[]) - transformed_result[parameter.name] = segment_value - for item in param_value: - if nested_type == SegmentType.NUMBER: - transformed = self._transform_number(item) - if transformed is not None: - segment_value.value.append(transformed) - elif nested_type == SegmentType.STRING: - if isinstance(item, str): - segment_value.value.append(item) - elif nested_type == SegmentType.OBJECT: - if isinstance(item, dict): - segment_value.value.append(item) - elif nested_type == SegmentType.BOOLEAN: - if isinstance(item, bool): - segment_value.value.append(item) - - if parameter.name not in transformed_result: - if parameter.type.is_array_type(): - transformed_result[parameter.name] = build_segment_with_type( - segment_type=SegmentType(parameter.type), value=[] - ) - elif parameter.type in (SegmentType.STRING, SegmentType.SECRET): - transformed_result[parameter.name] = "" - elif parameter.type == SegmentType.NUMBER: - transformed_result[parameter.name] = 0 - elif parameter.type == SegmentType.BOOLEAN: - transformed_result[parameter.name] = False - else: - raise AssertionError("this statement should be unreachable.") - - return transformed_result - - def _extract_complete_json_response(self, result: str) -> dict | None: - """ - Extract complete json response. - """ - - # extract json from the text - for idx in range(len(result)): - if result[idx] == "{" or result[idx] == "[": - json_str = extract_json(result[idx:]) - if json_str: - with contextlib.suppress(Exception): - return cast(dict, json.loads(json_str)) - logger.info("extra error: %s", result) - return None - - def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None: - """ - Extract json from tool call. - """ - if not tool_call or not tool_call.function.arguments: - return None - - result = tool_call.function.arguments - # extract json from the arguments - for idx in range(len(result)): - if result[idx] == "{" or result[idx] == "[": - json_str = extract_json(result[idx:]) - if json_str: - with contextlib.suppress(Exception): - return cast(dict, json.loads(json_str)) - - logger.info("extra error: %s", result) - return None - - def _generate_default_result(self, data: ParameterExtractorNodeData): - """ - Generate default result. - """ - result: dict[str, Any] = {} - for parameter in data.parameters: - if parameter.type == "number": - result[parameter.name] = 0 - elif parameter.type == "boolean": - result[parameter.name] = False - elif parameter.type in {"string", "select"}: - result[parameter.name] = "" - - return result - - def _get_function_calling_prompt_template( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - memory: PromptMessageMemory | None, - max_token_limit: int = 2000, - ) -> list[LLMNodeChatModelMessage]: - input_text = query - memory_str = "" - instruction = variable_pool.convert_template(node_data.instruction or "").text - - if memory and node_data.memory and node_data.memory.window: - memory_str = llm_utils.fetch_memory_text( - memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size - ) - if node_data.model.mode == LLMMode.CHAT: - system_prompt_messages = LLMNodeChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), - ) - user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text) - return [system_prompt_messages, user_prompt_message] - raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.") - - def _get_prompt_engineering_prompt_template( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - memory: PromptMessageMemory | None, - max_token_limit: int = 2000, - ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: - input_text = query - memory_str = "" - instruction = variable_pool.convert_template(node_data.instruction or "").text - - if memory and node_data.memory and node_data.memory.window: - memory_str = llm_utils.fetch_memory_text( - memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size - ) - if node_data.model.mode == LLMMode.CHAT: - system_prompt_messages = LLMNodeChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction), - ) - user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text) - return [system_prompt_messages, user_prompt_message] - if node_data.model.mode == LLMMode.COMPLETION: - return LLMNodeCompletionModelPromptTemplate( - text=COMPLETION_GENERATE_JSON_PROMPT.format( - histories=memory_str, text=input_text, instruction=instruction - ) - .replace("{γγγ", "") - .replace("}γγγ", "") - .replace("{ structure }", json.dumps(node_data.get_parameter_json_schema())), - ) - raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.") - - def _calculate_rest_token( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: PreparedLLMProtocol, - context: str | None, - ) -> int: - try: - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - except ValueError as exc: - raise ModelSchemaNotFoundError("Model schema not found") from exc - - prompt_template: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: - prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) - else: - prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) - - prompt_messages = self._compile_prompt_messages( - model_instance=model_instance, - prompt_template=prompt_template, - files=[], - vision_enabled=False, - context=context, - ) - rest_tokens = 2000 - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + 1000 - - max_tokens = 0 - for parameter_rule in model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_instance.parameters.get(parameter_rule.name) - or model_instance.parameters.get(parameter_rule.use_template or "") - ) or 0 - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - def _compile_prompt_messages( - self, - *, - model_instance: PreparedLLMProtocol, - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, - files: Sequence[File], - vision_enabled: bool, - context: str | None = "", - image_detail_config: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - prompt_messages, _ = LLMNode.fetch_prompt_messages( - sys_query="", - sys_files=files, - context=context or "", - memory=None, - model_instance=model_instance, - prompt_template=prompt_template, - stop=model_instance.stop, - memory_config=None, - vision_enabled=vision_enabled, - vision_detail=image_detail_config or ImagePromptMessageContent.DETAIL.HIGH, - variable_pool=self.graph_runtime_state.variable_pool, - jinja2_variables=[], - ) - return list(prompt_messages) - - @property - def model_instance(self) -> PreparedLLMProtocol: - return self._model_instance - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ParameterExtractorNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} - - if node_data.instruction: - selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) - for selector in selectors: - variable_mapping[selector.variable] = selector.value_selector - - variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} - - return variable_mapping diff --git a/api/graphon/nodes/parameter_extractor/prompts.py b/api/graphon/nodes/parameter_extractor/prompts.py deleted file mode 100644 index 1b29be4418..0000000000 --- a/api/graphon/nodes/parameter_extractor/prompts.py +++ /dev/null @@ -1,184 +0,0 @@ -from typing import Any - -FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters" - -FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. -### Task -Always call the `{FUNCTION_CALLING_EXTRACTOR_NAME}` function with the correct parameters. Ensure that the information extraction is contextual and aligns with the provided criteria. -### Memory -Here is the chat history between the human and assistant, provided within tags: - -\x7bhistories\x7d - -### Instructions: -Some additional information is provided below. Always adhere to these instructions as closely as possible: - -\x7binstruction\x7d - -Steps: -1. Review the chat history provided within the tags. -2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text. -3. Generate a well-formatted output using the defined functions and arguments. -4. Use the `extract_parameter` function to create structured outputs with appropriate parameters. -5. Do not include any XML tags in your output. -### Example -To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples. -### Final Output -Produce well-formatted function calls in json without XML tags, as shown in the example. -""" # noqa: E501 - -FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside XML tags. - -\x7bcontent\x7d - - - -\x7bstructure\x7d - -""" # noqa: E501 - -FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [ - { - "user": { - "query": "What is the weather today in SF?", - "function": { - "name": FUNCTION_CALLING_EXTRACTOR_NAME, - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location to get the weather information", - "required": True, - }, - }, - "required": ["location"], - }, - }, - }, - "assistant": { - "text": "I need always call the function with the correct parameters." - " in this case, I need to call the function with the location parameter.", - "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"location": "San Francisco"}}, - }, - }, - { - "user": { - "query": "I want to eat some apple pie.", - "function": { - "name": FUNCTION_CALLING_EXTRACTOR_NAME, - "parameters": { - "type": "object", - "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, - "required": ["food"], - }, - }, - }, - "assistant": { - "text": "I need always call the function with the correct parameters." - " in this case, I need to call the function with the food parameter.", - "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"food": "apple pie"}}, - }, - }, -] - -COMPLETION_GENERATE_JSON_PROMPT = """### Instructions: -Some extra information are provided below, I should always follow the instructions as possible as I can. - -{instruction} - - -### Extract parameter Workflow -I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted. - -{{ structure }} - - -Step 1: Carefully read the input and understand the structure of the expected output. -Step 2: Extract relevant parameters from the provided text based on the name and description of object. -Step 3: Structure the extracted parameters to JSON object as specified in . -Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted. - -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - - -### Structure -Here is the structure of the expected output, I should always follow the output structure. -{{γγγ - 'properties1': 'relevant text extracted from input', - 'properties2': 'relevant text extracted from input', -}}γγγ - -### Input Text -Inside XML tags, there is a text that I should extract parameters and convert to a JSON object. - -{text} - - -### Answer -I should always output a valid JSON object. Output nothing other than the JSON object. -```JSON -""" # noqa: E501 - -CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object. -The structure of the JSON object you can found in the instructions. - -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - - -### Instructions: -Some extra information are provided below, you should always follow the instructions as possible as you can. - -{instructions} - -""" - -CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE = """### Structure -Here is the structure of the JSON object, you should always follow the structure. - -{structure} - - -### Text to be converted to JSON -Inside XML tags, there is a text that you should convert to a JSON object. - -{text} - -""" - -CHAT_EXAMPLE = [ - { - "user": { - "query": "What is the weather today in SF?", - "json": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location to get the weather information", - "required": True, - } - }, - "required": ["location"], - }, - }, - "assistant": {"text": "I need to output a valid JSON object.", "json": {"location": "San Francisco"}}, - }, - { - "user": { - "query": "I want to eat some apple pie.", - "json": { - "type": "object", - "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, - "required": ["food"], - }, - }, - "assistant": {"text": "I need to output a valid JSON object.", "json": {"food": "apple pie"}}, - }, -] diff --git a/api/graphon/nodes/protocols.py b/api/graphon/nodes/protocols.py deleted file mode 100644 index 4b050c113c..0000000000 --- a/api/graphon/nodes/protocols.py +++ /dev/null @@ -1,46 +0,0 @@ -from collections.abc import Generator, Mapping -from typing import Any, Protocol - -import httpx - -from graphon.file import File - - -class HttpClientProtocol(Protocol): - @property - def max_retries_exceeded_error(self) -> type[Exception]: ... - - @property - def request_error(self) -> type[Exception]: ... - - def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - -class FileManagerProtocol(Protocol): - def download(self, f: File, /) -> bytes: ... - - -class ToolFileManagerProtocol(Protocol): - def create_file_by_raw( - self, - *, - file_binary: bytes, - mimetype: str, - filename: str | None = None, - ) -> Any: ... - - def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, File | None]: ... - - -class FileReferenceFactoryProtocol(Protocol): - def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: ... diff --git a/api/graphon/nodes/question_classifier/__init__.py b/api/graphon/nodes/question_classifier/__init__.py deleted file mode 100644 index 4d06b6bea3..0000000000 --- a/api/graphon/nodes/question_classifier/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .entities import QuestionClassifierNodeData -from .question_classifier_node import QuestionClassifierNode - -__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeData"] diff --git a/api/graphon/nodes/question_classifier/entities.py b/api/graphon/nodes/question_classifier/entities.py deleted file mode 100644 index 8d5f117315..0000000000 --- a/api/graphon/nodes/question_classifier/entities.py +++ /dev/null @@ -1,30 +0,0 @@ -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.llm import ModelConfig, VisionConfig -from graphon.prompt_entities import MemoryConfig - - -class ClassConfig(BaseModel): - id: str - name: str - - -class QuestionClassifierNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.QUESTION_CLASSIFIER - query_variable_selector: list[str] - model: ModelConfig - classes: list[ClassConfig] - instruction: str | None = None - memory: MemoryConfig | None = None - vision: VisionConfig = Field(default_factory=VisionConfig) - - @property - def structured_output_enabled(self) -> bool: - # NOTE(QuantumGhost): Temporary workaround for issue #20725 - # (https://github.com/langgenius/dify/issues/20725). - # - # The proper fix would be to make `QuestionClassifierNode` inherit - # from `BaseNode` instead of `LLMNode`. - return False diff --git a/api/graphon/nodes/question_classifier/exc.py b/api/graphon/nodes/question_classifier/exc.py deleted file mode 100644 index 2c6354e2a7..0000000000 --- a/api/graphon/nodes/question_classifier/exc.py +++ /dev/null @@ -1,6 +0,0 @@ -class QuestionClassifierNodeError(ValueError): - """Base class for QuestionClassifierNode errors.""" - - -class InvalidModelTypeError(QuestionClassifierNodeError): - """Raised when the model is not a Large Language Model.""" diff --git a/api/graphon/nodes/question_classifier/question_classifier_node.py b/api/graphon/nodes/question_classifier/question_classifier_node.py deleted file mode 100644 index a30ffbb149..0000000000 --- a/api/graphon/nodes/question_classifier/question_classifier_node.py +++ /dev/null @@ -1,395 +0,0 @@ -import json -import re -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ( - BuiltinNodeTypes, - NodeExecutionType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.model_runtime.entities import LLMMode, LLMUsage, ModelPropertyKey, PromptMessageRole -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.node_events import ModelInvokeCompletedEvent, NodeRunResult -from graphon.nodes.base.entities import VariableSelector -from graphon.nodes.base.node import Node -from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from graphon.nodes.llm import ( - LLMNode, - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - llm_utils, -) -from graphon.nodes.llm.file_saver import LLMFileSaver -from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol -from graphon.nodes.protocols import HttpClientProtocol -from graphon.template_rendering import Jinja2TemplateRenderer -from graphon.utils.json_in_md_parser import parse_and_check_json_markdown - -from .entities import QuestionClassifierNodeData -from .exc import InvalidModelTypeError -from .template_prompts import ( - QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, - QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2, - QUESTION_CLASSIFIER_COMPLETION_PROMPT, - QUESTION_CLASSIFIER_SYSTEM_PROMPT, - QUESTION_CLASSIFIER_USER_PROMPT_1, - QUESTION_CLASSIFIER_USER_PROMPT_2, - QUESTION_CLASSIFIER_USER_PROMPT_3, -) - -if TYPE_CHECKING: - from graphon.file.models import File - from graphon.runtime import GraphRuntimeState - - -class _PassthroughPromptMessageSerializer: - def serialize(self, *, model_mode: Any, prompt_messages: Sequence[Any]) -> Any: - _ = model_mode - return list(prompt_messages) - - -class QuestionClassifierNode(Node[QuestionClassifierNodeData]): - node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER - execution_type = NodeExecutionType.BRANCH - - _file_outputs: list["File"] - _llm_file_saver: LLMFileSaver - _prompt_message_serializer: PromptMessageSerializerProtocol - _model_instance: PreparedLLMProtocol - _memory: PromptMessageMemory | None - _template_renderer: Jinja2TemplateRenderer - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - credentials_provider: object | None = None, - model_factory: object | None = None, - model_instance: PreparedLLMProtocol, - http_client: HttpClientProtocol, - template_renderer: Jinja2TemplateRenderer, - memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver, - prompt_message_serializer: PromptMessageSerializerProtocol | None = None, - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - # LLM file outputs, used for MultiModal outputs. - self._file_outputs = [] - - _ = credentials_provider, model_factory, http_client - self._model_instance = model_instance - self._memory = memory - self._template_renderer = template_renderer - - self._llm_file_saver = llm_file_saver - self._prompt_message_serializer = prompt_message_serializer or _PassthroughPromptMessageSerializer() - - @classmethod - def version(cls): - return "1" - - def _run(self): - node_data = self.node_data - variable_pool = self.graph_runtime_state.variable_pool - - # extract variables - variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None - query = variable.value if variable else None - variables = {"query": query} - # fetch model instance - model_instance = self._model_instance - # Resolve variable references in string-typed completion params - model_instance.parameters = llm_utils.resolve_completion_params_variables( - model_instance.parameters, variable_pool - ) - memory = self._memory - # fetch instruction - node_data.instruction = node_data.instruction or "" - node_data.instruction = variable_pool.convert_template(node_data.instruction).text - - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=node_data.vision.configs.variable_selector, - ) - if node_data.vision.enabled - else [] - ) - - # fetch prompt messages - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query or "", - model_instance=model_instance, - context="", - ) - prompt_template = self._get_prompt_template( - node_data=node_data, - query=query or "", - memory=memory, - max_token_limit=rest_token, - ) - # Some models (e.g. Gemma, Mistral) force roles alternation (user/assistant/user/assistant...). - # If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt, - # two consecutive user prompts will be generated, causing model's error. - # To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end. - prompt_messages, stop = llm_utils.fetch_prompt_messages( - prompt_template=prompt_template, - sys_query="", - memory=memory, - model_instance=model_instance, - stop=model_instance.stop, - sys_files=files, - vision_enabled=node_data.vision.enabled, - vision_detail=node_data.vision.configs.detail, - variable_pool=variable_pool, - jinja2_variables=[], - template_renderer=self._template_renderer, - ) - - result_text = "" - usage = LLMUsage.empty_usage() - finish_reason = None - - try: - # handle invoke result - generator = LLMNode.invoke_llm( - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - structured_output_enabled=False, - structured_output=None, - file_saver=self._llm_file_saver, - file_outputs=self._file_outputs, - node_id=self._node_id, - node_type=self.node_type, - ) - - for event in generator: - if isinstance(event, ModelInvokeCompletedEvent): - result_text = event.text - usage = event.usage - finish_reason = event.finish_reason - break - - rendered_classes = [ - c.model_copy(update={"name": variable_pool.convert_template(c.name).text}) for c in node_data.classes - ] - - category_name = rendered_classes[0].name - category_id = rendered_classes[0].id - if "" in result_text: - result_text = re.sub(r"]*>[\s\S]*?", "", result_text, flags=re.IGNORECASE) - result_text_json = parse_and_check_json_markdown(result_text, []) - # result_text_json = json.loads(result_text.strip('```JSON\n')) - if "category_name" in result_text_json and "category_id" in result_text_json: - category_id_result = result_text_json["category_id"] - classes = rendered_classes - classes_map = {class_.id: class_.name for class_ in classes} - category_ids = [_class.id for _class in classes] - if category_id_result in category_ids: - category_name = classes_map[category_id_result] - category_id = category_id_result - process_data = { - "model_mode": node_data.model.mode, - "prompts": self._prompt_message_serializer.serialize( - model_mode=node_data.model.mode, prompt_messages=prompt_messages - ), - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - "model_provider": model_instance.provider, - "model_name": model_instance.model_name, - } - outputs = { - "class_name": category_name, - "class_id": category_id, - "usage": jsonable_encoder(usage), - } - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - process_data=process_data, - outputs=outputs, - edge_source_handle=category_id, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - except ValueError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e), - error_type=type(e).__name__, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - - @property - def model_instance(self) -> PreparedLLMProtocol: - return self._model_instance - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: QuestionClassifierNodeData, - ) -> Mapping[str, Sequence[str]]: - # graph_config is not used in this node type - variable_mapping = {"query": node_data.query_variable_selector} - variable_selectors: list[VariableSelector] = [] - if node_data.instruction: - variable_template_parser = VariableTemplateParser(template=node_data.instruction) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = list(variable_selector.value_selector) - - variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} - - return variable_mapping - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """ - Get default config of node. - :param filters: filter by node config parameters (not used in this implementation). - :return: - """ - # filters parameter is not used in this node type - return {"type": "question-classifier", "config": {"instructions": ""}} - - def _calculate_rest_token( - self, - node_data: QuestionClassifierNodeData, - query: str, - model_instance: PreparedLLMProtocol, - context: str | None, - ) -> int: - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - - prompt_template = self._get_prompt_template(node_data, query, None, 2000) - prompt_messages, _ = llm_utils.fetch_prompt_messages( - prompt_template=prompt_template, - sys_query="", - sys_files=[], - context=context or "", - memory=None, - model_instance=model_instance, - stop=model_instance.stop, - memory_config=node_data.memory, - vision_enabled=False, - vision_detail=node_data.vision.configs.detail, - variable_pool=self.graph_runtime_state.variable_pool, - jinja2_variables=[], - template_renderer=self._template_renderer, - ) - rest_tokens = 2000 - - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_instance.parameters.get(parameter_rule.name) - or model_instance.parameters.get(parameter_rule.use_template or "") - ) or 0 - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - def _get_prompt_template( - self, - node_data: QuestionClassifierNodeData, - query: str, - memory: PromptMessageMemory | None, - max_token_limit: int = 2000, - ): - model_mode = LLMMode(node_data.model.mode) - classes = node_data.classes - categories = [] - for class_ in classes: - category = {"category_id": class_.id, "category_name": class_.name} - categories.append(category) - instruction = node_data.instruction or "" - input_text = query - memory_str = "" - if memory: - memory_str = llm_utils.fetch_memory_text( - memory=memory, - max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, - ) - prompt_messages: list[LLMNodeChatModelMessage] = [] - if model_mode == LLMMode.CHAT: - system_prompt_messages = LLMNodeChatModelMessage( - role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) - ) - prompt_messages.append(system_prompt_messages) - user_prompt_message_1 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1 - ) - prompt_messages.append(user_prompt_message_1) - assistant_prompt_message_1 = LLMNodeChatModelMessage( - role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 - ) - prompt_messages.append(assistant_prompt_message_1) - user_prompt_message_2 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2 - ) - prompt_messages.append(user_prompt_message_2) - assistant_prompt_message_2 = LLMNodeChatModelMessage( - role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 - ) - prompt_messages.append(assistant_prompt_message_2) - user_prompt_message_3 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_3.format( - input_text=input_text, - categories=json.dumps(categories, ensure_ascii=False), - classification_instructions=instruction, - ), - ) - prompt_messages.append(user_prompt_message_3) - return prompt_messages - elif model_mode == LLMMode.COMPLETION: - return LLMNodeCompletionModelPromptTemplate( - text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( - histories=memory_str, - input_text=input_text, - categories=json.dumps(categories, ensure_ascii=False), - classification_instructions=instruction, - ) - ) - - else: - raise InvalidModelTypeError(f"Model mode {model_mode} not support.") diff --git a/api/graphon/nodes/question_classifier/template_prompts.py b/api/graphon/nodes/question_classifier/template_prompts.py deleted file mode 100644 index a615c32383..0000000000 --- a/api/graphon/nodes/question_classifier/template_prompts.py +++ /dev/null @@ -1,76 +0,0 @@ -QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ -### Job Description', -You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. -### Task -Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification. -### Format -The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy. -### Constraint -DO NOT include anything other than the JSON array in your response. -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - -""" # noqa: E501 - -QUESTION_CLASSIFIER_USER_PROMPT_1 = """ - {"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], - "categories": [{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"},{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"},{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"},{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}], - "classification_instructions": ["classify the text based on the feedback provided by customer"]} -""" # noqa: E501 - -QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """ -```json - {"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"], - "category_id": "f5660049-284f-41a7-b301-fd24176a711c", - "category_name": "Customer Service"} -``` -""" - -QUESTION_CLASSIFIER_USER_PROMPT_2 = """ - {"input_text": ["bad service, slow to bring the food"], - "categories": [{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"},{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"},{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}], - "classification_instructions": []} -""" # noqa: E501 - -QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """ -```json - {"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"], - "category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f", - "category_name": "Experience"} -``` -""" - -QUESTION_CLASSIFIER_USER_PROMPT_3 = """ - {{"input_text": ["{input_text}"], - "categories": {categories}, - "classification_instructions": ["{classification_instructions}"]}} -""" - -QUESTION_CLASSIFIER_COMPLETION_PROMPT = """ -### Job Description -You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. -### Task -Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification. -### Format -The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy. -### Constraint -DO NOT include anything other than the JSON array in your response. -### Example -Here is the chat example between human and assistant, inside XML tags. - -User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], "categories": [{{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"}},{{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"}},{{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"}},{{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}}], "classification_instructions": ["classify the text based on the feedback provided by customer"]}} -Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"category_id": "f5660049-284f-41a7-b301-fd24176a711c","category_name": "Customer Service"}} -User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"}},{{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"}},{{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}}], "classification_instructions": []}} -Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}} - -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - -### User Input -{{"input_text" : ["{input_text}"], "categories" : {categories},"classification_instruction" : ["{classification_instructions}"]}} -### Assistant Output -""" # noqa: E501 diff --git a/api/graphon/nodes/runtime.py b/api/graphon/nodes/runtime.py deleted file mode 100644 index 650299898c..0000000000 --- a/api/graphon/nodes/runtime.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator, Mapping, Sequence -from datetime import datetime -from typing import TYPE_CHECKING, Any, Protocol - -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.nodes.tool_runtime_entities import ( - ToolRuntimeHandle, - ToolRuntimeMessage, - ToolRuntimeParameter, -) - -if TYPE_CHECKING: - from graphon.nodes.human_input.entities import HumanInputNodeData - from graphon.nodes.human_input.enums import HumanInputFormStatus - from graphon.nodes.tool.entities import ToolNodeData - from graphon.runtime import VariablePool - - -class ToolNodeRuntimeProtocol(Protocol): - """Workflow-layer adapter owned by `core.workflow` and consumed by `graphon`. - - The graph package depends only on these DTOs and lets the workflow layer - translate between graph-owned abstractions and `core.tools` internals. - """ - - def get_runtime( - self, - *, - node_id: str, - node_data: ToolNodeData, - variable_pool: VariablePool | None, - ) -> ToolRuntimeHandle: ... - - def get_runtime_parameters( - self, - *, - tool_runtime: ToolRuntimeHandle, - ) -> Sequence[ToolRuntimeParameter]: ... - - def invoke( - self, - *, - tool_runtime: ToolRuntimeHandle, - tool_parameters: Mapping[str, Any], - workflow_call_depth: int, - provider_name: str, - ) -> Generator[ToolRuntimeMessage, None, None]: ... - - def get_usage( - self, - *, - tool_runtime: ToolRuntimeHandle, - ) -> LLMUsage: ... - - def build_file_reference(self, *, mapping: Mapping[str, Any]) -> Any: ... - - def resolve_provider_icons( - self, - *, - provider_name: str, - default_icon: str | None = None, - ) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]: ... - - -class HumanInputNodeRuntimeProtocol(Protocol): - """Workflow-layer adapter for human-input runtime persistence and delivery.""" - - def get_form( - self, - *, - node_id: str, - ) -> HumanInputFormStateProtocol | None: ... - - def create_form( - self, - *, - node_id: str, - node_data: HumanInputNodeData, - rendered_content: str, - resolved_default_values: Mapping[str, Any], - ) -> HumanInputFormStateProtocol: ... - - -class HumanInputFormStateProtocol(Protocol): - @property - def id(self) -> str: ... - - @property - def rendered_content(self) -> str: ... - - @property - def selected_action_id(self) -> str | None: ... - - @property - def submitted_data(self) -> Mapping[str, Any] | None: ... - - @property - def submitted(self) -> bool: ... - - @property - def status(self) -> HumanInputFormStatus: ... - - @property - def expiration_time(self) -> datetime: ... diff --git a/api/graphon/nodes/start/__init__.py b/api/graphon/nodes/start/__init__.py deleted file mode 100644 index 5411780423..0000000000 --- a/api/graphon/nodes/start/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .start_node import StartNode - -__all__ = ["StartNode"] diff --git a/api/graphon/nodes/start/entities.py b/api/graphon/nodes/start/entities.py deleted file mode 100644 index 7df62e1b2b..0000000000 --- a/api/graphon/nodes/start/entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from collections.abc import Sequence - -from pydantic import Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.variables.input_entities import VariableEntity - - -class StartNodeData(BaseNodeData): - """ - Start Node Data - """ - - type: NodeType = BuiltinNodeTypes.START - variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/graphon/nodes/start/start_node.py b/api/graphon/nodes/start/start_node.py deleted file mode 100644 index cb3f4c1e7d..0000000000 --- a/api/graphon/nodes/start/start_node.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Any - -from jsonschema import Draft7Validator, ValidationError - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.start.entities import StartNodeData -from graphon.variables.input_entities import VariableEntityType - - -class StartNode(Node[StartNodeData]): - node_type = BuiltinNodeTypes.START - execution_type = NodeExecutionType.ROOT - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) - self._validate_and_normalize_json_object_inputs(node_inputs) - outputs = dict(self.graph_runtime_state.variable_pool.flatten(unprefixed_node_id=self.id)) - outputs.update(node_inputs) - - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) - - def _validate_and_normalize_json_object_inputs(self, node_inputs: dict[str, Any]) -> None: - for variable in self.node_data.variables: - if variable.type != VariableEntityType.JSON_OBJECT: - continue - - key = variable.variable - value = node_inputs.get(key) - - if value is None and variable.required: - raise ValueError(f"{key} is required in input form") - - # If no value provided, skip further processing for this key - if not value: - continue - - if not isinstance(value, dict): - raise ValueError(f"JSON object for '{key}' must be an object") - - # Overwrite with normalized dict to ensure downstream consistency - node_inputs[key] = value - - # If schema exists, then validate against it - schema = variable.json_schema - if not schema: - continue - - try: - Draft7Validator(schema).validate(value) - except ValidationError as e: - raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}") diff --git a/api/graphon/nodes/template_transform/__init__.py b/api/graphon/nodes/template_transform/__init__.py deleted file mode 100644 index 43863b9d59..0000000000 --- a/api/graphon/nodes/template_transform/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .template_transform_node import TemplateTransformNode - -__all__ = ["TemplateTransformNode"] diff --git a/api/graphon/nodes/template_transform/entities.py b/api/graphon/nodes/template_transform/entities.py deleted file mode 100644 index a27a57f34f..0000000000 --- a/api/graphon/nodes/template_transform/entities.py +++ /dev/null @@ -1,13 +0,0 @@ -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base.entities import VariableSelector - - -class TemplateTransformNodeData(BaseNodeData): - """ - Template Transform Node Data. - """ - - type: NodeType = BuiltinNodeTypes.TEMPLATE_TRANSFORM - variables: list[VariableSelector] - template: str diff --git a/api/graphon/nodes/template_transform/template_transform_node.py b/api/graphon/nodes/template_transform/template_transform_node.py deleted file mode 100644 index 4206fb0c1a..0000000000 --- a/api/graphon/nodes/template_transform/template_transform_node.py +++ /dev/null @@ -1,119 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.entities import VariableSelector -from graphon.nodes.base.node import Node -from graphon.nodes.template_transform.entities import TemplateTransformNodeData -from graphon.template_rendering import ( - Jinja2TemplateRenderer, - TemplateRenderError, -) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - -DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000 - - -class TemplateTransformNode(Node[TemplateTransformNodeData]): - node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM - _jinja2_template_renderer: Jinja2TemplateRenderer - _max_output_length: int - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - jinja2_template_renderer: Jinja2TemplateRenderer, - max_output_length: int | None = None, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._jinja2_template_renderer = jinja2_template_renderer - - if max_output_length is not None and max_output_length <= 0: - raise ValueError("max_output_length must be a positive integer") - self._max_output_length = max_output_length or DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ - return { - "type": "template-transform", - "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - # Get variables - variables: dict[str, Any] = {} - for variable_selector in self.node_data.variables: - variable_name = variable_selector.variable - value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - variables[variable_name] = value.to_object() if value else None - # Run code - try: - rendered = self._jinja2_template_renderer.render_template(self.node_data.template, variables) - except TemplateRenderError as e: - return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) - - if len(rendered) > self._max_output_length: - return NodeRunResult( - inputs=variables, - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Output length exceeds {self._max_output_length} characters", - ) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered} - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: TemplateTransformNodeData | Mapping[str, Any], - ) -> Mapping[str, Sequence[str]]: - _ = graph_config - raw_variables = ( - node_data.variables if isinstance(node_data, TemplateTransformNodeData) else node_data.get("variables", []) - ) - variable_mapping: dict[str, Sequence[str]] = {} - for variable_selector in raw_variables: - if isinstance(variable_selector, VariableSelector): - variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector - continue - - if not isinstance(variable_selector, Mapping): - continue - - variable = variable_selector.get("variable") - value_selector = variable_selector.get("value_selector") - if ( - isinstance(variable, str) - and isinstance(value_selector, Sequence) - and all(isinstance(selector_part, str) for selector_part in value_selector) - ): - variable_mapping[node_id + "." + variable] = list(value_selector) - - return variable_mapping diff --git a/api/graphon/nodes/tool/__init__.py b/api/graphon/nodes/tool/__init__.py deleted file mode 100644 index f4982e655d..0000000000 --- a/api/graphon/nodes/tool/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .tool_node import ToolNode - -__all__ = ["ToolNode"] diff --git a/api/graphon/nodes/tool/entities.py b/api/graphon/nodes/tool/entities.py deleted file mode 100644 index 54e6048033..0000000000 --- a/api/graphon/nodes/tool/entities.py +++ /dev/null @@ -1,101 +0,0 @@ -from enum import StrEnum, auto -from typing import Any, Literal, Union - -from pydantic import BaseModel, field_validator -from pydantic_core.core_schema import ValidationInfo - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - - -class ToolProviderType(StrEnum): - """ - Graph-owned enum for persisted tool provider kinds. - """ - - PLUGIN = auto() - BUILT_IN = "builtin" - WORKFLOW = auto() - API = auto() - APP = auto() - DATASET_RETRIEVAL = "dataset-retrieval" - MCP = auto() - - -class ToolEntity(BaseModel): - provider_id: str - provider_type: ToolProviderType - provider_name: str # redundancy - tool_name: str - tool_label: str # redundancy - tool_configurations: dict[str, Any] - credential_id: str | None = None - plugin_unique_identifier: str | None = None # redundancy - - @field_validator("tool_configurations", mode="before") - @classmethod - def validate_tool_configurations(cls, value, values: ValidationInfo): - if not isinstance(value, dict): - raise ValueError("tool_configurations must be a dictionary") - - for key in values.data.get("tool_configurations", {}): - value = values.data.get("tool_configurations", {}).get(key) - if not isinstance(value, str | int | float | bool): - raise ValueError(f"{key} must be a string") - - return value - - -class ToolNodeData(BaseNodeData, ToolEntity): - type: NodeType = BuiltinNodeTypes.TOOL - - class ToolInput(BaseModel): - # TODO: check this type - value: Union[Any, list[str]] - type: Literal["mixed", "variable", "constant"] - - @field_validator("type", mode="before") - @classmethod - def check_type(cls, value, validation_info: ValidationInfo): - typ = value - value = validation_info.data.get("value") - - if value is None: - return typ - - if typ == "mixed" and not isinstance(value, str): - raise ValueError("value must be a string") - elif typ == "variable": - if not isinstance(value, list): - raise ValueError("value must be a list") - for val in value: - if not isinstance(val, str): - raise ValueError("value must be a list of strings") - elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))): - raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}") - return typ - - tool_parameters: dict[str, ToolInput] - # The version of the tool parameter. - # If this value is None, it indicates this is a previous version - # and requires using the legacy parameter parsing rules. - tool_node_version: str | None = None - - @field_validator("tool_parameters", mode="before") - @classmethod - def filter_none_tool_inputs(cls, value): - if not isinstance(value, dict): - return value - - return { - key: tool_input - for key, tool_input in value.items() - if tool_input is not None and cls._has_valid_value(tool_input) - } - - @staticmethod - def _has_valid_value(tool_input): - """Check if the value is valid""" - if isinstance(tool_input, dict): - return tool_input.get("value") is not None - return getattr(tool_input, "value", None) is not None diff --git a/api/graphon/nodes/tool/exc.py b/api/graphon/nodes/tool/exc.py deleted file mode 100644 index 1a309e1084..0000000000 --- a/api/graphon/nodes/tool/exc.py +++ /dev/null @@ -1,28 +0,0 @@ -class ToolNodeError(ValueError): - """Base exception for tool node errors.""" - - pass - - -class ToolRuntimeResolutionError(ToolNodeError): - """Raised when the workflow layer cannot construct a tool runtime.""" - - pass - - -class ToolRuntimeInvocationError(ToolNodeError): - """Raised when the workflow layer fails while invoking a tool runtime.""" - - pass - - -class ToolParameterError(ToolNodeError): - """Exception raised for errors in tool parameters.""" - - pass - - -class ToolFileError(ToolNodeError): - """Exception raised for errors related to tool files.""" - - pass diff --git a/api/graphon/nodes/tool/tool_node.py b/api/graphon/nodes/tool/tool_node.py deleted file mode 100644 index 57ab8ce5d6..0000000000 --- a/api/graphon/nodes/tool/tool_node.py +++ /dev/null @@ -1,432 +0,0 @@ -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from graphon.nodes.base.node import Node -from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from graphon.nodes.protocols import ToolFileManagerProtocol -from graphon.nodes.runtime import ToolNodeRuntimeProtocol -from graphon.nodes.tool_runtime_entities import ( - ToolRuntimeHandle, - ToolRuntimeMessage, - ToolRuntimeParameter, -) -from graphon.variables.segments import ArrayFileSegment - -from .entities import ToolNodeData -from .exc import ( - ToolFileError, - ToolNodeError, - ToolParameterError, -) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - -class ToolNode(Node[ToolNodeData]): - """ - Tool Node - """ - - node_type = BuiltinNodeTypes.TOOL - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - tool_file_manager_factory: ToolFileManagerProtocol, - runtime: ToolNodeRuntimeProtocol | None = None, - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._tool_file_manager_factory = tool_file_manager_factory - if runtime is None: - raise ValueError("runtime is required") - self._runtime = runtime - - @classmethod - def version(cls) -> str: - return "1" - - def populate_start_event(self, event) -> None: - event.provider_id = self.node_data.provider_id - event.provider_type = self.node_data.provider_type - - def _run(self) -> Generator[NodeEventBase, None, None]: - """ - Run the tool node - """ - # fetch tool icon - tool_info = { - "provider_type": self.node_data.provider_type.value, - "provider_id": self.node_data.provider_id, - "plugin_unique_identifier": self.node_data.plugin_unique_identifier, - } - - # get tool runtime - try: - # This is an issue that caused problems before. - # Logically, we shouldn't use the node_data.version field for judgment - # But for backward compatibility with historical data - # this version field judgment is still preserved here. - variable_pool: VariablePool | None = None - if self.node_data.version != "1" or self.node_data.tool_node_version is not None: - variable_pool = self.graph_runtime_state.variable_pool - tool_runtime = self._runtime.get_runtime( - node_id=self._node_id, - node_data=self.node_data, - variable_pool=variable_pool, - ) - except ToolNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs={}, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to get tool runtime: {str(e)}", - error_type=type(e).__name__, - ) - ) - return - - # get parameters - tool_parameters = self._runtime.get_runtime_parameters(tool_runtime=tool_runtime) - parameters = self._generate_parameters( - tool_parameters=tool_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, - ) - parameters_for_log = self._generate_parameters( - tool_parameters=tool_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, - for_log=True, - ) - try: - message_stream = self._runtime.invoke( - tool_runtime=tool_runtime, - tool_parameters=parameters, - workflow_call_depth=self.workflow_call_depth, - provider_name=self.node_data.provider_name, - ) - except ToolNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool: {str(e)}", - error_type=type(e).__name__, - ) - ) - return - - try: - # convert tool messages - _ = yield from self._transform_message( - messages=message_stream, - tool_info=tool_info, - parameters_for_log=parameters_for_log, - node_id=self._node_id, - tool_runtime=tool_runtime, - ) - except ToolNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=str(e), - error_type=type(e).__name__, - ) - ) - - def _generate_parameters( - self, - *, - tool_parameters: Sequence[ToolRuntimeParameter], - variable_pool: "VariablePool", - node_data: ToolNodeData, - for_log: bool = False, - ) -> dict[str, Any]: - """ - Generate parameters based on the given tool parameters, variable pool, and node data. - - Args: - tool_parameters (Sequence[ToolRuntimeParameter]): The list of tool parameters. - variable_pool (VariablePool): The variable pool containing the variables. - node_data (ToolNodeData): The data associated with the tool node. - - Returns: - Mapping[str, Any]: A dictionary containing the generated parameters. - - """ - tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} - - result: dict[str, Any] = {} - for parameter_name in node_data.tool_parameters: - parameter = tool_parameters_dictionary.get(parameter_name) - if not parameter: - result[parameter_name] = None - continue - tool_input = node_data.tool_parameters[parameter_name] - if tool_input.type == "variable": - variable = variable_pool.get(tool_input.value) - if variable is None: - if parameter.required: - raise ToolParameterError(f"Variable {tool_input.value} does not exist") - continue - parameter_value = variable.value - elif tool_input.type in {"mixed", "constant"}: - segment_group = variable_pool.convert_template(str(tool_input.value)) - parameter_value = segment_group.log if for_log else segment_group.text - else: - raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") - result[parameter_name] = parameter_value - - return result - - def _transform_message( - self, - messages: Generator[ToolRuntimeMessage, None, None], - tool_info: Mapping[str, Any], - parameters_for_log: dict[str, Any], - node_id: str, - tool_runtime: ToolRuntimeHandle, - **_: Any, - ) -> Generator[NodeEventBase, None, LLMUsage]: - """ - Convert graph-owned tool runtime messages into node outputs. - """ - text = "" - files: list[File] = [] - json: list[dict | list] = [] - - variables: dict[str, Any] = {} - - for message in messages: - if message.type in { - ToolRuntimeMessage.MessageType.IMAGE_LINK, - ToolRuntimeMessage.MessageType.BINARY_LINK, - ToolRuntimeMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, ToolRuntimeMessage.TextMessage) - - url = message.message.text - if message.meta: - transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - tool_file_id = message.meta.get("tool_file_id") - else: - transfer_method = FileTransferMethod.TOOL_FILE - tool_file_id = None - if not isinstance(tool_file_id, str) or not tool_file_id: - raise ToolFileError("tool message is missing tool_file_id metadata") - - _stream, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) - if not tool_file: - raise ToolFileError(f"tool file {tool_file_id} not found") - if tool_file.mime_type is None: - raise ToolFileError(f"tool file {tool_file_id} is missing mime type") - - file_mapping: dict[str, Any] = { - "tool_file_id": tool_file_id, - "type": get_file_type_by_mime_type(tool_file.mime_type), - "transfer_method": transfer_method, - "url": url, - } - file = self._runtime.build_file_reference(mapping=file_mapping) - files.append(file) - elif message.type == ToolRuntimeMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, ToolRuntimeMessage.TextMessage) - assert message.meta - - tool_file_id = message.meta.get("tool_file_id") - if not isinstance(tool_file_id, str) or not tool_file_id: - raise ToolFileError("tool blob message is missing tool_file_id metadata") - _stream, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) - if not tool_file: - raise ToolFileError(f"tool file {tool_file_id} not exists") - - blob_file_mapping: dict[str, Any] = { - "tool_file_id": tool_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append(self._runtime.build_file_reference(mapping=blob_file_mapping)) - elif message.type == ToolRuntimeMessage.MessageType.TEXT: - assert isinstance(message.message, ToolRuntimeMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - elif message.type == ToolRuntimeMessage.MessageType.JSON: - assert isinstance(message.message, ToolRuntimeMessage.JsonMessage) - # JSON message handling for tool node - if message.message.json_object: - json.append(message.message.json_object) - elif message.type == ToolRuntimeMessage.MessageType.LINK: - assert isinstance(message.message, ToolRuntimeMessage.TextMessage) - - # Check if this LINK message is a file link - file_obj = (message.meta or {}).get("file") - if isinstance(file_obj, File): - files.append(file_obj) - stream_text = f"File: {message.message.text}\n" - else: - stream_text = f"Link: {message.message.text}\n" - - text += stream_text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=stream_text, - is_final=False, - ) - elif message.type == ToolRuntimeMessage.MessageType.VARIABLE: - assert isinstance(message.message, ToolRuntimeMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.") - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value - - yield StreamChunkEvent( - selector=[node_id, variable_name], - chunk=variable_value, - is_final=False, - ) - else: - variables[variable_name] = variable_value - elif message.type == ToolRuntimeMessage.MessageType.FILE: - assert message.meta is not None - assert isinstance(message.meta, dict) - # Validate that meta contains a 'file' key - if "file" not in message.meta: - raise ToolNodeError("File message is missing 'file' key in meta") - - # Validate that the file is an instance of File - if not isinstance(message.meta["file"], File): - raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") - files.append(message.meta["file"]) - elif message.type == ToolRuntimeMessage.MessageType.LOG: - assert isinstance(message.message, ToolRuntimeMessage.LogMessage) - if message.message.metadata: - icon = tool_info.get("icon", "") - dict_metadata = dict(message.message.metadata) - if dict_metadata.get("provider"): - icon, icon_dark = self._runtime.resolve_provider_icons( - provider_name=dict_metadata["provider"], - default_icon=icon, - ) - dict_metadata["icon"] = icon - dict_metadata["icon_dark"] = icon_dark - message.message.metadata = dict_metadata - - # Add agent_logs to outputs['json'] to ensure frontend can access thinking process - json_output: list[dict[str, Any] | list[Any]] = [] - - # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] - if json: - json_output.extend(json) - else: - json_output.append({"data": []}) - - # Send final chunk events for all streamed outputs - # Final chunk for text stream - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) - - # Final chunks for any streamed variables - for var_name in variables: - yield StreamChunkEvent( - selector=[self._node_id, var_name], - chunk="", - is_final=True, - ) - - usage = self._runtime.get_usage(tool_runtime=tool_runtime) - - metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { - WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, - } - if isinstance(usage.total_tokens, int) and usage.total_tokens > 0: - metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens - metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price - metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, - metadata=metadata, - inputs=parameters_for_log, - llm_usage=usage, - ) - ) - - return usage - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ToolNodeData, - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - _ = graph_config # Explicitly mark as unused - typed_node_data = node_data - result = {} - for parameter_name in typed_node_data.tool_parameters: - input = typed_node_data.tool_parameters[parameter_name] - match input.type: - case "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - case "variable": - selector_key = ".".join(input.value) - result[f"#{selector_key}#"] = input.value - case "constant": - pass - - result = {node_id + "." + key: value for key, value in result.items()} - - return result - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled diff --git a/api/graphon/nodes/tool_runtime_entities.py b/api/graphon/nodes/tool_runtime_entities.py deleted file mode 100644 index 5bb0c16573..0000000000 --- a/api/graphon/nodes/tool_runtime_entities.py +++ /dev/null @@ -1,105 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from enum import StrEnum, auto -from typing import Any - -from pydantic import BaseModel, ConfigDict, Field - - -class _ToolRuntimeModel(BaseModel): - model_config = ConfigDict(extra="forbid") - - -@dataclass(frozen=True, slots=True) -class ToolRuntimeHandle: - """Opaque graph-owned handle for a workflow-layer tool runtime. - - Workflow-specific execution context must stay behind `raw` so the graph - contract does not absorb application-owned concepts. - """ - - raw: object - - -@dataclass(frozen=True, slots=True) -class ToolRuntimeParameter: - """Graph-owned parameter shape used by tool nodes.""" - - name: str - required: bool = False - - -class ToolRuntimeMessage(_ToolRuntimeModel): - """Graph-owned tool invocation message DTO.""" - - class TextMessage(_ToolRuntimeModel): - text: str - - class JsonMessage(_ToolRuntimeModel): - json_object: dict[str, Any] | list[Any] - suppress_output: bool = Field(default=False) - - class BlobMessage(_ToolRuntimeModel): - blob: bytes - - class BlobChunkMessage(_ToolRuntimeModel): - id: str - sequence: int - total_length: int - blob: bytes - end: bool - - class FileMessage(_ToolRuntimeModel): - file_marker: str = Field(default="file_marker") - - class VariableMessage(_ToolRuntimeModel): - variable_name: str - variable_value: dict[str, Any] | list[Any] | str | int | float | bool | None - stream: bool = Field(default=False) - - class LogMessage(_ToolRuntimeModel): - class LogStatus(StrEnum): - START = auto() - ERROR = auto() - SUCCESS = auto() - - id: str - label: str - parent_id: str | None = None - error: str | None = None - status: LogStatus - data: dict[str, Any] - metadata: dict[str, Any] = Field(default_factory=dict) - - class RetrieverResourceMessage(_ToolRuntimeModel): - retriever_resources: list[dict[str, Any]] - context: str - - class MessageType(StrEnum): - TEXT = auto() - IMAGE = auto() - LINK = auto() - BLOB = auto() - JSON = auto() - IMAGE_LINK = auto() - BINARY_LINK = auto() - VARIABLE = auto() - FILE = auto() - LOG = auto() - BLOB_CHUNK = auto() - RETRIEVER_RESOURCES = auto() - - type: MessageType = MessageType.TEXT - message: ( - JsonMessage - | TextMessage - | BlobChunkMessage - | BlobMessage - | LogMessage - | FileMessage - | None - | VariableMessage - | RetrieverResourceMessage - ) - meta: dict[str, Any] | None = None diff --git a/api/graphon/nodes/variable_aggregator/__init__.py b/api/graphon/nodes/variable_aggregator/__init__.py deleted file mode 100644 index 0b6bf2a5b6..0000000000 --- a/api/graphon/nodes/variable_aggregator/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .variable_aggregator_node import VariableAggregatorNode - -__all__ = ["VariableAggregatorNode"] diff --git a/api/graphon/nodes/variable_aggregator/entities.py b/api/graphon/nodes/variable_aggregator/entities.py deleted file mode 100644 index 136fd28f8c..0000000000 --- a/api/graphon/nodes/variable_aggregator/entities.py +++ /dev/null @@ -1,35 +0,0 @@ -from pydantic import BaseModel - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.variables.types import SegmentType - - -class AdvancedSettings(BaseModel): - """ - Advanced setting. - """ - - group_enabled: bool - - class Group(BaseModel): - """ - Group. - """ - - output_type: SegmentType - variables: list[list[str]] - group_name: str - - groups: list[Group] - - -class VariableAggregatorNodeData(BaseNodeData): - """ - Variable Aggregator Node Data. - """ - - type: NodeType = BuiltinNodeTypes.VARIABLE_AGGREGATOR - output_type: str - variables: list[list[str]] - advanced_settings: AdvancedSettings | None = None diff --git a/api/graphon/nodes/variable_aggregator/variable_aggregator_node.py b/api/graphon/nodes/variable_aggregator/variable_aggregator_node.py deleted file mode 100644 index 71b221e196..0000000000 --- a/api/graphon/nodes/variable_aggregator/variable_aggregator_node.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping - -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.variable_aggregator.entities import VariableAggregatorNodeData -from graphon.variables.segments import Segment - - -class VariableAggregatorNode(Node[VariableAggregatorNodeData]): - node_type = BuiltinNodeTypes.VARIABLE_AGGREGATOR - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - # Get variables - outputs: dict[str, Segment | Mapping[str, Segment]] = {} - inputs = {} - - if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: - for selector in self.node_data.variables: - variable = self.graph_runtime_state.variable_pool.get(selector) - if variable is not None: - outputs = {"output": variable} - - inputs = {".".join(selector[1:]): variable.to_object()} - break - else: - for group in self.node_data.advanced_settings.groups: - for selector in group.variables: - variable = self.graph_runtime_state.variable_pool.get(selector) - - if variable is not None: - outputs[group.group_name] = {"output": variable} - inputs[".".join(selector[1:])] = variable.to_object() - break - - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs) diff --git a/api/graphon/nodes/variable_assigner/__init__.py b/api/graphon/nodes/variable_assigner/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/nodes/variable_assigner/common/__init__.py b/api/graphon/nodes/variable_assigner/common/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/nodes/variable_assigner/common/exc.py b/api/graphon/nodes/variable_assigner/common/exc.py deleted file mode 100644 index f8dbedc290..0000000000 --- a/api/graphon/nodes/variable_assigner/common/exc.py +++ /dev/null @@ -1,4 +0,0 @@ -class VariableOperatorNodeError(ValueError): - """Base error type, don't use directly.""" - - pass diff --git a/api/graphon/nodes/variable_assigner/common/helpers.py b/api/graphon/nodes/variable_assigner/common/helpers.py deleted file mode 100644 index 4c30e009f2..0000000000 --- a/api/graphon/nodes/variable_assigner/common/helpers.py +++ /dev/null @@ -1,55 +0,0 @@ -from collections.abc import Mapping, MutableMapping, Sequence -from typing import Any, TypeVar - -from pydantic import BaseModel - -from graphon.variables import Segment -from graphon.variables.consts import SELECTORS_LENGTH -from graphon.variables.types import SegmentType - -# Use double underscore (`__`) prefix for internal variables -# to minimize risk of collision with user-defined variable names. -_UPDATED_VARIABLES_KEY = "__updated_variables" - - -class UpdatedVariable(BaseModel): - name: str - selector: Sequence[str] - value_type: SegmentType - new_value: Any = None - - -_T = TypeVar("_T", bound=MutableMapping[str, Any]) - - -def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable: - if len(selector) < SELECTORS_LENGTH: - raise Exception("selector too short") - _, var_name = selector[:2] - return UpdatedVariable( - name=var_name, - selector=list(selector[:2]), - value_type=seg.value_type, - new_value=seg.value, - ) - - -def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T: - m[_UPDATED_VARIABLES_KEY] = updates - return m - - -def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None: - updated_values = m.get(_UPDATED_VARIABLES_KEY, None) - if updated_values is None: - return None - result = [] - for items in updated_values: - if isinstance(items, UpdatedVariable): - result.append(items) - elif isinstance(items, dict): - items = UpdatedVariable.model_validate(items) - result.append(items) - else: - raise TypeError(f"Invalid updated variable: {items}, type={type(items)}") - return result diff --git a/api/graphon/nodes/variable_assigner/v1/__init__.py b/api/graphon/nodes/variable_assigner/v1/__init__.py deleted file mode 100644 index 7eb1428e50..0000000000 --- a/api/graphon/nodes/variable_assigner/v1/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .node import VariableAssignerNode - -__all__ = ["VariableAssignerNode"] diff --git a/api/graphon/nodes/variable_assigner/v1/node.py b/api/graphon/nodes/variable_assigner/v1/node.py deleted file mode 100644 index 19ded5f123..0000000000 --- a/api/graphon/nodes/variable_assigner/v1/node.py +++ /dev/null @@ -1,106 +0,0 @@ -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast - -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent, VariableUpdatedEvent -from graphon.nodes.base.node import Node -from graphon.nodes.variable_assigner.common import helpers as common_helpers -from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from graphon.variables import SegmentType, Variable, VariableBase - -from .node_data import VariableAssignerData, WriteMode - -if TYPE_CHECKING: - from graphon.runtime import GraphRuntimeState - - -class VariableAssignerNode(Node[VariableAssignerData]): - node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: - """ - Check if this Variable Assigner node blocks the output of specific variables. - - Returns True if this node updates any of the requested conversation variables. - """ - assigned_selector = tuple(self.node_data.assigned_variable_selector) - return assigned_selector in variable_selectors - - @classmethod - def version(cls) -> str: - return "1" - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: VariableAssignerData, - ) -> Mapping[str, Sequence[str]]: - mapping = {} - selector_key = ".".join(node_data.assigned_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.assigned_variable_selector - - selector_key = ".".join(node_data.input_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.input_variable_selector - return mapping - - def _run(self) -> Generator[NodeEventBase, None, None]: - assigned_variable_selector = self.node_data.assigned_variable_selector - # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) - if not isinstance(original_variable, VariableBase): - raise VariableOperatorNodeError("assigned variable not found") - - match self.node_data.write_mode: - case WriteMode.OVER_WRITE: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) - if not income_value: - raise VariableOperatorNodeError("input value not found") - updated_variable = original_variable.model_copy(update={"value": income_value.value}) - - case WriteMode.APPEND: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) - if not income_value: - raise VariableOperatorNodeError("input value not found") - updated_value = original_variable.value + [income_value.value] - updated_variable = original_variable.model_copy(update={"value": updated_value}) - - case WriteMode.CLEAR: - income_value = SegmentType.get_zero_value(original_variable.value_type) - updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) - - updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] - yield VariableUpdatedEvent(variable=cast(Variable, updated_variable)) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={ - "value": income_value.to_object(), - }, - # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, - # we still set `output_variables` as a list to ensure the schema of output is - # compatible with `v2.VariableAssignerNode`. - process_data=common_helpers.set_updated_variables({}, updated_variables), - outputs={}, - ) - ) diff --git a/api/graphon/nodes/variable_assigner/v1/node_data.py b/api/graphon/nodes/variable_assigner/v1/node_data.py deleted file mode 100644 index 4f630bc76c..0000000000 --- a/api/graphon/nodes/variable_assigner/v1/node_data.py +++ /dev/null @@ -1,18 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - - -class WriteMode(StrEnum): - OVER_WRITE = "over-write" - APPEND = "append" - CLEAR = "clear" - - -class VariableAssignerData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER - assigned_variable_selector: Sequence[str] - write_mode: WriteMode - input_variable_selector: Sequence[str] diff --git a/api/graphon/nodes/variable_assigner/v2/__init__.py b/api/graphon/nodes/variable_assigner/v2/__init__.py deleted file mode 100644 index 7eb1428e50..0000000000 --- a/api/graphon/nodes/variable_assigner/v2/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .node import VariableAssignerNode - -__all__ = ["VariableAssignerNode"] diff --git a/api/graphon/nodes/variable_assigner/v2/entities.py b/api/graphon/nodes/variable_assigner/v2/entities.py deleted file mode 100644 index d1c68c8e8c..0000000000 --- a/api/graphon/nodes/variable_assigner/v2/entities.py +++ /dev/null @@ -1,28 +0,0 @@ -from collections.abc import Sequence -from typing import Any - -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - -from .enums import InputType, Operation - - -class VariableOperationItem(BaseModel): - variable_selector: Sequence[str] - input_type: InputType - operation: Operation - # NOTE(QuantumGhost): The `value` field serves multiple purposes depending on context: - # - # 1. For CONSTANT input_type: Contains the literal value to be used in the operation. - # 2. For VARIABLE input_type: Initially contains the selector of the source variable. - # 3. During the variable updating procedure: The `value` field is reassigned to hold - # the resolved actual value that will be applied to the target variable. - value: Any = None - - -class VariableAssignerNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER - version: str = "2" - items: Sequence[VariableOperationItem] = Field(default_factory=list) diff --git a/api/graphon/nodes/variable_assigner/v2/enums.py b/api/graphon/nodes/variable_assigner/v2/enums.py deleted file mode 100644 index 291b1208d4..0000000000 --- a/api/graphon/nodes/variable_assigner/v2/enums.py +++ /dev/null @@ -1,20 +0,0 @@ -from enum import StrEnum - - -class Operation(StrEnum): - OVER_WRITE = "over-write" - CLEAR = "clear" - APPEND = "append" - EXTEND = "extend" - SET = "set" - ADD = "+=" - SUBTRACT = "-=" - MULTIPLY = "*=" - DIVIDE = "/=" - REMOVE_FIRST = "remove-first" - REMOVE_LAST = "remove-last" - - -class InputType(StrEnum): - VARIABLE = "variable" - CONSTANT = "constant" diff --git a/api/graphon/nodes/variable_assigner/v2/exc.py b/api/graphon/nodes/variable_assigner/v2/exc.py deleted file mode 100644 index 90d7648574..0000000000 --- a/api/graphon/nodes/variable_assigner/v2/exc.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Sequence -from typing import Any - -from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError - -from .enums import InputType, Operation - - -class OperationNotSupportedError(VariableOperatorNodeError): - def __init__(self, *, operation: Operation, variable_type: str): - super().__init__(f"Operation {operation} is not supported for type {variable_type}") - - -class InputTypeNotSupportedError(VariableOperatorNodeError): - def __init__(self, *, input_type: InputType, operation: Operation): - super().__init__(f"Input type {input_type} is not supported for operation {operation}") - - -class VariableNotFoundError(VariableOperatorNodeError): - def __init__(self, *, variable_selector: Sequence[str]): - super().__init__(f"Variable {variable_selector} not found") - - -class InvalidInputValueError(VariableOperatorNodeError): - def __init__(self, *, value: Any): - super().__init__(f"Invalid input value {value}") - - -class ConversationIDNotFoundError(VariableOperatorNodeError): - def __init__(self): - super().__init__("conversation_id not found") - - -class InvalidDataError(VariableOperatorNodeError): - def __init__(self, message: str): - super().__init__(message) diff --git a/api/graphon/nodes/variable_assigner/v2/helpers.py b/api/graphon/nodes/variable_assigner/v2/helpers.py deleted file mode 100644 index ebc6c79476..0000000000 --- a/api/graphon/nodes/variable_assigner/v2/helpers.py +++ /dev/null @@ -1,98 +0,0 @@ -from typing import Any - -from graphon.variables import SegmentType - -from .enums import Operation - - -def is_operation_supported(*, variable_type: SegmentType, operation: Operation): - match operation: - case Operation.OVER_WRITE | Operation.CLEAR: - return True - case Operation.SET: - return variable_type in { - SegmentType.OBJECT, - SegmentType.STRING, - SegmentType.NUMBER, - SegmentType.INTEGER, - SegmentType.FLOAT, - SegmentType.BOOLEAN, - } - case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE: - # Only number variable can be added, subtracted, multiplied or divided - return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT} - case Operation.APPEND | Operation.EXTEND | Operation.REMOVE_FIRST | Operation.REMOVE_LAST: - # Only array variable can be appended or extended - # Only array variable can have elements removed - return variable_type.is_array_type() - - -def is_variable_input_supported(*, operation: Operation): - if operation in {Operation.SET, Operation.ADD, Operation.SUBTRACT, Operation.MULTIPLY, Operation.DIVIDE}: - return False - return True - - -def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation): - match variable_type: - case SegmentType.STRING | SegmentType.OBJECT | SegmentType.BOOLEAN: - return operation in {Operation.OVER_WRITE, Operation.SET} - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return operation in { - Operation.OVER_WRITE, - Operation.SET, - Operation.ADD, - Operation.SUBTRACT, - Operation.MULTIPLY, - Operation.DIVIDE, - } - case _: - return False - - -def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any): - if operation in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST}: - return True - match variable_type: - case SegmentType.STRING: - return isinstance(value, str) - - case SegmentType.BOOLEAN: - return isinstance(value, bool) - - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - if not isinstance(value, int | float): - return False - if operation == Operation.DIVIDE and value == 0: - return False - return True - - case SegmentType.OBJECT: - return isinstance(value, dict) - - # Array & Append - case SegmentType.ARRAY_ANY if operation == Operation.APPEND: - return isinstance(value, str | float | int | dict) - case SegmentType.ARRAY_STRING if operation == Operation.APPEND: - return isinstance(value, str) - case SegmentType.ARRAY_NUMBER if operation == Operation.APPEND: - return isinstance(value, int | float) - case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND: - return isinstance(value, dict) - case SegmentType.ARRAY_BOOLEAN if operation == Operation.APPEND: - return isinstance(value, bool) - - # Array & Extend / Overwrite - case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, str | float | int | dict) for item in value) - case SegmentType.ARRAY_STRING if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, str) for item in value) - case SegmentType.ARRAY_NUMBER if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, int | float) for item in value) - case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, dict) for item in value) - case SegmentType.ARRAY_BOOLEAN if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, bool) for item in value) - - case _: - return False diff --git a/api/graphon/nodes/variable_assigner/v2/node.py b/api/graphon/nodes/variable_assigner/v2/node.py deleted file mode 100644 index 887bd1b604..0000000000 --- a/api/graphon/nodes/variable_assigner/v2/node.py +++ /dev/null @@ -1,257 +0,0 @@ -import json -from collections.abc import Generator, Mapping, MutableMapping, Sequence -from typing import TYPE_CHECKING, Any, cast - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent, VariableUpdatedEvent -from graphon.nodes.base.node import Node -from graphon.nodes.variable_assigner.common import helpers as common_helpers -from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from graphon.variables import SegmentType, Variable, VariableBase -from graphon.variables.consts import SELECTORS_LENGTH - -from . import helpers -from .entities import VariableAssignerNodeData, VariableOperationItem -from .enums import InputType, Operation -from .exc import ( - InputTypeNotSupportedError, - InvalidDataError, - InvalidInputValueError, - OperationNotSupportedError, - VariableNotFoundError, -) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - - -def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): - selector_str = ".".join(item.variable_selector) - key = f"{node_id}.#{selector_str}#" - mapping[key] = item.variable_selector - - -def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): - # Keep this in sync with the logic in _run methods... - if item.input_type != InputType.VARIABLE: - return - selector = item.value - if not isinstance(selector, list): - raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}") - if len(selector) < SELECTORS_LENGTH: - raise InvalidDataError(f"selector too short, {node_id=}, {item=}") - selector_str = ".".join(selector) - key = f"{node_id}.#{selector_str}#" - mapping[key] = selector - - -class VariableAssignerNode(Node[VariableAssignerNodeData]): - node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: - """ - Check if this Variable Assigner node blocks the output of specific variables. - - Returns True if this node updates any of the requested conversation variables. - """ - # Check each item in this Variable Assigner node - for item in self.node_data.items: - # Convert the item's variable_selector to tuple for comparison - item_selector_tuple = tuple(item.variable_selector) - - # Check if this item updates any of the requested variables - if item_selector_tuple in variable_selectors: - return True - - return False - - @classmethod - def version(cls) -> str: - return "2" - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: VariableAssignerNodeData, - ) -> Mapping[str, Sequence[str]]: - var_mapping: dict[str, Sequence[str]] = {} - for item in node_data.items: - _target_mapping_from_item(var_mapping, node_id, item) - _source_mapping_from_item(var_mapping, node_id, item) - return var_mapping - - def _run(self) -> Generator[NodeEventBase, None, None]: - inputs = self.node_data.model_dump() - process_data: dict[str, Any] = {} - # NOTE: This node has no outputs - updated_variable_selectors: list[Sequence[str]] = [] - # Preserve intra-node read-after-write behavior without mutating the shared pool - # until the engine processes the emitted VariableUpdatedEvent instances. - working_variable_pool = self.graph_runtime_state.variable_pool.model_copy(deep=True) - - try: - for item in self.node_data.items: - variable = working_variable_pool.get(item.variable_selector) - - # ==================== Validation Part - - # Check if variable exists - if not isinstance(variable, VariableBase): - raise VariableNotFoundError(variable_selector=item.variable_selector) - - # Check if operation is supported - if not helpers.is_operation_supported(variable_type=variable.value_type, operation=item.operation): - raise OperationNotSupportedError(operation=item.operation, variable_type=variable.value_type) - - # Check if variable input is supported - if item.input_type == InputType.VARIABLE and not helpers.is_variable_input_supported( - operation=item.operation - ): - raise InputTypeNotSupportedError(input_type=InputType.VARIABLE, operation=item.operation) - - # Check if constant input is supported - if item.input_type == InputType.CONSTANT and not helpers.is_constant_input_supported( - variable_type=variable.value_type, operation=item.operation - ): - raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation) - - # Get value from variable pool - input_value = item.value - if ( - item.input_type == InputType.VARIABLE - and item.operation not in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST} - and item.value is not None - ): - value = working_variable_pool.get(item.value) - if value is None: - raise VariableNotFoundError(variable_selector=item.value) - # Skip if value is NoneSegment - if value.value_type == SegmentType.NONE: - continue - input_value = value.value - - # If set string / bytes / bytearray to object, try convert string to object. - if ( - item.operation == Operation.SET - and variable.value_type == SegmentType.OBJECT - and isinstance(input_value, str | bytes | bytearray) - ): - try: - input_value = json.loads(input_value) - except json.JSONDecodeError: - raise InvalidInputValueError(value=input_value) - - # Check if input value is valid - if not helpers.is_input_value_valid( - variable_type=variable.value_type, operation=item.operation, value=input_value - ): - raise InvalidInputValueError(value=input_value) - - # ==================== Execution Part - - updated_value = self._handle_item( - variable=variable, - operation=item.operation, - value=input_value, - ) - updated_variable = variable.model_copy(update={"value": updated_value}) - working_variable_pool.add(updated_variable.selector, updated_variable) - updated_variable_selectors.append(updated_variable.selector) - except VariableOperatorNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - error=str(e), - ) - ) - return - - # The `updated_variable_selectors` is a list contains list[str] which not hashable, - # remove duplicated items while preserving the first update order. - updated_variable_selectors = list(dict.fromkeys(map(tuple, updated_variable_selectors))) - - for selector in updated_variable_selectors: - variable = working_variable_pool.get(selector) - if not isinstance(variable, VariableBase): - raise VariableNotFoundError(variable_selector=selector) - process_data[variable.name] = variable.value - - updated_variables = [ - common_helpers.variable_to_processed_data(selector, seg) - for selector in updated_variable_selectors - if (seg := working_variable_pool.get(selector)) is not None - ] - - process_data = common_helpers.set_updated_variables(process_data, updated_variables) - for selector in updated_variable_selectors: - variable = working_variable_pool.get(selector) - if not isinstance(variable, VariableBase): - raise VariableNotFoundError(variable_selector=selector) - yield VariableUpdatedEvent(variable=cast(Variable, variable)) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={}, - ) - ) - - def _handle_item( - self, - *, - variable: VariableBase, - operation: Operation, - value: Any, - ): - match operation: - case Operation.OVER_WRITE: - return value - case Operation.CLEAR: - return SegmentType.get_zero_value(variable.value_type).to_object() - case Operation.APPEND: - return variable.value + [value] - case Operation.EXTEND: - return variable.value + value - case Operation.SET: - return value - case Operation.ADD: - return variable.value + value - case Operation.SUBTRACT: - return variable.value - value - case Operation.MULTIPLY: - return variable.value * value - case Operation.DIVIDE: - return variable.value / value - case Operation.REMOVE_FIRST: - # If array is empty, do nothing - if not variable.value: - return variable.value - return variable.value[1:] - case Operation.REMOVE_LAST: - # If array is empty, do nothing - if not variable.value: - return variable.value - return variable.value[:-1] diff --git a/api/graphon/prompt_entities.py b/api/graphon/prompt_entities.py deleted file mode 100644 index 2b8b106c6c..0000000000 --- a/api/graphon/prompt_entities.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel - -from graphon.model_runtime.entities.message_entities import PromptMessageRole - - -class ChatModelMessage(BaseModel): - """Graph-owned chat prompt template message.""" - - text: str - role: PromptMessageRole - edition_type: Literal["basic", "jinja2"] | None = None - - -class CompletionModelPromptTemplate(BaseModel): - """Graph-owned completion prompt template.""" - - text: str - edition_type: Literal["basic", "jinja2"] | None = None - - -class MemoryConfig(BaseModel): - """Graph-owned memory configuration for prompt assembly.""" - - class RolePrefix(BaseModel): - """Role labels used when serializing completion-model histories.""" - - user: str - assistant: str - - class WindowConfig(BaseModel): - """History windowing controls.""" - - enabled: bool - size: int | None = None - - role_prefix: RolePrefix | None = None - window: WindowConfig - query_prompt_template: str | None = None - - -__all__ = [ - "ChatModelMessage", - "CompletionModelPromptTemplate", - "MemoryConfig", -] diff --git a/api/graphon/runtime/__init__.py b/api/graphon/runtime/__init__.py deleted file mode 100644 index adca07e59a..0000000000 --- a/api/graphon/runtime/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .graph_runtime_state import ( - ChildEngineBuilderNotConfiguredError, - ChildEngineError, - ChildGraphNotFoundError, - GraphRuntimeState, -) -from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool -from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper -from .variable_pool import VariablePool, VariableValue - -__all__ = [ - "ChildEngineBuilderNotConfiguredError", - "ChildEngineError", - "ChildGraphNotFoundError", - "GraphRuntimeState", - "ReadOnlyGraphRuntimeState", - "ReadOnlyGraphRuntimeStateWrapper", - "ReadOnlyVariablePool", - "ReadOnlyVariablePoolWrapper", - "VariablePool", - "VariableValue", -] diff --git a/api/graphon/runtime/graph_runtime_state.py b/api/graphon/runtime/graph_runtime_state.py deleted file mode 100644 index 8453830f28..0000000000 --- a/api/graphon/runtime/graph_runtime_state.py +++ /dev/null @@ -1,704 +0,0 @@ -from __future__ import annotations - -import importlib -import json -from collections.abc import Mapping, Sequence -from contextlib import AbstractContextManager, nullcontext -from copy import deepcopy -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Protocol - -from pydantic import BaseModel, Field -from pydantic.json import pydantic_encoder - -from graphon.enums import NodeExecutionType, NodeState, NodeType -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime.variable_pool import VariablePool - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.entities.pause_reason import PauseReason - - -class ReadyQueueProtocol(Protocol): - """Structural interface required from ready queue implementations.""" - - def put(self, item: str) -> None: - """Enqueue the identifier of a node that is ready to run.""" - ... - - def get(self, timeout: float | None = None) -> str: - """Return the next node identifier, blocking until available or timeout expires.""" - ... - - def task_done(self) -> None: - """Signal that the most recently dequeued node has completed processing.""" - ... - - def empty(self) -> bool: - """Return True when the queue contains no pending nodes.""" - ... - - def qsize(self) -> int: - """Approximate the number of pending nodes awaiting execution.""" - ... - - def dumps(self) -> str: - """Serialize the queue contents for persistence.""" - ... - - def loads(self, data: str) -> None: - """Restore the queue contents from a serialized payload.""" - ... - - -class NodeExecutionProtocol(Protocol): - """Structural interface for persisted per-node execution state.""" - - execution_id: str | None - - -class GraphExecutionProtocol(Protocol): - """Structural interface for graph execution aggregate. - - Defines the minimal set of attributes and methods required from a GraphExecution entity - for runtime orchestration and state management. - """ - - workflow_id: str - started: bool - completed: bool - aborted: bool - error: Exception | None - exceptions_count: int - pause_reasons: list[PauseReason] - - @property - def node_executions(self) -> Mapping[str, NodeExecutionProtocol]: - """Return the persisted node execution state keyed by node id.""" - ... - - def start(self) -> None: - """Transition execution into the running state.""" - ... - - def complete(self) -> None: - """Mark execution as successfully completed.""" - ... - - def abort(self, reason: str) -> None: - """Abort execution in response to an external stop request.""" - ... - - def fail(self, error: Exception) -> None: - """Record an unrecoverable error and end execution.""" - ... - - def dumps(self) -> str: - """Serialize execution state into a JSON payload.""" - ... - - def loads(self, data: str) -> None: - """Restore execution state from a previously serialized payload.""" - ... - - -class ResponseStreamCoordinatorProtocol(Protocol): - """Structural interface for response stream coordinator.""" - - def register(self, response_node_id: str) -> None: - """Register a response node so its outputs can be streamed.""" - ... - - def loads(self, data: str) -> None: - """Restore coordinator state from a serialized payload.""" - ... - - def dumps(self) -> str: - """Serialize coordinator state for persistence.""" - ... - - -class NodeProtocol(Protocol): - """Structural interface for graph nodes.""" - - id: str - state: NodeState - execution_type: NodeExecutionType - node_type: ClassVar[NodeType] - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ... - - -class EdgeProtocol(Protocol): - id: str - state: NodeState - tail: str - head: str - source_handle: str - - -class GraphProtocol(Protocol): - """Structural interface required from graph instances attached to the runtime state.""" - - nodes: Mapping[str, NodeProtocol] - edges: Mapping[str, EdgeProtocol] - root_node: NodeProtocol - - def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... - - -class ChildGraphEngineBuilderProtocol(Protocol): - def build_child_engine( - self, - *, - workflow_id: str, - graph_init_params: GraphInitParams, - parent_graph_runtime_state: GraphRuntimeState, - root_node_id: str, - variable_pool: VariablePool | None = None, - ) -> Any: ... - - -class ChildEngineError(ValueError): - """Base error type for child-engine creation failures.""" - - -class ChildEngineBuilderNotConfiguredError(ChildEngineError): - """Raised when child-engine creation is requested without a bound builder.""" - - -class ChildGraphNotFoundError(ChildEngineError): - """Raised when the requested child graph entry point cannot be resolved.""" - - -class _GraphStateSnapshot(BaseModel): - """Serializable graph state snapshot for node/edge states.""" - - nodes: dict[str, NodeState] = Field(default_factory=dict) - edges: dict[str, NodeState] = Field(default_factory=dict) - - -@dataclass(slots=True) -class _GraphRuntimeStateSnapshot: - """Immutable view of a serialized runtime state snapshot.""" - - start_at: float - total_tokens: int - node_run_steps: int - llm_usage: LLMUsage - outputs: dict[str, Any] - variable_pool: VariablePool - has_variable_pool: bool - ready_queue_dump: str | None - graph_execution_dump: str | None - response_coordinator_dump: str | None - paused_nodes: tuple[str, ...] - deferred_nodes: tuple[str, ...] - graph_node_states: dict[str, NodeState] - graph_edge_states: dict[str, NodeState] - - -class GraphRuntimeState: - """Mutable runtime state shared across graph execution components. - - `GraphRuntimeState` encapsulates the runtime state of workflow execution, - including scheduling details, variable values, and timing information. - - Values that are initialized prior to workflow execution and remain constant - throughout the execution should be part of `GraphInitParams` instead. - """ - - def __init__( - self, - *, - variable_pool: VariablePool, - start_at: float, - total_tokens: int = 0, - llm_usage: LLMUsage | None = None, - outputs: dict[str, object] | None = None, - node_run_steps: int = 0, - ready_queue: ReadyQueueProtocol | None = None, - graph_execution: GraphExecutionProtocol | None = None, - response_coordinator: ResponseStreamCoordinatorProtocol | None = None, - graph: GraphProtocol | None = None, - execution_context: AbstractContextManager[object] | None = None, - ) -> None: - self._variable_pool = variable_pool - self._start_at = start_at - - if total_tokens < 0: - raise ValueError("total_tokens must be non-negative") - self._total_tokens = total_tokens - - self._llm_usage = (llm_usage or LLMUsage.empty_usage()).model_copy() - self._outputs = deepcopy(outputs) if outputs is not None else {} - - if node_run_steps < 0: - raise ValueError("node_run_steps must be non-negative") - self._node_run_steps = node_run_steps - - self._graph: GraphProtocol | None = None - - self._ready_queue = ready_queue - self._graph_execution = graph_execution - self._response_coordinator = response_coordinator - # Application code injects this when worker threads must restore request - # or framework-local state. It is intentionally excluded from snapshots. - self._execution_context = execution_context if execution_context is not None else nullcontext(None) - self._pending_response_coordinator_dump: str | None = None - self._pending_graph_execution_workflow_id: str | None = None - self._paused_nodes: set[str] = set() - self._deferred_nodes: set[str] = set() - self._child_engine_builder: ChildGraphEngineBuilderProtocol | None = None - - # Node and edges states needed to be restored into - # graph object. - # - # These two fields are non-None only when resuming from a snapshot. - # Once the graph is attached, these two fields will be set to None. - self._pending_graph_node_states: dict[str, NodeState] | None = None - self._pending_graph_edge_states: dict[str, NodeState] | None = None - - if graph is not None: - self.attach_graph(graph) - - # ------------------------------------------------------------------ - # Context binding helpers - # ------------------------------------------------------------------ - def attach_graph(self, graph: GraphProtocol) -> None: - """Attach the materialized graph to the runtime state.""" - if self._graph is not None and self._graph is not graph: - raise ValueError("GraphRuntimeState already attached to a different graph instance") - - self._graph = graph - - if self._response_coordinator is None: - self._response_coordinator = self._build_response_coordinator(graph) - - if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None: - self._response_coordinator.loads(self._pending_response_coordinator_dump) - self._pending_response_coordinator_dump = None - self._apply_pending_graph_state() - - def configure(self, *, graph: GraphProtocol | None = None) -> None: - """Ensure core collaborators are initialized with the provided context.""" - if graph is not None: - self.attach_graph(graph) - - # Ensure collaborators are instantiated - _ = self.ready_queue - _ = self.graph_execution - if self._graph is not None: - _ = self.response_coordinator - - def bind_child_engine_builder(self, builder: ChildGraphEngineBuilderProtocol) -> None: - self._child_engine_builder = builder - - def create_child_engine( - self, - *, - workflow_id: str, - graph_init_params: GraphInitParams, - root_node_id: str, - variable_pool: VariablePool | None = None, - ) -> Any: - """Create a child graph engine that derives its runtime state from the parent.""" - if self._child_engine_builder is None: - raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.") - - return self._child_engine_builder.build_child_engine( - workflow_id=workflow_id, - graph_init_params=graph_init_params, - parent_graph_runtime_state=self, - root_node_id=root_node_id, - variable_pool=variable_pool, - ) - - # ------------------------------------------------------------------ - # Primary collaborators - # ------------------------------------------------------------------ - @property - def variable_pool(self) -> VariablePool: - return self._variable_pool - - @property - def ready_queue(self) -> ReadyQueueProtocol: - if self._ready_queue is None: - self._ready_queue = self._build_ready_queue() - return self._ready_queue - - @property - def graph_execution(self) -> GraphExecutionProtocol: - if self._graph_execution is None: - self._graph_execution = self._build_graph_execution() - return self._graph_execution - - @property - def response_coordinator(self) -> ResponseStreamCoordinatorProtocol: - if self._response_coordinator is None: - if self._graph is None: - raise ValueError("Graph must be attached before accessing response coordinator") - self._response_coordinator = self._build_response_coordinator(self._graph) - return self._response_coordinator - - @property - def execution_context(self) -> AbstractContextManager[object]: - return self._execution_context - - @execution_context.setter - def execution_context(self, value: AbstractContextManager[object] | None) -> None: - self._execution_context = value if value is not None else nullcontext(None) - - # ------------------------------------------------------------------ - # Scalar state - # ------------------------------------------------------------------ - @property - def start_at(self) -> float: - return self._start_at - - @start_at.setter - def start_at(self, value: float) -> None: - self._start_at = value - - @property - def total_tokens(self) -> int: - return self._total_tokens - - @total_tokens.setter - def total_tokens(self, value: int) -> None: - if value < 0: - raise ValueError("total_tokens must be non-negative") - self._total_tokens = value - - @property - def llm_usage(self) -> LLMUsage: - return self._llm_usage.model_copy() - - @llm_usage.setter - def llm_usage(self, value: LLMUsage) -> None: - self._llm_usage = value.model_copy() - - @property - def outputs(self) -> dict[str, Any]: - return deepcopy(self._outputs) - - @outputs.setter - def outputs(self, value: dict[str, Any]) -> None: - self._outputs = deepcopy(value) - - def set_output(self, key: str, value: object) -> None: - self._outputs[key] = deepcopy(value) - - def get_output(self, key: str, default: object = None) -> object: - return deepcopy(self._outputs.get(key, default)) - - def update_outputs(self, updates: dict[str, object]) -> None: - for key, value in updates.items(): - self._outputs[key] = deepcopy(value) - - @property - def node_run_steps(self) -> int: - return self._node_run_steps - - @node_run_steps.setter - def node_run_steps(self, value: int) -> None: - if value < 0: - raise ValueError("node_run_steps must be non-negative") - self._node_run_steps = value - - def increment_node_run_steps(self) -> None: - self._node_run_steps += 1 - - def add_tokens(self, tokens: int) -> None: - if tokens < 0: - raise ValueError("tokens must be non-negative") - self._total_tokens += tokens - - # ------------------------------------------------------------------ - # Serialization - # ------------------------------------------------------------------ - def dumps(self) -> str: - """Serialize runtime state into a JSON string.""" - - snapshot: dict[str, Any] = { - "version": "1.0", - "start_at": self._start_at, - "total_tokens": self._total_tokens, - "node_run_steps": self._node_run_steps, - "llm_usage": self._llm_usage.model_dump(mode="json"), - "outputs": self.outputs, - "variable_pool": self.variable_pool.model_dump(mode="json"), - "ready_queue": self.ready_queue.dumps(), - "graph_execution": self.graph_execution.dumps(), - "paused_nodes": list(self._paused_nodes), - "deferred_nodes": list(self._deferred_nodes), - } - - graph_state = self._snapshot_graph_state() - if graph_state is not None: - snapshot["graph_state"] = graph_state - - if self._response_coordinator is not None and self._graph is not None: - snapshot["response_coordinator"] = self._response_coordinator.dumps() - - return json.dumps(snapshot, default=pydantic_encoder) - - @classmethod - def from_snapshot(cls, data: str | Mapping[str, Any]) -> GraphRuntimeState: - """Restore runtime state from a serialized snapshot.""" - - snapshot = cls._parse_snapshot_payload(data) - - state = cls( - variable_pool=snapshot.variable_pool, - start_at=snapshot.start_at, - total_tokens=snapshot.total_tokens, - llm_usage=snapshot.llm_usage, - outputs=snapshot.outputs, - node_run_steps=snapshot.node_run_steps, - ) - state._apply_snapshot(snapshot) - return state - - def loads(self, data: str | Mapping[str, Any]) -> None: - """Restore runtime state from a serialized snapshot (legacy API).""" - - snapshot = self._parse_snapshot_payload(data) - self._apply_snapshot(snapshot) - - def register_paused_node(self, node_id: str) -> None: - """Record a node that should resume when execution is continued.""" - - self._paused_nodes.add(node_id) - - def get_paused_nodes(self) -> list[str]: - """Retrieve the list of paused nodes without mutating internal state.""" - - return list(self._paused_nodes) - - def consume_paused_nodes(self) -> list[str]: - """Retrieve and clear the list of paused nodes awaiting resume.""" - - nodes = list(self._paused_nodes) - self._paused_nodes.clear() - return nodes - - def register_deferred_node(self, node_id: str) -> None: - """Record a node that became ready during pause and should resume later.""" - - self._deferred_nodes.add(node_id) - - def get_deferred_nodes(self) -> list[str]: - """Retrieve deferred nodes without mutating internal state.""" - - return list(self._deferred_nodes) - - def consume_deferred_nodes(self) -> list[str]: - """Retrieve and clear deferred nodes awaiting resume.""" - - nodes = list(self._deferred_nodes) - self._deferred_nodes.clear() - return nodes - - # ------------------------------------------------------------------ - # Builders - # ------------------------------------------------------------------ - def _build_ready_queue(self) -> ReadyQueueProtocol: - # Import lazily to avoid breaching architecture boundaries enforced by import-linter. - module = importlib.import_module("graphon.graph_engine.ready_queue") - in_memory_cls = module.InMemoryReadyQueue - return in_memory_cls() - - def _build_graph_execution(self) -> GraphExecutionProtocol: - # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("graphon.graph_engine.domain.graph_execution") - graph_execution_cls = module.GraphExecution - workflow_id = self._pending_graph_execution_workflow_id or "" - self._pending_graph_execution_workflow_id = None - return graph_execution_cls(workflow_id=workflow_id) # type: ignore[invalid-return-type] - - def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol: - # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("graphon.graph_engine.response_coordinator") - coordinator_cls = module.ResponseStreamCoordinator - return coordinator_cls(variable_pool=self.variable_pool, graph=graph) - - # ------------------------------------------------------------------ - # Snapshot helpers - # ------------------------------------------------------------------ - @classmethod - def _parse_snapshot_payload(cls, data: str | Mapping[str, Any]) -> _GraphRuntimeStateSnapshot: - payload: dict[str, Any] - if isinstance(data, str): - payload = json.loads(data) - else: - payload = dict(data) - - version = payload.get("version") - if version != "1.0": - raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}") - - start_at = float(payload.get("start_at", 0.0)) - - total_tokens = int(payload.get("total_tokens", 0)) - if total_tokens < 0: - raise ValueError("total_tokens must be non-negative") - - node_run_steps = int(payload.get("node_run_steps", 0)) - if node_run_steps < 0: - raise ValueError("node_run_steps must be non-negative") - - llm_usage_payload = payload.get("llm_usage", {}) - llm_usage = LLMUsage.model_validate(llm_usage_payload) - - outputs_payload = deepcopy(payload.get("outputs", {})) - - variable_pool_payload = payload.get("variable_pool") - has_variable_pool = variable_pool_payload is not None - variable_pool = VariablePool.model_validate(variable_pool_payload) if has_variable_pool else VariablePool() - - ready_queue_payload = payload.get("ready_queue") - graph_execution_payload = payload.get("graph_execution") - response_payload = payload.get("response_coordinator") - paused_nodes_payload = payload.get("paused_nodes", []) - deferred_nodes_payload = payload.get("deferred_nodes", []) - graph_state_payload = payload.get("graph_state", {}) or {} - graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes") - graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges") - - return _GraphRuntimeStateSnapshot( - start_at=start_at, - total_tokens=total_tokens, - node_run_steps=node_run_steps, - llm_usage=llm_usage, - outputs=outputs_payload, - variable_pool=variable_pool, - has_variable_pool=has_variable_pool, - ready_queue_dump=ready_queue_payload, - graph_execution_dump=graph_execution_payload, - response_coordinator_dump=response_payload, - paused_nodes=tuple(map(str, paused_nodes_payload)), - deferred_nodes=tuple(map(str, deferred_nodes_payload)), - graph_node_states=graph_node_states, - graph_edge_states=graph_edge_states, - ) - - def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None: - self._start_at = snapshot.start_at - self._total_tokens = snapshot.total_tokens - self._node_run_steps = snapshot.node_run_steps - self._llm_usage = snapshot.llm_usage.model_copy() - self._outputs = deepcopy(snapshot.outputs) - if snapshot.has_variable_pool or self._variable_pool is None: - self._variable_pool = snapshot.variable_pool - - self._restore_ready_queue(snapshot.ready_queue_dump) - self._restore_graph_execution(snapshot.graph_execution_dump) - self._restore_response_coordinator(snapshot.response_coordinator_dump) - self._paused_nodes = set(snapshot.paused_nodes) - self._deferred_nodes = set(snapshot.deferred_nodes) - self._pending_graph_node_states = snapshot.graph_node_states or None - self._pending_graph_edge_states = snapshot.graph_edge_states or None - self._apply_pending_graph_state() - - def _restore_ready_queue(self, payload: str | None) -> None: - if payload is not None: - self._ready_queue = self._build_ready_queue() - self._ready_queue.loads(payload) - else: - self._ready_queue = None - - def _restore_graph_execution(self, payload: str | None) -> None: - self._graph_execution = None - self._pending_graph_execution_workflow_id = None - - if payload is None: - return - - try: - execution_payload = json.loads(payload) - self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id") - except (json.JSONDecodeError, TypeError, AttributeError): - self._pending_graph_execution_workflow_id = None - - self.graph_execution.loads(payload) - - def _restore_response_coordinator(self, payload: str | None) -> None: - if payload is None: - self._pending_response_coordinator_dump = None - self._response_coordinator = None - return - - if self._graph is not None: - self.response_coordinator.loads(payload) - self._pending_response_coordinator_dump = None - return - - self._pending_response_coordinator_dump = payload - self._response_coordinator = None - - def _snapshot_graph_state(self) -> _GraphStateSnapshot: - graph = self._graph - if graph is None: - if self._pending_graph_node_states is None and self._pending_graph_edge_states is None: - return _GraphStateSnapshot() - return _GraphStateSnapshot( - nodes=self._pending_graph_node_states or {}, - edges=self._pending_graph_edge_states or {}, - ) - - nodes = graph.nodes - edges = graph.edges - if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping): - return _GraphStateSnapshot() - - node_states = {} - for node_id, node in nodes.items(): - if not isinstance(node_id, str): - continue - node_states[node_id] = node.state - - edge_states = {} - for edge_id, edge in edges.items(): - if not isinstance(edge_id, str): - continue - edge_states[edge_id] = edge.state - - return _GraphStateSnapshot(nodes=node_states, edges=edge_states) - - def _apply_pending_graph_state(self) -> None: - if self._graph is None: - return - if self._pending_graph_node_states: - for node_id, state in self._pending_graph_node_states.items(): - node = self._graph.nodes.get(node_id) - if node is None: - continue - node.state = state - if self._pending_graph_edge_states: - for edge_id, state in self._pending_graph_edge_states.items(): - edge = self._graph.edges.get(edge_id) - if edge is None: - continue - edge.state = state - - self._pending_graph_node_states = None - self._pending_graph_edge_states = None - - -def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]: - if not isinstance(payload, Mapping): - return {} - raw_map = payload.get(key, {}) - if not isinstance(raw_map, Mapping): - return {} - result: dict[str, NodeState] = {} - for node_id, raw_state in raw_map.items(): - if not isinstance(node_id, str): - continue - try: - result[node_id] = NodeState(str(raw_state)) - except ValueError: - continue - return result diff --git a/api/graphon/runtime/graph_runtime_state_protocol.py b/api/graphon/runtime/graph_runtime_state_protocol.py deleted file mode 100644 index 856625a5d3..0000000000 --- a/api/graphon/runtime/graph_runtime_state_protocol.py +++ /dev/null @@ -1,79 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Protocol - -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.variables.segments import Segment - - -class ReadOnlyVariablePool(Protocol): - """Read-only interface for VariablePool.""" - - def get(self, selector: Sequence[str], /) -> Segment | None: - """Get a variable value (read-only).""" - ... - - def get_all_by_node(self, node_id: str) -> Mapping[str, object]: - """Get all variables for a node (read-only).""" - ... - - def get_by_prefix(self, prefix: str) -> Mapping[str, object]: - """Get all variables stored under a given node prefix (read-only).""" - ... - - -class ReadOnlyGraphRuntimeState(Protocol): - """ - Read-only view of GraphRuntimeState for layers. - - This protocol defines a read-only interface that prevents layers from - modifying the graph runtime state while still allowing observation. - All methods return defensive copies to ensure immutability. - """ - - @property - def variable_pool(self) -> ReadOnlyVariablePool: - """Get read-only access to the variable pool.""" - ... - - @property - def start_at(self) -> float: - """Get the start time (read-only).""" - ... - - @property - def total_tokens(self) -> int: - """Get the total tokens count (read-only).""" - ... - - @property - def llm_usage(self) -> LLMUsage: - """Get a copy of LLM usage info (read-only).""" - ... - - @property - def outputs(self) -> dict[str, Any]: - """Get a defensive copy of outputs (read-only).""" - ... - - @property - def node_run_steps(self) -> int: - """Get the node run steps count (read-only).""" - ... - - @property - def ready_queue_size(self) -> int: - """Get the number of nodes currently in the ready queue.""" - ... - - @property - def exceptions_count(self) -> int: - """Get the number of node execution exceptions recorded.""" - ... - - def get_output(self, key: str, default: Any = None) -> Any: - """Get a single output value (returns a copy).""" - ... - - def dumps(self) -> str: - """Serialize the runtime state into a JSON snapshot (read-only).""" - ... diff --git a/api/graphon/runtime/read_only_wrappers.py b/api/graphon/runtime/read_only_wrappers.py deleted file mode 100644 index aaef255204..0000000000 --- a/api/graphon/runtime/read_only_wrappers.py +++ /dev/null @@ -1,82 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from copy import deepcopy -from typing import Any - -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.variables.segments import Segment - -from .graph_runtime_state import GraphRuntimeState -from .variable_pool import VariablePool - - -class ReadOnlyVariablePoolWrapper: - """Provide defensive, read-only access to ``VariablePool``.""" - - def __init__(self, variable_pool: VariablePool) -> None: - self._variable_pool = variable_pool - - def get(self, selector: Sequence[str], /) -> Segment | None: - """Return a copy of a variable value if present.""" - value = self._variable_pool.get(selector) - return deepcopy(value) if value is not None else None - - def get_all_by_node(self, node_id: str) -> Mapping[str, object]: - """Return a copy of all variables for the specified node.""" - variables: dict[str, object] = {} - if node_id in self._variable_pool.variable_dictionary: - for key, variable in self._variable_pool.variable_dictionary[node_id].items(): - variables[key] = deepcopy(variable.value) - return variables - - def get_by_prefix(self, prefix: str) -> Mapping[str, object]: - """Return a copy of all variables stored under the given prefix.""" - return self._variable_pool.get_by_prefix(prefix) - - -class ReadOnlyGraphRuntimeStateWrapper: - """Expose a defensive, read-only view of ``GraphRuntimeState``.""" - - def __init__(self, state: GraphRuntimeState) -> None: - self._state = state - self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool) - - @property - def variable_pool(self) -> ReadOnlyVariablePoolWrapper: - return self._variable_pool_wrapper - - @property - def start_at(self) -> float: - return self._state.start_at - - @property - def total_tokens(self) -> int: - return self._state.total_tokens - - @property - def llm_usage(self) -> LLMUsage: - return self._state.llm_usage.model_copy() - - @property - def outputs(self) -> dict[str, Any]: - return deepcopy(self._state.outputs) - - @property - def node_run_steps(self) -> int: - return self._state.node_run_steps - - @property - def ready_queue_size(self) -> int: - return self._state.ready_queue.qsize() - - @property - def exceptions_count(self) -> int: - return self._state.graph_execution.exceptions_count - - def get_output(self, key: str, default: Any = None) -> Any: - return self._state.get_output(key, default) - - def dumps(self) -> str: - """Serialize the underlying runtime state for external persistence.""" - return self._state.dumps() diff --git a/api/graphon/runtime/variable_pool.py b/api/graphon/runtime/variable_pool.py deleted file mode 100644 index b44d1a8abe..0000000000 --- a/api/graphon/runtime/variable_pool.py +++ /dev/null @@ -1,279 +0,0 @@ -from __future__ import annotations - -import re -from collections import defaultdict -from collections.abc import Mapping, Sequence -from copy import deepcopy -from typing import Annotated, Any, Union, cast - -from pydantic import BaseModel, Field, model_validator - -from graphon.file import File, FileAttribute, file_manager -from graphon.variables import Segment, SegmentGroup, VariableBase, build_segment, segment_to_variable -from graphon.variables.consts import SELECTORS_LENGTH -from graphon.variables.segments import FileSegment, ObjectSegment -from graphon.variables.variables import RAGPipelineVariableInput, Variable - -VariableValue = Union[str, int, float, dict[str, object], list[object], File] - -VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") - - -def _default_variable_dictionary() -> defaultdict[str, dict[str, Variable]]: - return defaultdict(dict) - - -class VariablePool(BaseModel): - _SYSTEM_VARIABLE_NODE_ID = "sys" - _ENVIRONMENT_VARIABLE_NODE_ID = "env" - _CONVERSATION_VARIABLE_NODE_ID = "conversation" - _RAG_PIPELINE_VARIABLE_NODE_ID = "rag" - - # Variable dictionary is a dictionary for looking up variables by their selector. - # The first element of the selector is the node id, it's the first-level key in the dictionary. - # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the - # elements of the selector except the first one. - variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field( - description="Variables mapping", - default_factory=_default_variable_dictionary, - ) - system_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) - environment_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) - conversation_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) - rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = Field(default_factory=tuple, exclude=True) - user_inputs: Mapping[str, Any] = Field(default_factory=dict, exclude=True) - - @model_validator(mode="after") - def _load_legacy_bootstrap_inputs(self) -> VariablePool: - """ - Accept legacy constructor kwargs that still appear throughout the workflow - layer while keeping serialized state focused on `variable_dictionary`. - """ - - self._ingest_legacy_variables(self.system_variables, node_id=self._SYSTEM_VARIABLE_NODE_ID) - self._ingest_legacy_variables(self.environment_variables, node_id=self._ENVIRONMENT_VARIABLE_NODE_ID) - self._ingest_legacy_variables(self.conversation_variables, node_id=self._CONVERSATION_VARIABLE_NODE_ID) - self._ingest_legacy_rag_variables(self.rag_pipeline_variables) - - # These kwargs are accepted for compatibility but should not affect the - # stable serialized form or model equality. - self.system_variables = () - self.environment_variables = () - self.conversation_variables = () - self.rag_pipeline_variables = () - self.user_inputs = {} - return self - - def _ingest_legacy_variables(self, variables: Sequence[Variable], *, node_id: str) -> None: - for variable in variables: - selector = [node_id, variable.name] - normalized_variable = variable - if list(variable.selector) != selector: - normalized_variable = variable.model_copy(update={"selector": selector}) - self.add(normalized_variable.selector, normalized_variable) - - def _ingest_legacy_rag_variables(self, rag_pipeline_variables: Sequence[RAGPipelineVariableInput]) -> None: - if not rag_pipeline_variables: - return - - values_by_node_id: defaultdict[str, dict[str, Any]] = defaultdict(dict) - for rag_variable_input in rag_pipeline_variables: - values_by_node_id[rag_variable_input.variable.belong_to_node_id][rag_variable_input.variable.variable] = ( - rag_variable_input.value - ) - - for node_id, value in values_by_node_id.items(): - self.add((self._RAG_PIPELINE_VARIABLE_NODE_ID, node_id), value) - - def add(self, selector: Sequence[str], value: Any, /): - """ - Add a variable to the variable pool. - - This method accepts a selector path and a value, converting the value - to a Variable object if necessary before storing it in the pool. - - Args: - selector: A two-element sequence containing [node_id, variable_name]. - The selector must have exactly 2 elements to be valid. - value: The value to store. Can be a Variable, Segment, or any value - that can be converted to a Segment (str, int, float, dict, list, File). - - Raises: - ValueError: If selector length is not exactly 2 elements. - - Note: - While non-Segment values are currently accepted and automatically - converted, it's recommended to pass Segment or Variable objects directly. - """ - if len(selector) != SELECTORS_LENGTH: - raise ValueError( - f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), " - f"got {len(selector)} elements" - ) - - if isinstance(value, VariableBase): - variable = value - elif isinstance(value, Segment): - variable = segment_to_variable(segment=value, selector=selector) - else: - segment = build_segment(value) - variable = segment_to_variable(segment=segment, selector=selector) - - node_id, name = self._selector_to_keys(selector) - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - self.variable_dictionary[node_id][name] = cast(Variable, variable) - - @classmethod - def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]: - return selector[0], selector[1] - - def _has(self, selector: Sequence[str]) -> bool: - node_id, name = self._selector_to_keys(selector) - if node_id not in self.variable_dictionary: - return False - if name not in self.variable_dictionary[node_id]: - return False - return True - - def get(self, selector: Sequence[str], /) -> Segment | None: - """ - Retrieve a variable's value from the pool as a Segment. - - This method supports both simple selectors [node_id, variable_name] and - extended selectors that include attribute access for FileSegment and - ObjectSegment types. - - Args: - selector: A sequence with at least 2 elements: - - [node_id, variable_name]: Returns the full segment - - [node_id, variable_name, attr, ...]: Returns a nested value - from FileSegment (e.g., 'url', 'name') or ObjectSegment - - Returns: - The Segment associated with the selector, or None if not found. - Returns None if selector has fewer than 2 elements. - - Raises: - ValueError: If attempting to access an invalid FileAttribute. - """ - if len(selector) < SELECTORS_LENGTH: - return None - - node_id, name = self._selector_to_keys(selector) - node_map = self.variable_dictionary.get(node_id) - if node_map is None: - return None - - segment: Segment | None = node_map.get(name) - - if segment is None: - return None - - if len(selector) == 2: - return segment - - if isinstance(segment, FileSegment): - attr = selector[2] - # Python support `attr in FileAttribute` after 3.12 - if attr not in {item.value for item in FileAttribute}: - return None - attr = FileAttribute(attr) - attr_value = file_manager.get_attr(file=segment.value, attr=attr) - return build_segment(attr_value) - - # Navigate through nested attributes - result: Any = segment - for attr in selector[2:]: - result = self._extract_value(result) - result = self._get_nested_attribute(result, attr) - if result is None: - return None - - # Return result as Segment - return result if isinstance(result, Segment) else build_segment(result) - - def _extract_value(self, obj: Any): - """Extract the actual value from an ObjectSegment.""" - return obj.value if isinstance(obj, ObjectSegment) else obj - - def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None: - """ - Get a nested attribute from a dictionary-like object. - - Args: - obj: The dictionary-like object to search. - attr: The key to look up. - - Returns: - Segment | None: - The corresponding Segment built from the attribute value if the key exists, - otherwise None. - """ - if not isinstance(obj, dict) or attr not in obj: - return None - return build_segment(obj.get(attr)) - - def remove(self, selector: Sequence[str], /): - """ - Remove variables from the variable pool based on the given selector. - - Args: - selector (Sequence[str]): A sequence of strings representing the selector. - - Returns: - None - """ - if not selector: - return - if len(selector) == 1: - self.variable_dictionary[selector[0]] = {} - return - key, hash_key = self._selector_to_keys(selector) - self.variable_dictionary[key].pop(hash_key, None) - - def convert_template(self, template: str, /): - parts = VARIABLE_PATTERN.split(template) - segments: list[Segment] = [] - for part in filter(lambda x: x, parts): - if "." in part and (variable := self.get(part.split("."))): - segments.append(variable) - else: - segments.append(build_segment(part)) - return SegmentGroup(value=segments) - - def get_file(self, selector: Sequence[str], /) -> FileSegment | None: - segment = self.get(selector) - if isinstance(segment, FileSegment): - return segment - return None - - def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: - """Return a copy of all variables stored under the given node prefix.""" - - nodes = self.variable_dictionary.get(prefix) - if not nodes: - return {} - - result: dict[str, object] = {} - for key, variable in nodes.items(): - value = variable.value - result[key] = deepcopy(value) - - return result - - def flatten(self, *, unprefixed_node_id: str | None = None) -> Mapping[str, object]: - """Return a selector-style snapshot of the entire variable pool.""" - - result: dict[str, object] = {} - for node_id, variables in self.variable_dictionary.items(): - for name, variable in variables.items(): - output_name = name if node_id == unprefixed_node_id else f"{node_id}.{name}" - result[output_name] = deepcopy(variable.value) - - return result - - @classmethod - def empty(cls) -> VariablePool: - """Create an empty variable pool.""" - return cls() diff --git a/api/graphon/template_rendering.py b/api/graphon/template_rendering.py deleted file mode 100644 index 0527e58f6d..0000000000 --- a/api/graphon/template_rendering.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from collections.abc import Mapping -from typing import Any - - -class TemplateRenderError(ValueError): - """Raised when rendering a template fails.""" - - -class Jinja2TemplateRenderer(ABC): - """Nominal renderer contract for Jinja2 template rendering in graph nodes.""" - - @abstractmethod - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - """Render the template into plain text.""" - raise NotImplementedError diff --git a/api/graphon/utils/__init__.py b/api/graphon/utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/utils/condition/__init__.py b/api/graphon/utils/condition/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/utils/condition/entities.py b/api/graphon/utils/condition/entities.py deleted file mode 100644 index 77a214571a..0000000000 --- a/api/graphon/utils/condition/entities.py +++ /dev/null @@ -1,49 +0,0 @@ -from collections.abc import Sequence -from typing import Literal - -from pydantic import BaseModel, Field - -SupportedComparisonOperator = Literal[ - # for string or array - "contains", - "not contains", - "start with", - "end with", - "is", - "is not", - "empty", - "not empty", - "in", - "not in", - "all of", - # for number - "=", - "≠", - ">", - "<", - "≥", - "≤", - "null", - "not null", - # for file - "exists", - "not exists", -] - - -class SubCondition(BaseModel): - key: str - comparison_operator: SupportedComparisonOperator - value: str | Sequence[str] | None = None - - -class SubVariableCondition(BaseModel): - logical_operator: Literal["and", "or"] - conditions: list[SubCondition] = Field(default_factory=list) - - -class Condition(BaseModel): - variable_selector: list[str] - comparison_operator: SupportedComparisonOperator - value: str | Sequence[str] | bool | None = None - sub_variable_condition: SubVariableCondition | None = None diff --git a/api/graphon/utils/condition/processor.py b/api/graphon/utils/condition/processor.py deleted file mode 100644 index 03535927cb..0000000000 --- a/api/graphon/utils/condition/processor.py +++ /dev/null @@ -1,504 +0,0 @@ -import json -from collections.abc import Mapping, Sequence -from typing import Literal, NamedTuple - -from graphon.file import FileAttribute, file_manager -from graphon.runtime import VariablePool -from graphon.variables import ArrayFileSegment -from graphon.variables.segments import ArrayBooleanSegment, BooleanSegment - -from .entities import Condition, SubCondition, SupportedComparisonOperator - - -def _convert_to_bool(value: object) -> bool: - if isinstance(value, int): - return bool(value) - - if isinstance(value, str): - loaded = json.loads(value) - if isinstance(loaded, (int, bool)): - return bool(loaded) - - raise TypeError(f"unexpected value: type={type(value)}, value={value}") - - -class ConditionCheckResult(NamedTuple): - inputs: Sequence[Mapping[str, object]] - group_results: Sequence[bool] - final_result: bool - - -class ConditionProcessor: - def process_conditions( - self, - *, - variable_pool: VariablePool, - conditions: Sequence[Condition], - operator: Literal["and", "or"], - ) -> ConditionCheckResult: - input_conditions: list[Mapping[str, object]] = [] - group_results: list[bool] = [] - - for condition in conditions: - variable = variable_pool.get(condition.variable_selector) - if variable is None: - raise ValueError(f"Variable {condition.variable_selector} not found") - - if isinstance(variable, ArrayFileSegment) and condition.comparison_operator in { - "contains", - "not contains", - "all of", - }: - # check sub conditions - if not condition.sub_variable_condition: - raise ValueError("Sub variable is required") - result = _process_sub_conditions( - variable=variable, - sub_conditions=condition.sub_variable_condition.conditions, - operator=condition.sub_variable_condition.logical_operator, - ) - elif condition.comparison_operator in { - "exists", - "not exists", - }: - result = _evaluate_condition( - value=variable.value, - operator=condition.comparison_operator, - expected=None, - ) - else: - actual_value = variable.value if variable else None - expected_value: str | Sequence[str] | bool | list[bool] | None = condition.value - if isinstance(expected_value, str): - expected_value = variable_pool.convert_template(expected_value).text - # Here we need to explicit convet the input string to boolean. - if isinstance(variable, (BooleanSegment, ArrayBooleanSegment)) and expected_value is not None: - # The following two lines is for compatibility with existing workflows. - if isinstance(expected_value, list): - expected_value = [_convert_to_bool(i) for i in expected_value] - else: - expected_value = _convert_to_bool(expected_value) - input_conditions.append( - { - "actual_value": actual_value, - "expected_value": expected_value, - "comparison_operator": condition.comparison_operator, - } - ) - result = _evaluate_condition( - value=actual_value, - operator=condition.comparison_operator, - expected=expected_value, - ) - group_results.append(result) - # Implemented short-circuit evaluation for logical conditions - if (operator == "and" and not result) or (operator == "or" and result): - final_result = result - return ConditionCheckResult(input_conditions, group_results, final_result) - - final_result = all(group_results) if operator == "and" else any(group_results) - return ConditionCheckResult(input_conditions, group_results, final_result) - - -def _evaluate_condition( - *, - operator: SupportedComparisonOperator, - value: object, - expected: str | Sequence[str] | bool | Sequence[bool] | None, -) -> bool: - match operator: - case "contains": - return _assert_contains(value=value, expected=expected) - case "not contains": - return _assert_not_contains(value=value, expected=expected) - case "start with": - return _assert_start_with(value=value, expected=expected) - case "end with": - return _assert_end_with(value=value, expected=expected) - case "is": - return _assert_is(value=value, expected=expected) - case "is not": - return _assert_is_not(value=value, expected=expected) - case "empty": - return _assert_empty(value=value) - case "not empty": - return _assert_not_empty(value=value) - case "=": - return _assert_equal(value=value, expected=expected) - case "≠": - return _assert_not_equal(value=value, expected=expected) - case ">": - return _assert_greater_than(value=value, expected=expected) - case "<": - return _assert_less_than(value=value, expected=expected) - case "≥": - return _assert_greater_than_or_equal(value=value, expected=expected) - case "≤": - return _assert_less_than_or_equal(value=value, expected=expected) - case "null": - return _assert_null(value=value) - case "not null": - return _assert_not_null(value=value) - case "in": - return _assert_in(value=value, expected=expected) - case "not in": - return _assert_not_in(value=value, expected=expected) - case "all of" if isinstance(expected, list): - # Type narrowing: at this point expected is a list, could be list[str] or list[bool] - if all(isinstance(item, str) for item in expected): - # Create a new typed list to satisfy type checker - str_list: list[str] = [item for item in expected if isinstance(item, str)] - return _assert_all_of(value=value, expected=str_list) - elif all(isinstance(item, bool) for item in expected): - # Create a new typed list to satisfy type checker - bool_list: list[bool] = [item for item in expected if isinstance(item, bool)] - return _assert_all_of_bool(value=value, expected=bool_list) - else: - raise ValueError("all of operator expects homogeneous list of strings or booleans") - case "exists": - return _assert_exists(value=value) - case "not exists": - return _assert_not_exists(value=value) - case _: - raise ValueError(f"Unsupported operator: {operator}") - - -def _assert_contains(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(value, (str, list)): - raise ValueError("Invalid actual value type: string or array") - - # Type checking ensures value is str or list at this point - if isinstance(value, str): - if not isinstance(expected, str): - expected = str(expected) - if expected not in value: - return False - else: # value is list - if expected not in value: - return False - return True - - -def _assert_not_contains(*, value: object, expected: object) -> bool: - if not value: - return True - - if not isinstance(value, (str, list)): - raise ValueError("Invalid actual value type: string or array") - - # Type checking ensures value is str or list at this point - if isinstance(value, str): - if not isinstance(expected, str): - expected = str(expected) - if expected in value: - return False - else: # value is list - if expected in value: - return False - return True - - -def _assert_start_with(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(value, str): - raise ValueError("Invalid actual value type: string") - - if not isinstance(expected, str): - raise ValueError("Expected value must be a string for startswith") - if not value.startswith(expected): - return False - return True - - -def _assert_end_with(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(value, str): - raise ValueError("Invalid actual value type: string") - - if not isinstance(expected, str): - raise ValueError("Expected value must be a string for endswith") - if not value.endswith(expected): - return False - return True - - -def _assert_is(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (str, bool)): - raise ValueError("Invalid actual value type: string or boolean") - - if value != expected: - return False - return True - - -def _assert_is_not(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (str, bool)): - raise ValueError("Invalid actual value type: string or boolean") - - if value == expected: - return False - return True - - -def _assert_empty(*, value: object) -> bool: - if not value: - return True - return False - - -def _assert_not_empty(*, value: object) -> bool: - if value: - return True - return False - - -def _normalize_numeric_values(value: int | float, expected: object) -> tuple[int | float, int | float]: - """ - Normalize value and expected to compatible numeric types for comparison. - - Args: - value: The actual numeric value (int or float) - expected: The expected value (int, float, or str) - - Returns: - A tuple of (normalized_value, normalized_expected) with compatible types - - Raises: - ValueError: If expected cannot be converted to a number - """ - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to number") - - # Convert expected to appropriate numeric type - if isinstance(expected, str): - # Try to convert to float first to handle decimal strings - try: - expected_float = float(expected) - except ValueError as e: - raise ValueError(f"Cannot convert '{expected}' to number") from e - - # If value is int and expected is a whole number, keep as int comparison - if isinstance(value, int) and expected_float.is_integer(): - return value, int(expected_float) - else: - # Otherwise convert value to float for comparison - return float(value) if isinstance(value, int) else value, expected_float - elif isinstance(expected, float): - # If expected is already float, convert int value to float - return float(value) if isinstance(value, int) else value, expected - else: - # expected is int - return value, expected - - -def _assert_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float, bool)): - raise ValueError("Invalid actual value type: number or boolean") - - # Handle boolean comparison - if isinstance(value, bool): - if not isinstance(expected, (bool, int, str)): - raise ValueError(f"Cannot convert {type(expected)} to bool") - expected = bool(expected) - elif isinstance(value, int): - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to int") - expected = int(expected) - else: - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to float") - expected = float(expected) - - if value != expected: - return False - return True - - -def _assert_not_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float, bool)): - raise ValueError("Invalid actual value type: number or boolean") - - # Handle boolean comparison - if isinstance(value, bool): - if not isinstance(expected, (bool, int, str)): - raise ValueError(f"Cannot convert {type(expected)} to bool") - expected = bool(expected) - elif isinstance(value, int): - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to int") - expected = int(expected) - else: - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to float") - expected = float(expected) - - if value == expected: - return False - return True - - -def _assert_greater_than(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value > expected - - -def _assert_less_than(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value < expected - - -def _assert_greater_than_or_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value >= expected - - -def _assert_less_than_or_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value <= expected - - -def _assert_null(*, value: object) -> bool: - if value is None: - return True - return False - - -def _assert_not_null(*, value: object) -> bool: - if value is not None: - return True - return False - - -def _assert_in(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(expected, list): - raise ValueError("Invalid expected value type: array") - - if value not in expected: - return False - return True - - -def _assert_not_in(*, value: object, expected: object) -> bool: - if not value: - return True - - if not isinstance(expected, list): - raise ValueError("Invalid expected value type: array") - - if value in expected: - return False - return True - - -def _assert_all_of(*, value: object, expected: Sequence[str]) -> bool: - if not value: - return False - - # Ensure value is a container that supports 'in' operator - if not isinstance(value, (list, tuple, set, str)): - return False - - return all(item in value for item in expected) - - -def _assert_all_of_bool(*, value: object, expected: Sequence[bool]) -> bool: - if not value: - return False - - # Ensure value is a container that supports 'in' operator - if not isinstance(value, (list, tuple, set)): - return False - - return all(item in value for item in expected) - - -def _assert_exists(*, value: object) -> bool: - return value is not None - - -def _assert_not_exists(*, value: object) -> bool: - return value is None - - -def _process_sub_conditions( - variable: ArrayFileSegment, - sub_conditions: Sequence[SubCondition], - operator: Literal["and", "or"], -) -> bool: - files = variable.value - group_results: list[bool] = [] - for condition in sub_conditions: - key = FileAttribute(condition.key) - values = [file_manager.get_attr(file=file, attr=key) for file in files] - expected_value = condition.value - if key == FileAttribute.EXTENSION: - if not isinstance(expected_value, str): - raise TypeError("Expected value must be a string when key is FileAttribute.EXTENSION") - if expected_value and not expected_value.startswith("."): - expected_value = "." + expected_value - - normalized_values: list[object] = [] - for value in values: - if value and isinstance(value, str): - if not value.startswith("."): - value = "." + value - normalized_values.append(value) - values = normalized_values - sub_group_results: list[bool] = [ - _evaluate_condition( - value=value, - operator=condition.comparison_operator, - expected=expected_value, - ) - for value in values - ] - # Determine the result based on the presence of "not" in the comparison operator - result = all(sub_group_results) if "not" in condition.comparison_operator else any(sub_group_results) - group_results.append(result) - return all(group_results) if operator == "and" else any(group_results) diff --git a/api/graphon/utils/json_in_md_parser.py b/api/graphon/utils/json_in_md_parser.py deleted file mode 100644 index 4416b4774b..0000000000 --- a/api/graphon/utils/json_in_md_parser.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -import json - - -class OutputParserError(ValueError): - """Raised when a markdown-wrapped JSON payload cannot be parsed or validated.""" - - -def parse_json_markdown(json_string: str) -> dict | list: - """Extract and parse the first JSON object or array embedded in markdown text.""" - json_string = json_string.strip() - starts = ["```json", "```", "``", "`", "{", "["] - ends = ["```", "``", "`", "}", "]"] - end_index = -1 - start_index = 0 - - for start_marker in starts: - start_index = json_string.find(start_marker) - if start_index != -1: - if json_string[start_index] not in ("{", "["): - start_index += len(start_marker) - break - - if start_index != -1: - for end_marker in ends: - end_index = json_string.rfind(end_marker, start_index) - if end_index != -1: - if json_string[end_index] in ("}", "]"): - end_index += 1 - break - - if start_index == -1 or end_index == -1 or start_index >= end_index: - raise ValueError("could not find json block in the output.") - - extracted_content = json_string[start_index:end_index].strip() - return json.loads(extracted_content) - - -def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: - try: - json_obj = parse_json_markdown(text) - except json.JSONDecodeError as exc: - raise OutputParserError(f"got invalid json object. error: {exc}") from exc - - if isinstance(json_obj, list): - if len(json_obj) == 1 and isinstance(json_obj[0], dict): - json_obj = json_obj[0] - else: - raise OutputParserError(f"got invalid return object. obj:{json_obj}") - - for key in expected_keys: - if key not in json_obj: - raise OutputParserError( - f"got invalid return object. expected key `{key}` to be present, but got {json_obj}" - ) - - return json_obj diff --git a/api/graphon/variable_loader.py b/api/graphon/variable_loader.py deleted file mode 100644 index 03db920d3d..0000000000 --- a/api/graphon/variable_loader.py +++ /dev/null @@ -1,75 +0,0 @@ -import abc -from collections.abc import Mapping, Sequence -from typing import Any, Protocol - -from graphon.runtime import VariablePool -from graphon.variables import VariableBase -from graphon.variables.consts import SELECTORS_LENGTH - - -class VariableLoader(Protocol): - """Interface for loading variables based on selectors. - - A `VariableLoader` is responsible for retrieving additional variables required during the execution - of a single node, which are not provided as user inputs. - - TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into - `WorkflowService.single_step_run`, we may get rid of this interface. - """ - - @abc.abstractmethod - def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: - """Load variables based on the provided selectors. If the selectors are empty, - this method should return an empty list. - - The order of the returned variables is not guaranteed. If the caller wants to ensure - a specific order, they should sort the returned list themselves. - - :param: selectors: a list of string list, each inner list should have at least two elements: - - the first element is the node ID, - - the second element is the variable name. - :return: a list of VariableBase objects that match the provided selectors. - """ - pass - - -class _DummyVariableLoader(VariableLoader): - """A dummy implementation of VariableLoader that does not load any variables. - Serves as a placeholder when no variable loading is needed. - """ - - def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: - return [] - - -DUMMY_VARIABLE_LOADER = _DummyVariableLoader() - - -def load_into_variable_pool( - variable_loader: VariableLoader, - variable_pool: VariablePool, - variable_mapping: Mapping[str, Sequence[str]], - user_inputs: Mapping[str, Any], -): - # Loading missing variable from draft var here, and set it into - # variable_pool. - variables_to_load: list[list[str]] = [] - for key, selector in variable_mapping.items(): - # NOTE(QuantumGhost): this logic needs to be in sync with - # `WorkflowEntry.mapping_user_inputs_to_variable_pool`. - node_variable_list = key.split(".") - if len(node_variable_list) < 2: - raise ValueError(f"Invalid variable key: {key}. It should have at least two elements.") - if key in user_inputs: - continue - node_variable_key = ".".join(node_variable_list[1:]) - if node_variable_key in user_inputs: - continue - if variable_pool.get(selector) is None: - variables_to_load.append(list(selector)) - loaded = variable_loader.load_variables(variables_to_load) - for var in loaded: - assert len(var.selector) >= SELECTORS_LENGTH, f"Invalid variable {var}" - # Add variable directly to the pool - # The variable pool expects 2-element selectors [node_id, variable_name] - variable_pool.add([var.selector[0], var.selector[1]], var) diff --git a/api/graphon/variables/__init__.py b/api/graphon/variables/__init__.py deleted file mode 100644 index e9beb6cb95..0000000000 --- a/api/graphon/variables/__init__.py +++ /dev/null @@ -1,82 +0,0 @@ -from .factory import ( - TypeMismatchError, - UnsupportedSegmentTypeError, - build_segment, - build_segment_with_type, - segment_to_variable, -) -from .input_entities import VariableEntity, VariableEntityType -from .segment_group import SegmentGroup -from .segments import ( - ArrayAnySegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, -) -from .types import SegmentType -from .variables import ( - ArrayAnyVariable, - ArrayFileVariable, - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - ArrayVariable, - FileVariable, - FloatVariable, - IntegerVariable, - NoneVariable, - ObjectVariable, - SecretVariable, - StringVariable, - Variable, - VariableBase, -) - -__all__ = [ - "ArrayAnySegment", - "ArrayAnyVariable", - "ArrayFileSegment", - "ArrayFileVariable", - "ArrayNumberSegment", - "ArrayNumberVariable", - "ArrayObjectSegment", - "ArrayObjectVariable", - "ArraySegment", - "ArrayStringSegment", - "ArrayStringVariable", - "ArrayVariable", - "FileSegment", - "FileVariable", - "FloatSegment", - "FloatVariable", - "IntegerSegment", - "IntegerVariable", - "NoneSegment", - "NoneVariable", - "ObjectSegment", - "ObjectVariable", - "SecretVariable", - "Segment", - "SegmentGroup", - "SegmentType", - "StringSegment", - "StringVariable", - "TypeMismatchError", - "UnsupportedSegmentTypeError", - "Variable", - "VariableBase", - "VariableEntity", - "VariableEntityType", - "build_segment", - "build_segment_with_type", - "segment_to_variable", -] diff --git a/api/graphon/variables/consts.py b/api/graphon/variables/consts.py deleted file mode 100644 index 8f3f78f740..0000000000 --- a/api/graphon/variables/consts.py +++ /dev/null @@ -1,7 +0,0 @@ -# The minimal selector length for valid variables. -# -# The first element of the selector is the node id, and the second element is the variable name. -# -# If the selector length is more than 2, the remaining parts are the keys / indexes paths used -# to extract part of the variable value. -SELECTORS_LENGTH = 2 diff --git a/api/graphon/variables/exc.py b/api/graphon/variables/exc.py deleted file mode 100644 index 5cf67c3bac..0000000000 --- a/api/graphon/variables/exc.py +++ /dev/null @@ -1,2 +0,0 @@ -class VariableError(ValueError): - pass diff --git a/api/graphon/variables/factory.py b/api/graphon/variables/factory.py deleted file mode 100644 index ac693914a7..0000000000 --- a/api/graphon/variables/factory.py +++ /dev/null @@ -1,202 +0,0 @@ -"""Graph-owned helpers for converting runtime values, segments, and variables. - -These conversions are part of the `graphon` runtime model and must stay -independent from top-level API factory modules so graph nodes and state -containers can operate without importing application-layer packages. -""" - -from collections.abc import Mapping, Sequence -from typing import Any, cast -from uuid import uuid4 - -from graphon.file import File - -from .segments import ( - ArrayAnySegment, - ArrayBooleanSegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - BooleanSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, -) -from .types import SegmentType -from .variables import ( - ArrayAnyVariable, - ArrayBooleanVariable, - ArrayFileVariable, - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - BooleanVariable, - FileVariable, - FloatVariable, - IntegerVariable, - NoneVariable, - ObjectVariable, - StringVariable, - VariableBase, -) - - -class UnsupportedSegmentTypeError(Exception): - pass - - -class TypeMismatchError(Exception): - pass - - -SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[Any]] = { - ArrayAnySegment: ArrayAnyVariable, - ArrayBooleanSegment: ArrayBooleanVariable, - ArrayFileSegment: ArrayFileVariable, - ArrayNumberSegment: ArrayNumberVariable, - ArrayObjectSegment: ArrayObjectVariable, - ArrayStringSegment: ArrayStringVariable, - BooleanSegment: BooleanVariable, - FileSegment: FileVariable, - FloatSegment: FloatVariable, - IntegerSegment: IntegerVariable, - NoneSegment: NoneVariable, - ObjectSegment: ObjectVariable, - StringSegment: StringVariable, -} - - -def build_segment(value: Any, /) -> Segment: - """Build a runtime segment from a Python value.""" - if value is None: - return NoneSegment() - if isinstance(value, Segment): - return value - if isinstance(value, str): - return StringSegment(value=value) - if isinstance(value, bool): - return BooleanSegment(value=value) - if isinstance(value, int): - return IntegerSegment(value=value) - if isinstance(value, float): - return FloatSegment(value=value) - if isinstance(value, dict): - return ObjectSegment(value=value) - if isinstance(value, File): - return FileSegment(value=value) - if isinstance(value, list): - items = [build_segment(item) for item in value] - types = {item.value_type for item in items} - if all(isinstance(item, ArraySegment) for item in items): - return ArrayAnySegment(value=value) - if len(types) != 1: - if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): - return ArrayNumberSegment(value=value) - return ArrayAnySegment(value=value) - - match types.pop(): - case SegmentType.STRING: - return ArrayStringSegment(value=value) - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return ArrayNumberSegment(value=value) - case SegmentType.BOOLEAN: - return ArrayBooleanSegment(value=value) - case SegmentType.OBJECT: - return ArrayObjectSegment(value=value) - case SegmentType.FILE: - return ArrayFileSegment(value=value) - case SegmentType.NONE: - return ArrayAnySegment(value=value) - case _: - raise ValueError(f"not supported value {value}") - raise ValueError(f"not supported value {value}") - - -_SEGMENT_FACTORY: Mapping[SegmentType, type[Segment]] = { - SegmentType.NONE: NoneSegment, - SegmentType.STRING: StringSegment, - SegmentType.INTEGER: IntegerSegment, - SegmentType.FLOAT: FloatSegment, - SegmentType.FILE: FileSegment, - SegmentType.BOOLEAN: BooleanSegment, - SegmentType.OBJECT: ObjectSegment, - SegmentType.ARRAY_ANY: ArrayAnySegment, - SegmentType.ARRAY_STRING: ArrayStringSegment, - SegmentType.ARRAY_NUMBER: ArrayNumberSegment, - SegmentType.ARRAY_OBJECT: ArrayObjectSegment, - SegmentType.ARRAY_FILE: ArrayFileSegment, - SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment, -} - - -def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: - """Build a segment while enforcing compatibility with the expected runtime type.""" - if value is None: - if segment_type == SegmentType.NONE: - return NoneSegment() - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") - - if isinstance(value, list) and len(value) == 0: - if segment_type == SegmentType.ARRAY_ANY: - return ArrayAnySegment(value=value) - if segment_type == SegmentType.ARRAY_STRING: - return ArrayStringSegment(value=value) - if segment_type == SegmentType.ARRAY_BOOLEAN: - return ArrayBooleanSegment(value=value) - if segment_type == SegmentType.ARRAY_NUMBER: - return ArrayNumberSegment(value=value) - if segment_type == SegmentType.ARRAY_OBJECT: - return ArrayObjectSegment(value=value) - if segment_type == SegmentType.ARRAY_FILE: - return ArrayFileSegment(value=value) - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") - - inferred_type = SegmentType.infer_segment_type(value) - if inferred_type is None: - raise TypeMismatchError( - f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" - ) - if inferred_type == segment_type: - segment_class = _SEGMENT_FACTORY[segment_type] - return segment_class(value_type=segment_type, value=value) - if segment_type == SegmentType.NUMBER and inferred_type in (SegmentType.INTEGER, SegmentType.FLOAT): - segment_class = _SEGMENT_FACTORY[inferred_type] - return segment_class(value_type=inferred_type, value=value) - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") - - -def segment_to_variable( - *, - segment: Segment, - selector: Sequence[str], - id: str | None = None, - name: str | None = None, - description: str = "", -) -> VariableBase: - """Convert a runtime segment into a runtime variable for storage in the pool.""" - if isinstance(segment, VariableBase): - return segment - name = name or selector[-1] - id = id or str(uuid4()) - - segment_type = type(segment) - if segment_type not in SEGMENT_TO_VARIABLE_MAP: - raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") - - variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return cast( - VariableBase, - variable_class( - id=id, - name=name, - description=description, - value=segment.value, - selector=list(selector), - ), - ) diff --git a/api/graphon/variables/input_entities.py b/api/graphon/variables/input_entities.py deleted file mode 100644 index c46ee47714..0000000000 --- a/api/graphon/variables/input_entities.py +++ /dev/null @@ -1,62 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum -from typing import Any - -from jsonschema import Draft7Validator, SchemaError -from pydantic import BaseModel, Field, field_validator - -from graphon.file import FileTransferMethod, FileType - - -class VariableEntityType(StrEnum): - TEXT_INPUT = "text-input" - SELECT = "select" - PARAGRAPH = "paragraph" - NUMBER = "number" - EXTERNAL_DATA_TOOL = "external_data_tool" - FILE = "file" - FILE_LIST = "file-list" - CHECKBOX = "checkbox" - JSON_OBJECT = "json_object" - - -class VariableEntity(BaseModel): - """ - Shared variable entity used by workflow runtime and app configuration. - """ - - # `variable` records the name of the variable in user inputs. - variable: str - label: str - description: str = "" - type: VariableEntityType - required: bool = False - hide: bool = False - default: Any = None - max_length: int | None = None - options: Sequence[str] = Field(default_factory=list) - allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) - allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) - allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) - json_schema: dict[str, Any] | None = Field(default=None) - - @field_validator("description", mode="before") - @classmethod - def convert_none_description(cls, value: Any) -> str: - return value or "" - - @field_validator("options", mode="before") - @classmethod - def convert_none_options(cls, value: Any) -> Sequence[str]: - return value or [] - - @field_validator("json_schema") - @classmethod - def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None: - if schema is None: - return None - try: - Draft7Validator.check_schema(schema) - except SchemaError as error: - raise ValueError(f"Invalid JSON schema: {error.message}") - return schema diff --git a/api/graphon/variables/segment_group.py b/api/graphon/variables/segment_group.py deleted file mode 100644 index b363255b2c..0000000000 --- a/api/graphon/variables/segment_group.py +++ /dev/null @@ -1,22 +0,0 @@ -from .segments import Segment -from .types import SegmentType - - -class SegmentGroup(Segment): - value_type: SegmentType = SegmentType.GROUP - value: list[Segment] - - @property - def text(self): - return "".join([segment.text for segment in self.value]) - - @property - def log(self): - return "".join([segment.log for segment in self.value]) - - @property - def markdown(self): - return "".join([segment.markdown for segment in self.value]) - - def to_object(self): - return [segment.to_object() for segment in self.value] diff --git a/api/graphon/variables/segments.py b/api/graphon/variables/segments.py deleted file mode 100644 index 8902ddc7e9..0000000000 --- a/api/graphon/variables/segments.py +++ /dev/null @@ -1,253 +0,0 @@ -import json -import sys -from collections.abc import Mapping, Sequence -from typing import Annotated, Any, TypeAlias - -from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator - -from graphon.file import File - -from .types import SegmentType - - -class Segment(BaseModel): - """Segment is runtime type used during the execution of workflow. - - Note: this class is abstract, you should use subclasses of this class instead. - """ - - model_config = ConfigDict(frozen=True) - - value_type: SegmentType - value: Any - - @field_validator("value_type") - @classmethod - def validate_value_type(cls, value): - """ - This validator checks if the provided value is equal to the default value of the 'value_type' field. - If the value is different, a ValueError is raised. - """ - if value != cls.model_fields["value_type"].default: - raise ValueError("Cannot modify 'value_type'") - return value - - @property - def text(self) -> str: - return str(self.value) - - @property - def log(self) -> str: - return str(self.value) - - @property - def markdown(self) -> str: - return str(self.value) - - @property - def size(self) -> int: - """ - Return the size of the value in bytes. - """ - return sys.getsizeof(self.value) - - def to_object(self): - return self.value - - -class NoneSegment(Segment): - value_type: SegmentType = SegmentType.NONE - value: None = None - - @property - def text(self) -> str: - return "" - - @property - def log(self) -> str: - return "" - - @property - def markdown(self) -> str: - return "" - - -class StringSegment(Segment): - value_type: SegmentType = SegmentType.STRING - value: str - - -class FloatSegment(Segment): - value_type: SegmentType = SegmentType.FLOAT - value: float - # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. - # The following tests cannot pass. - # - # def test_float_segment_and_nan(): - # nan = float("nan") - # assert nan != nan - # - # f1 = FloatSegment(value=float("nan")) - # f2 = FloatSegment(value=float("nan")) - # assert f1 != f2 - # - # f3 = FloatSegment(value=nan) - # f4 = FloatSegment(value=nan) - # assert f3 != f4 - - -class IntegerSegment(Segment): - value_type: SegmentType = SegmentType.INTEGER - value: int - - -class ObjectSegment(Segment): - value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Any] - - @property - def text(self) -> str: - return json.dumps(self.model_dump()["value"], ensure_ascii=False) - - @property - def log(self) -> str: - return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) - - @property - def markdown(self) -> str: - return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) - - -class ArraySegment(Segment): - @property - def text(self) -> str: - # Return empty string for empty arrays instead of "[]" - if not self.value: - return "" - return super().text - - @property - def markdown(self) -> str: - items = [] - for item in self.value: - items.append(f"- {item}") - return "\n".join(items) - - -class FileSegment(Segment): - value_type: SegmentType = SegmentType.FILE - value: File - - @property - def markdown(self) -> str: - return self.value.markdown - - @property - def log(self) -> str: - return "" - - @property - def text(self) -> str: - return "" - - -class BooleanSegment(Segment): - value_type: SegmentType = SegmentType.BOOLEAN - value: bool - - -class ArrayAnySegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Any] - - -class ArrayStringSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[str] - - @property - def text(self) -> str: - # Return empty string for empty arrays instead of "[]" - if not self.value: - return "" - return json.dumps(self.value, ensure_ascii=False) - - -class ArrayNumberSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[float | int] - - -class ArrayObjectSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[Mapping[str, Any]] - - -class ArrayFileSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[File] - - @property - def markdown(self) -> str: - items = [] - for item in self.value: - items.append(item.markdown) - return "\n".join(items) - - @property - def log(self) -> str: - return "" - - @property - def text(self) -> str: - return "" - - -class ArrayBooleanSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_BOOLEAN - value: Sequence[bool] - - -def get_segment_discriminator(v: Any) -> SegmentType | None: - if isinstance(v, Segment): - return v.value_type - elif isinstance(v, dict): - value_type = v.get("value_type") - if value_type is None: - return None - try: - seg_type = SegmentType(value_type) - except ValueError: - return None - return seg_type - else: - # return None if the discriminator value isn't found - return None - - -# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic. -# Use `Segment` for type hinting when serialization is not required. -# -# Note: -# - All variants in `SegmentUnion` must inherit from the `Segment` class. -# - The union must include all non-abstract subclasses of `Segment`, except: -# - `SegmentGroup`, which is not added to the variable pool. -# - `VariableBase` and its subclasses, which are handled by `Variable`. -SegmentUnion: TypeAlias = Annotated[ - ( - Annotated[NoneSegment, Tag(SegmentType.NONE)] - | Annotated[StringSegment, Tag(SegmentType.STRING)] - | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] - | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] - | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] - | Annotated[FileSegment, Tag(SegmentType.FILE)] - | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)] - | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] - | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] - | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] - | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] - | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] - | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)] - ), - Discriminator(get_segment_discriminator), -] diff --git a/api/graphon/variables/types.py b/api/graphon/variables/types.py deleted file mode 100644 index 949a693ad2..0000000000 --- a/api/graphon/variables/types.py +++ /dev/null @@ -1,273 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from enum import StrEnum -from typing import TYPE_CHECKING, Any - -from graphon.file.models import File - -if TYPE_CHECKING: - from graphon.variables.segments import Segment - - -class ArrayValidation(StrEnum): - """Strategy for validating array elements. - - Note: - The `NONE` and `FIRST` strategies are primarily for compatibility purposes. - Avoid using them in new code whenever possible. - """ - - # Skip element validation (only check array container) - NONE = "none" - - # Validate the first element (if array is non-empty) - FIRST = "first" - - # Validate all elements in the array. - ALL = "all" - - -class SegmentType(StrEnum): - NUMBER = "number" - INTEGER = "integer" - FLOAT = "float" - STRING = "string" - OBJECT = "object" - SECRET = "secret" - - FILE = "file" - BOOLEAN = "boolean" - - ARRAY_ANY = "array[any]" - ARRAY_STRING = "array[string]" - ARRAY_NUMBER = "array[number]" - ARRAY_OBJECT = "array[object]" - ARRAY_FILE = "array[file]" - ARRAY_BOOLEAN = "array[boolean]" - - NONE = "none" - - GROUP = "group" - - def is_array_type(self) -> bool: - return self in _ARRAY_TYPES - - @classmethod - def infer_segment_type(cls, value: Any) -> SegmentType | None: - """ - Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. - - Returns `None` if no appropriate `SegmentType` can be determined for the given `value`. - For example, this may occur if the input is a generic Python object of type `object`. - """ - - if isinstance(value, list): - elem_types: set[SegmentType] = set() - for i in value: - segment_type = cls.infer_segment_type(i) - if segment_type is None: - return None - - elem_types.add(segment_type) - - if len(elem_types) != 1: - if elem_types.issubset(_NUMERICAL_TYPES): - return SegmentType.ARRAY_NUMBER - return SegmentType.ARRAY_ANY - elif all(i.is_array_type() for i in elem_types): - return SegmentType.ARRAY_ANY - match elem_types.pop(): - case SegmentType.STRING: - return SegmentType.ARRAY_STRING - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return SegmentType.ARRAY_NUMBER - case SegmentType.OBJECT: - return SegmentType.ARRAY_OBJECT - case SegmentType.FILE: - return SegmentType.ARRAY_FILE - case SegmentType.NONE: - return SegmentType.ARRAY_ANY - case SegmentType.BOOLEAN: - return SegmentType.ARRAY_BOOLEAN - case _: - # This should be unreachable. - raise ValueError(f"not supported value {value}") - if value is None: - return SegmentType.NONE - # Important: The check for `bool` must precede the check for `int`, - # as `bool` is a subclass of `int` in Python's type hierarchy. - elif isinstance(value, bool): - return SegmentType.BOOLEAN - elif isinstance(value, int): - return SegmentType.INTEGER - elif isinstance(value, float): - return SegmentType.FLOAT - elif isinstance(value, str): - return SegmentType.STRING - elif isinstance(value, dict): - return SegmentType.OBJECT - elif isinstance(value, File): - return SegmentType.FILE - else: - return None - - def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool: - if not isinstance(value, list): - return False - # Skip element validation if array is empty - if len(value) == 0: - return True - if self == SegmentType.ARRAY_ANY: - return True - element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self] - - if array_validation == ArrayValidation.NONE: - return True - elif array_validation == ArrayValidation.FIRST: - return element_type.is_valid(value[0]) - else: - return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value) - - def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.ALL) -> bool: - """ - Check if a value matches the segment type. - Users of `SegmentType` should call this method, instead of using - `isinstance` manually. - - Args: - value: The value to validate - array_validation: Validation strategy for array types (ignored for non-array types) - - Returns: - True if the value matches the type under the given validation strategy - """ - if self.is_array_type(): - return self._validate_array(value, array_validation) - # Important: The check for `bool` must precede the check for `int`, - # as `bool` is a subclass of `int` in Python's type hierarchy. - elif self == SegmentType.BOOLEAN: - return isinstance(value, bool) - elif self in [SegmentType.INTEGER, SegmentType.FLOAT, SegmentType.NUMBER]: - return isinstance(value, (int, float)) - elif self == SegmentType.STRING: - return isinstance(value, str) - elif self == SegmentType.OBJECT: - return isinstance(value, dict) - elif self == SegmentType.SECRET: - return isinstance(value, str) - elif self == SegmentType.FILE: - return isinstance(value, File) - elif self == SegmentType.NONE: - return value is None - elif self == SegmentType.GROUP: - from .segment_group import SegmentGroup - from .segments import Segment - - if isinstance(value, SegmentGroup): - return all(isinstance(item, Segment) for item in value.value) - - if isinstance(value, list): - return all(isinstance(item, Segment) for item in value) - - return False - else: - raise AssertionError("this statement should be unreachable.") - - @staticmethod - def cast_value(value: Any, type_: SegmentType): - # Cast Python's `bool` type to `int` when the runtime type requires - # an integer or number. - # - # This ensures compatibility with existing workflows that may use `bool` as - # `int`, since in Python's type system, `bool` is a subtype of `int`. - # - # This function exists solely to maintain compatibility with existing workflows. - # It should not be used to compromise the integrity of the runtime type system. - # No additional casting rules should be introduced to this function. - - if type_ in ( - SegmentType.INTEGER, - SegmentType.NUMBER, - ) and isinstance(value, bool): - return int(value) - if type_ == SegmentType.ARRAY_NUMBER and all(isinstance(i, bool) for i in value): - return [int(i) for i in value] - return value - - def exposed_type(self) -> SegmentType: - """Returns the type exposed to the frontend. - - The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here. - """ - if self in (SegmentType.INTEGER, SegmentType.FLOAT): - return SegmentType.NUMBER - return self - - def element_type(self) -> SegmentType | None: - """Return the element type of the current segment type, or `None` if the element type is undefined. - - Raises: - ValueError: If the current segment type is not an array type. - - Note: - For certain array types, such as `SegmentType.ARRAY_ANY`, their element types are not defined - by the runtime system. In such cases, this method will return `None`. - """ - if not self.is_array_type(): - raise ValueError(f"element_type is only supported by array type, got {self}") - return _ARRAY_ELEMENT_TYPES_MAPPING.get(self) - - @staticmethod - def get_zero_value(t: SegmentType) -> Segment: - # Lazy import to avoid circular dependency between segment types and factory helpers. - from graphon.variables.factory import build_segment, build_segment_with_type - - match t: - case ( - SegmentType.ARRAY_OBJECT - | SegmentType.ARRAY_ANY - | SegmentType.ARRAY_STRING - | SegmentType.ARRAY_NUMBER - | SegmentType.ARRAY_BOOLEAN - ): - return build_segment_with_type(t, []) - case SegmentType.OBJECT: - return build_segment({}) - case SegmentType.STRING: - return build_segment("") - case SegmentType.INTEGER: - return build_segment(0) - case SegmentType.FLOAT: - return build_segment(0.0) - case SegmentType.NUMBER: - return build_segment(0) - case SegmentType.BOOLEAN: - return build_segment(False) - case _: - raise ValueError(f"unsupported variable type: {t}") - - -_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = { - # ARRAY_ANY does not have corresponding element type. - SegmentType.ARRAY_STRING: SegmentType.STRING, - SegmentType.ARRAY_NUMBER: SegmentType.NUMBER, - SegmentType.ARRAY_OBJECT: SegmentType.OBJECT, - SegmentType.ARRAY_FILE: SegmentType.FILE, - SegmentType.ARRAY_BOOLEAN: SegmentType.BOOLEAN, -} - -_ARRAY_TYPES = frozenset( - list(_ARRAY_ELEMENT_TYPES_MAPPING.keys()) - + [ - SegmentType.ARRAY_ANY, - ] -) - -_NUMERICAL_TYPES = frozenset( - [ - SegmentType.NUMBER, - SegmentType.INTEGER, - SegmentType.FLOAT, - ] -) diff --git a/api/graphon/variables/utils.py b/api/graphon/variables/utils.py deleted file mode 100644 index 8e738f8fd5..0000000000 --- a/api/graphon/variables/utils.py +++ /dev/null @@ -1,33 +0,0 @@ -from collections.abc import Iterable, Sequence -from typing import Any - -import orjson - -from .segment_group import SegmentGroup -from .segments import ArrayFileSegment, FileSegment, Segment - - -def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: - selectors = [node_id, name] - if paths: - selectors.extend(paths) - return selectors - - -def segment_orjson_default(o: Any): - """Default function for orjson serialization of Segment types""" - if isinstance(o, ArrayFileSegment): - return [v.model_dump() for v in o.value] - elif isinstance(o, FileSegment): - return o.value.model_dump() - elif isinstance(o, SegmentGroup): - return [segment_orjson_default(seg) for seg in o.value] - elif isinstance(o, Segment): - return o.value - raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable") - - -def dumps_with_segments(obj: Any, ensure_ascii: bool = False) -> str: - """JSON dumps with segment support using orjson""" - option = orjson.OPT_NON_STR_KEYS - return orjson.dumps(obj, default=segment_orjson_default, option=option).decode("utf-8") diff --git a/api/graphon/variables/variables.py b/api/graphon/variables/variables.py deleted file mode 100644 index af866283da..0000000000 --- a/api/graphon/variables/variables.py +++ /dev/null @@ -1,172 +0,0 @@ -from collections.abc import Sequence -from typing import Annotated, Any, TypeAlias -from uuid import uuid4 - -from pydantic import BaseModel, Discriminator, Field, Tag - -from .segments import ( - ArrayAnySegment, - ArrayBooleanSegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - BooleanSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, - get_segment_discriminator, -) -from .types import SegmentType - - -def _obfuscated_token(token: str) -> str: - if not token: - return token - if len(token) <= 8: - return "*" * 20 - return token[:6] + "*" * 12 + token[-2:] - - -class VariableBase(Segment): - """ - A variable is a segment that has a name. - - It is mainly used to store segments and their selector in VariablePool. - - Note: this class is abstract, you should use subclasses of this class instead. - """ - - id: str = Field( - default_factory=lambda: str(uuid4()), - description="Unique identity for variable.", - ) - name: str - description: str = Field(default="", description="Description of the variable.") - selector: Sequence[str] = Field(default_factory=list) - - -class StringVariable(StringSegment, VariableBase): - pass - - -class FloatVariable(FloatSegment, VariableBase): - pass - - -class IntegerVariable(IntegerSegment, VariableBase): - pass - - -class ObjectVariable(ObjectSegment, VariableBase): - pass - - -class ArrayVariable(ArraySegment, VariableBase): - pass - - -class ArrayAnyVariable(ArrayAnySegment, ArrayVariable): - pass - - -class ArrayStringVariable(ArrayStringSegment, ArrayVariable): - pass - - -class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable): - pass - - -class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable): - pass - - -class SecretVariable(StringVariable): - value_type: SegmentType = SegmentType.SECRET - - @property - def log(self) -> str: - return _obfuscated_token(self.value) - - -class NoneVariable(NoneSegment, VariableBase): - value_type: SegmentType = SegmentType.NONE - value: None = None - - -class FileVariable(FileSegment, VariableBase): - pass - - -class BooleanVariable(BooleanSegment, VariableBase): - pass - - -class ArrayFileVariable(ArrayFileSegment, ArrayVariable): - pass - - -class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable): - pass - - -class RAGPipelineVariable(BaseModel): - belong_to_node_id: str = Field(description="belong to which node id, shared means public") - type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list") - label: str = Field(description="label") - description: str | None = Field(description="description", default="") - variable: str = Field(description="variable key", default="") - max_length: int | None = Field( - description="max length, applicable to text-input, paragraph, and file-list", default=0 - ) - default_value: Any = Field(description="default value", default="") - placeholder: str | None = Field(description="placeholder", default="") - unit: str | None = Field(description="unit, applicable to Number", default="") - tooltips: str | None = Field(description="helpful text", default="") - allowed_file_types: list[str] | None = Field( - description="image, document, audio, video, custom.", default_factory=list - ) - allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list) - allowed_file_upload_methods: list[str] | None = Field( - description="remote_url, local_file, tool_file.", default_factory=list - ) - required: bool = Field(description="optional, default false", default=False) - options: list[str] | None = Field(default_factory=list) - - -class RAGPipelineVariableInput(BaseModel): - variable: RAGPipelineVariable - value: Any - - -# The `Variable` type is used to enable serialization and deserialization with Pydantic. -# Use `VariableBase` for type hinting when serialization is not required. -# -# Note: -# - All variants in `Variable` must inherit from the `VariableBase` class. -# - The union must include all non-abstract subclasses of `VariableBase`. -Variable: TypeAlias = Annotated[ - ( - Annotated[NoneVariable, Tag(SegmentType.NONE)] - | Annotated[StringVariable, Tag(SegmentType.STRING)] - | Annotated[FloatVariable, Tag(SegmentType.FLOAT)] - | Annotated[IntegerVariable, Tag(SegmentType.INTEGER)] - | Annotated[ObjectVariable, Tag(SegmentType.OBJECT)] - | Annotated[FileVariable, Tag(SegmentType.FILE)] - | Annotated[BooleanVariable, Tag(SegmentType.BOOLEAN)] - | Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)] - | Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)] - | Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)] - | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)] - | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)] - | Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)] - | Annotated[SecretVariable, Tag(SegmentType.SECRET)] - ), - Discriminator(get_segment_discriminator), -] diff --git a/api/graphon/workflow_type_encoder.py b/api/graphon/workflow_type_encoder.py deleted file mode 100644 index 7cdc83ebdb..0000000000 --- a/api/graphon/workflow_type_encoder.py +++ /dev/null @@ -1,49 +0,0 @@ -from collections.abc import Mapping -from decimal import Decimal -from typing import Any, overload - -from pydantic import BaseModel - -from graphon.file.models import File -from graphon.variables import Segment - - -class WorkflowRuntimeTypeConverter: - @overload - def to_json_encodable(self, value: Mapping[str, Any]) -> Mapping[str, Any]: ... - @overload - def to_json_encodable(self, value: None) -> None: ... - - def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: - """Convert runtime values to JSON-serializable structures.""" - - result = self.value_to_json_encodable_recursive(value) - if isinstance(result, Mapping) or result is None: - return result - return {} - - def value_to_json_encodable_recursive(self, value: Any): - if value is None: - return value - if isinstance(value, (bool, int, str, float)): - return value - if isinstance(value, Decimal): - # Convert Decimal to float for JSON serialization - return float(value) - if isinstance(value, Segment): - return self.value_to_json_encodable_recursive(value.value) - if isinstance(value, File): - return value.to_dict() - if isinstance(value, BaseModel): - return value.model_dump(mode="json") - if isinstance(value, dict): - res = {} - for k, v in value.items(): - res[k] = self.value_to_json_encodable_recursive(v) - return res - if isinstance(value, list): - res_list = [] - for item in value: - res_list.append(self.value_to_json_encodable_recursive(item)) - return res_list - return value diff --git a/api/libs/helper.py b/api/libs/helper.py index b1815859a5..a7b3da77ff 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -16,14 +16,14 @@ from zoneinfo import available_timezones from flask import Response, stream_with_context from flask_restx import fields +from graphon.file import helpers as file_helpers +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel from pydantic.functional_validators import AfterValidator from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from extensions.ext_redis import redis_client -from graphon.file import helpers as file_helpers -from graphon.model_runtime.utils.encoders import jsonable_encoder if TYPE_CHECKING: from models import Account diff --git a/api/models/human_input.py b/api/models/human_input.py index b4c7a634b6..79c5d62f6a 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -3,11 +3,11 @@ from enum import StrEnum from typing import Annotated, Literal, Self, final import sqlalchemy as sa +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship from core.workflow.human_input_compat import DeliveryMethodType -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.helper import generate_string from .base import Base, DefaultFieldsMixin diff --git a/api/models/model.py b/api/models/model.py index bcb142db56..b03cb7711f 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -14,6 +14,9 @@ from uuid import uuid4 import sqlalchemy as sa from flask import request from flask_login import UserMixin # type: ignore[import-untyped] +from graphon.enums import WorkflowExecutionStatus +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.file import helpers as file_helpers from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import TypedDict @@ -22,9 +25,6 @@ from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file from extensions.storage.storage_type import StorageType -from graphon.enums import WorkflowExecutionStatus -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType -from graphon.file import helpers as file_helpers from libs.helper import generate_string # type: ignore[import-not-found] from libs.uuid_utils import uuidv7 from models.utils.file_input_compat import build_file_from_input_mapping diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py index dee1cc507a..f71583c1cd 100644 --- a/api/models/utils/file_input_compat.py +++ b/api/models/utils/file_input_compat.py @@ -4,9 +4,10 @@ from collections.abc import Callable, Mapping from functools import lru_cache from typing import Any -from core.workflow.file_reference import parse_file_reference from graphon.file import File, FileTransferMethod +from core.workflow.file_reference import parse_file_reference + @lru_cache(maxsize=1) def _get_file_access_controller(): diff --git a/api/models/workflow.py b/api/models/workflow.py index 0557e2e890..f8868cb73c 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -8,6 +8,19 @@ from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast from uuid import uuid4 import sqlalchemy as sa +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import ( + BuiltinNodeTypes, + NodeType, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import File +from graphon.file.constants import maybe_file_object +from graphon.variables import utils as variable_utils +from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from sqlalchemy import ( DateTime, Index, @@ -31,19 +44,6 @@ from core.workflow.variable_prefixes import ( ) from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import ( - BuiltinNodeTypes, - NodeType, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file.constants import maybe_file_object -from graphon.file.models import File -from graphon.variables import utils as variable_utils -from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -53,10 +53,11 @@ if TYPE_CHECKING: from .model import AppMode, UploadFile +from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase + from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter from factories import variable_factory -from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase from libs import helper from .account import Account @@ -1466,8 +1467,6 @@ class WorkflowDraftVariable(Base): # From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than # 80 chars. - # - # ref: api/graphon/entities/variable_pool.py:18 name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column( sa.String(255), diff --git a/api/pyproject.toml b/api/pyproject.toml index b1f1f4bb2e..9c94474cdf 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -28,11 +28,11 @@ dependencies = [ "google-auth-httplib2==0.3.0", "google-cloud-aiplatform>=1.123.0", "googleapis-common-protos>=1.65.0", + "graphon>=0.1.2", "gunicorn~=25.1.0", "httpx[socks]~=0.28.0", "jieba==0.42.1", "json-repair>=0.55.1", - "jsonschema>=4.25.1", "langfuse~=2.51.3", "langsmith~=0.7.16", "markdown~=3.10.2", @@ -63,7 +63,6 @@ dependencies = [ "psycopg2-binary~=2.9.6", "pycryptodome==3.23.0", "pydantic~=2.12.5", - "pydantic-extra-types~=2.11.0", "pydantic-settings~=2.13.1", "pyjwt~=2.12.0", "pypdfium2==5.6.0", @@ -81,7 +80,6 @@ dependencies = [ "unstructured[docx,epub,md,ppt,pptx]~=0.21.5", "pypandoc~=1.13", "yarl~=1.23.0", - "webvtt-py~=0.5.1", "sseclient-py~=1.9.0", "httpx-sse~=0.4.0", "sendgrid~=6.12.3", @@ -130,7 +128,6 @@ dev = [ "types-defusedxml~=0.7.0", "types-deprecated~=1.3.1", "types-docutils~=0.22.3", - "types-jsonschema~=4.26.0", "types-flask-cors~=6.0.0", "types-flask-migrate~=4.1.0", "types-gevent~=25.9.0", diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index cf002df2a9..43f604c2de 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -118,34 +118,7 @@ enterprise/telemetry/exporter.py enterprise/telemetry/id_generator.py enterprise/telemetry/metric_handler.py enterprise/telemetry/telemetry_log.py -graphon/entities/workflow_execution.py -graphon/file/file_manager.py -graphon/graph_engine/error_handler.py -graphon/graph_engine/layers/execution_limits.py -graphon/nodes/agent/agent_node.py -graphon/nodes/base/node.py -graphon/nodes/code/code_node.py -graphon/nodes/datasource/datasource_node.py -graphon/nodes/document_extractor/node.py -graphon/nodes/human_input/human_input_node.py -graphon/nodes/if_else/if_else_node.py -graphon/nodes/iteration/iteration_node.py -graphon/nodes/knowledge_index/knowledge_index_node.py core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py -graphon/nodes/list_operator/node.py -graphon/nodes/llm/node.py -graphon/nodes/loop/loop_node.py -graphon/nodes/parameter_extractor/parameter_extractor_node.py -graphon/nodes/question_classifier/question_classifier_node.py -graphon/nodes/start/start_node.py -graphon/nodes/template_transform/template_transform_node.py -graphon/nodes/tool/tool_node.py -graphon/nodes/trigger_plugin/trigger_event_node.py -graphon/nodes/trigger_schedule/trigger_schedule_node.py -graphon/nodes/trigger_webhook/node.py -graphon/nodes/variable_aggregator/variable_aggregator_node.py -graphon/nodes/variable_assigner/v1/node.py -graphon/nodes/variable_assigner/v2/node.py extensions/logstore/repositories/logstore_api_workflow_run_repository.py extensions/otel/instrumentation.py extensions/otel/runtime.py diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index ffc17e92cf..1a2a539c80 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -38,11 +38,11 @@ from collections.abc import Callable, Sequence from datetime import datetime from typing import Protocol +from graphon.entities.pause_reason import PauseReason +from graphon.enums import WorkflowType from sqlalchemy.orm import Session from core.repositories.factory import WorkflowExecutionRepository -from graphon.entities.pause_reason import PauseReason -from graphon.enums import WorkflowType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 44735eb769..d5c6a203b1 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -10,11 +10,11 @@ from collections.abc import Sequence from datetime import datetime from typing import Protocol, cast +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import ( DifyAPIWorkflowNodeExecutionRepository, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 5bb0c74ada..413936b542 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -28,15 +28,15 @@ from decimal import Decimal from typing import Any, cast import sqlalchemy as sa +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.nodes.human_input.entities import FormDefinition from pydantic import ValidationError from sqlalchemy import and_, delete, func, null, or_, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.nodes.human_input.entities import FormDefinition from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date from libs.infinite_scroll_pagination import InfiniteScrollPagination diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py index 67f8795d3f..feba5f7eb6 100644 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -7,6 +7,9 @@ from collections import defaultdict from collections.abc import Sequence from typing import Any +from graphon.nodes.human_input.entities import FormDefinition +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode from sqlalchemy import select from sqlalchemy.orm import Session, selectinload, sessionmaker @@ -18,9 +21,6 @@ from core.entities.execution_extra_content import ( from core.entities.execution_extra_content import ( HumanInputContent as HumanInputContentDomainModel, ) -from graphon.nodes.human_input.entities import FormDefinition -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode from models.execution_extra_content import ( ExecutionExtraContent as ExecutionExtraContentModel, ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 643a2a2a84..dd73e10374 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -11,6 +11,12 @@ from uuid import uuid4 import yaml from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from packaging import version from packaging.version import parse as parse_version from pydantic import BaseModel, Field @@ -30,12 +36,6 @@ from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerSc from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode from models.model import AppModelConfig, AppModelConfigDict, IconType diff --git a/api/services/app_service.py b/api/services/app_service.py index 9413a93fc4..e9aeb6c43d 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -4,6 +4,8 @@ from typing import Any, TypedDict, cast import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from configs import dify_config from constants.model_template import default_app_templates @@ -14,8 +16,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created, app_was_deleted, app_was_updated from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account diff --git a/api/services/app_task_service.py b/api/services/app_task_service.py index 6e9d6b1c73..0842e9d3e7 100644 --- a/api/services/app_task_service.py +++ b/api/services/app_task_service.py @@ -5,10 +5,11 @@ like stopping tasks, handling both legacy Redis flag mechanism and new GraphEngine command channel mechanism. """ +from graphon.graph_engine.manager import GraphEngineManager + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_redis import redis_client -from graphon.graph_engine.manager import GraphEngineManager from models.model import AppMode diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 9e743bf7b1..90e72d5f34 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,12 +5,12 @@ from collections.abc import Generator from typing import cast from flask import Response, stream_with_context +from graphon.model_runtime.entities.model_entities import ModelType from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.enums import MessageStatus from models.model import App, AppMode, Message from services.errors.audio import ( diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index c6b32b373e..1c128524ad 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app +from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -13,7 +14,6 @@ from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_storage import storage -from graphon.model_runtime.utils.encoders import jsonable_encoder from models.account import Tenant from models.model import ( App, diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 545c5048d5..ba1e7bb826 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -3,6 +3,7 @@ import logging from collections.abc import Callable, Sequence from typing import Any, Union +from graphon.variables.types import SegmentType from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session @@ -12,7 +13,6 @@ from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db from factories import variable_factory -from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index 287d513f48..95a8951951 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ +from graphon.variables.variables import VariableBase from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker -from graphon.variables.variables import VariableBase from models import ConversationVariable diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 3e2342b1a7..83363125c3 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -10,6 +10,9 @@ from collections.abc import Sequence from typing import Any, Literal, cast import sqlalchemy as sa +from graphon.file import helpers as file_helpers +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from redis.exceptions import LockNotOwnedError from sqlalchemy import exists, func, select from sqlalchemy.orm import Session @@ -28,9 +31,6 @@ from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.file import helpers as file_helpers -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 2b7bebb01e..06f83a18f7 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -3,6 +3,7 @@ import time from collections.abc import Mapping from typing import Any +from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy.orm import Session from configs import dify_config @@ -16,7 +17,6 @@ from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.provider_entities import FormType from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 6679c08ebd..a944ef6acd 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,6 +1,15 @@ from collections.abc import Sequence from enum import StrEnum +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderHelpEntity, + SimpleProviderEntity, +) from pydantic import BaseModel, ConfigDict, model_validator from configs import dify_config @@ -15,15 +24,6 @@ from core.entities.provider_entities import ( QuotaConfiguration, UnaddedModelConfiguration, ) -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderHelpEntity, - SimpleProviderEntity, -) from models.provider import ProviderType diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index d2fa98f5e2..64852c222f 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -4,13 +4,13 @@ from typing import Any, Union, cast from urllib.parse import urlparse import httpx +from graphon.nodes.http_request.exc import InvalidHttpMethodError from sqlalchemy import select from constants import HIDDEN_VALUE from core.helper import ssrf_proxy from core.rag.entities.metadata_entities import MetadataCondition from extensions.ext_database import db -from graphon.nodes.http_request.exc import InvalidHttpMethodError from libs.datetime_utils import naive_utc_now from models.dataset import ( Dataset, diff --git a/api/services/file_service.py b/api/services/file_service.py index c11f018f52..50a326d813 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -8,6 +8,7 @@ from tempfile import NamedTemporaryFile from typing import Literal, Union from zipfile import ZIP_DEFLATED, ZipFile +from graphon.file import helpers as file_helpers from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import NotFound @@ -23,7 +24,6 @@ from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType -from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from models import Account diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index d490ad1561..82e0b0f8b1 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -3,6 +3,8 @@ import logging import time from typing import Any +from graphon.model_runtime.entities import LLMMode + from core.app.app_config.entities import ModelConfig from core.rag.datasource.retrieval_service import RetrievalService from core.rag.index_processor.constant.query_type import QueryType @@ -10,7 +12,6 @@ from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db -from graphon.model_runtime.entities import LLMMode from models import Account from models.dataset import Dataset, DatasetQuery from models.enums import CreatorUserRole, DatasetQuerySource diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 68ef67dec1..77576fa4c0 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from enum import StrEnum from typing import Protocol +from graphon.runtime import VariablePool from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker @@ -17,7 +18,6 @@ from core.workflow.human_input_compat import ( ) from extensions.ext_database import db from extensions.ext_mail import mail -from graphon.runtime import VariablePool from libs.email_template_renderer import render_email_template from models import Account, TenantAccountJoin from services.feature_service import FeatureService diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 76598d31ac..02a6620fc7 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -3,6 +3,12 @@ from collections.abc import Mapping from datetime import datetime, timedelta from typing import Any +from graphon.nodes.human_input.entities import ( + FormDefinition, + HumanInputSubmissionValidationError, + validate_human_input_submission, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -11,12 +17,6 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from graphon.nodes.human_input.entities import ( - FormDefinition, - HumanInputSubmissionValidationError, - validate_human_input_submission, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.exception import BaseHTTPException from models.human_input import RecipientType diff --git a/api/services/message_service.py b/api/services/message_service.py index 0c4a334b47..e5389ef659 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -2,6 +2,7 @@ import json from collections.abc import Sequence from typing import Union +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import sessionmaker from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -13,7 +14,6 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.enums import FeedbackFromSource, FeedbackRating diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 469357d6e0..91cca5cb6d 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -3,6 +3,12 @@ import logging from json import JSONDecodeError from typing import Union +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ModelCredentialSchema, + ProviderCredentialSchema, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from sqlalchemy import or_, select from constants import HIDDEN_VALUE @@ -13,12 +19,6 @@ from core.model_manager import LBModelManager from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.provider_manager import ProviderManager from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ModelCredentialSchema, - ProviderCredentialSchema, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.enums import CredentialSourceType from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index e634f90603..3f37c9b176 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,9 +1,10 @@ import logging +from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule + from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager from core.provider_manager import ProviderManager -from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule from models.provider import ProviderType from services.entities.model_provider_entities import ( CustomConfigurationResponse, diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 46a6221fcc..bcf5973d7b 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -9,6 +9,15 @@ from typing import Any, Union, cast from uuid import uuid4 from flask_login import current_user +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from graphon.errors import WorkflowNodeRunFailedError +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable, VariableBase from sqlalchemy import func, select from sqlalchemy.orm import Session, sessionmaker @@ -48,19 +57,6 @@ from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from extensions.ext_database import db -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType -from graphon.errors import WorkflowNodeRunFailedError -from graphon.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent -from graphon.graph_events.base import GraphNodeEventBase -from graphon.node_events.base import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from graphon.runtime import VariablePool -from graphon.variables.variables import Variable, VariableBase from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.dataset import ( # type: ignore diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 1b8207cc31..04156713f4 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -14,6 +14,12 @@ import yaml # type: ignore from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from flask_login import current_user +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from packaging import version from pydantic import BaseModel, Field from sqlalchemy import select @@ -28,12 +34,6 @@ from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_redis import redis_client from factories import variable_factory -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.enums import CollectionBindingType, DatasetRuntimeMode diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py index c91f621ffb..2c1f99a3bc 100644 --- a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -27,13 +27,13 @@ from dataclasses import dataclass, field from typing import Any import click +from graphon.enums import WorkflowType from sqlalchemy import inspect from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db -from graphon.enums import WorkflowType from libs.archive_storage import ( ArchiveStorage, ArchiveStorageNotConfiguredError, diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 4334412c8b..12053377e2 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -6,6 +6,8 @@ import uuid from datetime import UTC, datetime from typing import Any +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.db.session_factory import session_factory @@ -15,8 +17,6 @@ from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 9190a67249..2a56bc0c71 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -2,6 +2,7 @@ import json import logging from typing import Any, cast +from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx import get from sqlalchemy import select from typing_extensions import TypedDict @@ -21,7 +22,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db -from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider from services.tools.tools_transform_service import ToolTransformService diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 931ca5021a..fb6b5bea24 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -2,6 +2,7 @@ import json import logging from datetime import datetime +from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import or_, select from sqlalchemy.orm import Session @@ -13,7 +14,6 @@ from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurati from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db -from graphon.model_runtime.utils.encoders import jsonable_encoder from models.model import App from models.tools import WorkflowToolProvider from models.workflow import Workflow diff --git a/api/services/trigger/schedule_service.py b/api/services/trigger/schedule_service.py index a827222c1d..25e80770b8 100644 --- a/api/services/trigger/schedule_service.py +++ b/api/services/trigger/schedule_service.py @@ -2,6 +2,7 @@ import json import logging from datetime import datetime +from graphon.entities.graph_config import NodeConfigDict from sqlalchemy import select from sqlalchemy.orm import Session @@ -13,7 +14,6 @@ from core.workflow.nodes.trigger_schedule.entities import ( VisualConfig, ) from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError -from graphon.entities.graph_config import NodeConfigDict from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin from models.trigger import WorkflowSchedulePlan diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index dca00a466b..d72c041609 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -5,6 +5,7 @@ from collections.abc import Mapping from typing import Any from flask import Request, Response +from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -20,7 +21,6 @@ from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_ from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.entities.graph_config import NodeConfigDict from models.model import App from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription, WorkflowPluginTrigger diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 5d9be84c06..c03275497d 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -7,6 +7,9 @@ from typing import Any import orjson from flask import request +from graphon.entities.graph_config import NodeConfigDict +from graphon.file import FileTransferMethod +from graphon.variables.types import ArrayValidation, SegmentType from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -28,9 +31,6 @@ from enums.quota_type import QuotaType from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory -from graphon.entities.graph_config import NodeConfigDict -from graphon.file.models import FileTransferMethod -from graphon.variables.types import ArrayValidation, SegmentType from models.enums import AppTriggerStatus, AppTriggerType from models.model import App from models.trigger import AppTrigger, WorkflowWebhookTrigger diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index d0a4317065..62916cc2c9 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -5,8 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from typing import Any, Generic, TypeAlias, TypeVar, overload -from configs import dify_config -from graphon.file.models import File +from graphon.file import File from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable from graphon.variables.segments import ( ArrayFileSegment, @@ -22,6 +21,8 @@ from graphon.variables.segments import ( ) from graphon.variables.utils import dumps_with_segments +from configs import dify_config + _MAX_DEPTH = 100 diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 5fd310b689..3f78b823a6 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,5 +1,7 @@ import logging +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelInstance, ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector @@ -9,7 +11,6 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 1f3993505c..31367f72fa 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,6 +1,11 @@ import json from typing import Any +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.input_entities import VariableEntity from typing_extensions import TypedDict from core.app.app_config.entities import ( @@ -19,11 +24,6 @@ from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.app_event import app_was_created from extensions.ext_database import db -from graphon.file.models import FileUploadConfig -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.input_entities import VariableEntity from models import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig, IconType diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index fa26f507ee..bf178e8a44 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -3,11 +3,11 @@ import uuid from datetime import datetime from typing import Any +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from typing_extensions import TypedDict -from graphon.enums import WorkflowExecutionStatus from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 0b5c89e574..98e338a2d4 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -6,6 +6,19 @@ from concurrent.futures import ThreadPoolExecutor from enum import StrEnum from typing import Any, ClassVar +from graphon.enums import NodeType +from graphon.file import File +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.variable_assigner.common.helpers import get_updated_variables +from graphon.variable_loader import VariableLoader +from graphon.variables import Segment, StringSegment, VariableBase +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.segments import ( + ArrayFileSegment, + FileSegment, +) +from graphon.variables.types import SegmentType +from graphon.variables.utils import dumps_with_segments from sqlalchemy import Engine, orm, select from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -26,19 +39,6 @@ from core.workflow.variable_prefixes import ( from extensions.ext_storage import storage from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable -from graphon.enums import NodeType -from graphon.file.models import File -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.variable_assigner.common.helpers import get_updated_variables -from graphon.variable_loader import VariableLoader -from graphon.variables import Segment, StringSegment, VariableBase -from graphon.variables.consts import SELECTORS_LENGTH -from graphon.variables.segments import ( - ArrayFileSegment, - FileSegment, -) -from graphon.variables.types import SegmentType -from graphon.variables.utils import dumps_with_segments from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models import Account, App, Conversation diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 5fca444723..601e9261fc 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -9,6 +9,10 @@ from collections.abc import Generator, Mapping, Sequence from dataclasses import dataclass from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import desc, select from sqlalchemy.orm import Session, sessionmaker @@ -22,10 +26,6 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models.model import AppMode, Message from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index bef99458be..b555676704 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,6 +5,31 @@ import uuid from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast +from graphon.entities import GraphInitParams, WorkflowNodeExecution +from graphon.entities.graph_config import NodeConfigDict +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + ErrorStrategy, + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.nodes.human_input.entities import HumanInputNodeData, validate_human_input_submission +from graphon.nodes.human_input.enums import HumanInputFormKind +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import load_into_variable_pool +from graphon.variables import VariableBase +from graphon.variables.input_entities import VariableEntityType +from graphon.variables.variables import Variable from sqlalchemy import exists, select from sqlalchemy.orm import Session, sessionmaker @@ -33,31 +58,6 @@ from events.app_event import app_draft_workflow_was_synced, app_published_workfl from extensions.ext_database import db from extensions.ext_storage import storage from factories.file_factory import build_from_mapping, build_from_mappings -from graphon.entities import GraphInitParams, WorkflowNodeExecution -from graphon.entities.graph_config import NodeConfigDict -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import ( - ErrorStrategy, - NodeType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.errors import WorkflowNodeRunFailedError -from graphon.file import File -from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent -from graphon.node_events import NodeRunResult -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.base.node import Node -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from graphon.nodes.human_input.entities import HumanInputNodeData, validate_human_input_submission -from graphon.nodes.human_input.enums import HumanInputFormKind -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.start.entities import StartNodeData -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import load_into_variable_pool -from graphon.variables import VariableBase -from graphon.variables.input_entities import VariableEntityType -from graphon.variables.variables import Variable from libs.datetime_utils import naive_utc_now from models import Account from models.human_input import HumanInputFormRecipient, RecipientType diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 458099d99e..489467651d 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -7,6 +7,7 @@ from typing import Annotated, Any, TypeAlias, Union from celery import shared_task from flask import current_app, json +from graphon.runtime import GraphRuntimeState from pydantic import BaseModel, Discriminator, Field, Tag from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -22,7 +23,6 @@ from core.app.entities.app_invoke_entities import ( from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.runtime import GraphRuntimeState from libs.flask_utils import set_login_user from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index 6365400dd1..0a73c91279 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -10,6 +10,7 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task +from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -22,7 +23,6 @@ from core.app.layers.trigger_post_layer import TriggerPostLayer from core.db.session_factory import session_factory from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.runtime import GraphRuntimeState from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus from models.model import App, EndUser, Tenant diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index ed8a24b336..20335d9b9f 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -7,6 +7,7 @@ from pathlib import Path import click import pandas as pd from celery import shared_task +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import func from core.db.session_factory import session_factory @@ -14,7 +15,6 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from extensions.ext_storage import storage -from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index fd743205a1..ca73b4d374 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -2,6 +2,8 @@ import logging from datetime import timedelta from celery import shared_task +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import or_, select from sqlalchemy.orm import sessionmaker @@ -9,8 +11,6 @@ from configs import dify_config from core.repositories.human_input_repository import HumanInputFormSubmissionRepository from extensions.ext_database import db from extensions.ext_storage import storage -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from models.human_input import HumanInputForm from models.workflow import WorkflowPause, WorkflowRun diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index f8ae3f4b6e..a316eec7b9 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -6,6 +6,7 @@ from typing import Any import click from celery import shared_task +from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -14,7 +15,6 @@ from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod from extensions.ext_database import db from extensions.ext_mail import mail -from graphon.runtime import GraphRuntimeState, VariablePool from models.human_input import ( DeliveryMethodType, HumanInputDelivery, diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 25ea53dfac..56626e372e 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -12,6 +12,7 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -28,7 +29,6 @@ from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from enums.quota_type import QuotaType, unlimited -from graphon.enums import WorkflowExecutionStatus from models.enums import ( AppTriggerType, CreatorUserRole, diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index ae1c2991c9..0c7f74c180 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -9,11 +9,11 @@ import json import logging from celery import shared_task +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from core.db.session_factory import session_factory -from graphon.entities.workflow_execution import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index b823ce3961..f25ebe3bae 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -9,13 +9,13 @@ import json import logging from celery import shared_task -from sqlalchemy import select - -from core.db.session_factory import session_factory from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, ) from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter +from sqlalchemy import select + +from core.db.session_factory import session_factory from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py index a876b0c4aa..91245e879e 100644 --- a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -1,8 +1,9 @@ from collections.abc import Generator +from graphon.node_events import StreamCompletedEvent + from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage -from graphon.node_events import StreamCompletedEvent def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index b2de11b068..3fdea10976 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,7 +1,8 @@ +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamCompletedEvent + from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamCompletedEvent class _Seg: diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index 878d9b24df..c1bb8e1245 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -4,13 +4,13 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader -from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index c4146d5ccd..ce04a158a8 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -4,9 +4,6 @@ from collections.abc import Generator, Sequence from decimal import Decimal from json import dumps -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from core.plugin.impl.model import PluginModelClient - # import monkeypatch from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.llm_entities import ( @@ -26,6 +23,9 @@ from graphon.model_runtime.entities.model_entities import ( ) from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.model import PluginModelClient + class MockModelClass(PluginModelClient): def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 0b21ff1d2a..5c6636f31e 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -3,6 +3,10 @@ import unittest import uuid import pytest +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType +from graphon.variables.variables import StringVariable from sqlalchemy import delete from sqlalchemy.orm import Session @@ -11,10 +15,6 @@ from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType from factories.variable_factory import build_segment -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType -from graphon.variables.variables import StringVariable from libs import datetime_utils from models.enums import CreatorUserRole from models.model import UploadFile diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index f6f4cf260b..38dc8bbb28 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,11 +2,11 @@ import uuid from unittest.mock import patch import pytest +from graphon.variables.segments import StringSegment from sqlalchemy import delete from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType -from graphon.variables.segments import StringSegment from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -193,6 +193,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def setup_offload_test_data(self, app_and_tenant): tenant, app = app_and_tenant from graphon.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: @@ -424,6 +425,7 @@ class TestDeleteDraftVariablesSessionCommit: def setup_offload_test_data(self, app_and_tenant): """Create test data with offload files for session commit tests.""" from graphon.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now tenant, app = app_and_tenant diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index a9a2617bae..c0143faa85 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -1,11 +1,12 @@ from unittest.mock import MagicMock +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory -from graphon.model_runtime.entities.model_entities import ModelType from models.provider import ProviderType diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 7573e00872..ce0c8bf8ca 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -2,17 +2,17 @@ import time import uuid import pytest - -from configs import dify_config -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.node_events import NodeRunResult from graphon.nodes.code.code_node import CodeNode from graphon.nodes.code.limits import CodeNodeLimits from graphon.runtime import GraphRuntimeState, VariablePool + +from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -172,7 +172,7 @@ def test_execute_code_output_validator(setup_code_executor_mock): result = node._run() assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.FAILED - assert result.error == "Output result must be a string, got int instead" + assert result.error == "Output result must be a string, got int instead." def test_execute_code_output_validator_depth(): diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 17ea7de881..ce18486faf 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -3,6 +3,11 @@ import uuid from urllib.parse import urlencode import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.graph import Graph +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig +from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom @@ -11,11 +16,6 @@ from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.graph import Graph -from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -191,7 +191,6 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" - from core.workflow.system_variables import build_system_variables from graphon.enums import BuiltinNodeTypes from graphon.nodes.http_request.entities import ( HttpRequestNodeAuthorization, @@ -202,6 +201,8 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool + from core.workflow.system_variables import build_system_variables + # Create variable pool variable_pool = VariablePool( system_variables=build_system_variables(user_id="test", files=[]), diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index fa5d63cfbf..f0f3fcead1 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,11 +4,6 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.llm_generator.output_parser.structured_output import _parse_structured_output -from core.model_manager import ModelInstance -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import StreamCompletedEvent from graphon.nodes.llm.file_saver import LLMFileSaver @@ -17,6 +12,12 @@ from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.nodes.protocols import HttpClientProtocol from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.llm_generator.output_parser.structured_output import _parse_structured_output +from core.model_manager import ModelInstance +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 367b5bbc11..3bf44df349 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,16 +3,17 @@ import time import uuid from unittest.mock import MagicMock +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_runtime import DifyPromptMessageSerializer from core.workflow.system_variables import build_system_variables from extensions.ext_database import db -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory -from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 9e3e1a47e3..2d728569be 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,14 +1,15 @@ import time import uuid -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index f9ec51ee10..750ced7075 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -2,17 +2,18 @@ import time import uuid from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.tools.utils.configuration import ToolParameterConfigurationManager -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.node_runtime import DifyToolNodeRuntime -from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.node_events import StreamCompletedEvent from graphon.nodes.protocols import ToolFileManagerProtocol from graphon.nodes.tool.tool_node import ToolNode from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.tools.utils.configuration import ToolParameterConfigurationManager +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 5b51510388..5cc458fe2e 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -4,11 +4,11 @@ import json import uuid from flask.testing import FlaskClient +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN -from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py index 290be87697..8ddf867370 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -3,12 +3,12 @@ import uuid from flask.testing import FlaskClient +from graphon.variables.segments import StringSegment from sqlalchemy import select from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from factories.variable_factory import segment_to_variable -from graphon.variables.segments import StringSegment from models import Workflow from models.model import AppMode from models.workflow import WorkflowDraftVariable diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index b8840c4ba8..2b4c1b59ab 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -22,6 +22,13 @@ import uuid from time import time import pytest +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import GraphRunPausedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session @@ -33,16 +40,6 @@ from core.app.layers.pause_state_persist_layer import ( ) from core.workflow.system_variables import build_system_variables from extensions.ext_storage import storage -from graphon.entities.pause_reason import SchedulingPause -from graphon.enums import WorkflowExecutionStatus -from graphon.graph_engine.entities.commands import GraphEngineCommand -from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from graphon.graph_events.graph import GraphRunPausedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime.graph_runtime_state import GraphRuntimeState -from graphon.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from graphon.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper -from graphon.runtime.variable_pool import VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel @@ -545,7 +542,7 @@ class TestPauseStatePersistenceLayerTestContainers: layer.initialize(graph_runtime_state, command_channel) # Import other event types - from graphon.graph_events.graph import ( + from graphon.graph_events import ( GraphRunFailedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index e0c58f0f5c..13caad799e 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -4,6 +4,7 @@ from __future__ import annotations from uuid import uuid4 +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from sqlalchemy import Engine, select from sqlalchemy.orm import Session @@ -17,7 +18,6 @@ from core.workflow.human_input_compat import ( MemberRecipient, WebAppDeliveryMethod, ) -from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.human_input import ( EmailExternalRecipientPayload, diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index ae8c0716a4..0a9b476afc 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -4,6 +4,18 @@ from datetime import timedelta from unittest.mock import MagicMock import pytest +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import delete, select from sqlalchemy.orm import Session @@ -15,18 +27,6 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.node_runtime import DifyHumanInputNodeRuntime from core.workflow.system_variables import build_system_variables -from graphon.enums import WorkflowType -from graphon.graph import Graph -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from graphon.graph_engine.graph_engine import GraphEngine -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models.account import Tenant, TenantAccountJoin, TenantAccountRole diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index 2e207ddc67..cc72dc1cf3 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -4,13 +4,13 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader -from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py index 2fd289dfbc..b745aed141 100644 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -6,6 +6,7 @@ from decimal import Decimal from uuid import uuid4 from graphon.nodes.human_input.entities import FormDefinition, UserAction + from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin from models.enums import ConversationFromSource, InvokeFrom diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py index 641399c7f9..a68b3a08c7 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -5,10 +5,10 @@ from __future__ import annotations from datetime import timedelta from uuid import uuid4 +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index cb00752b35..d28cfda159 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -8,15 +8,15 @@ from unittest.mock import Mock from uuid import uuid4 import pytest -from sqlalchemy import Engine, delete, select -from sqlalchemy.orm import Session, sessionmaker - -from extensions.ext_storage import storage from graphon.entities import WorkflowExecution from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType from graphon.enums import WorkflowExecutionStatus from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus +from sqlalchemy import Engine, delete, select +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import ( diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py index aaf9a85d60..7f44eb6ca3 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -12,11 +12,11 @@ from decimal import Decimal from uuid import uuid4 import pytest +from graphon.nodes.human_input.entities import FormDefinition, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker -from graphon.nodes.human_input.entities import FormDefinition, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ConversationFromSource, InvokeFrom diff --git a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py index d6f0657380..c5e9201ee3 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py @@ -7,12 +7,12 @@ from datetime import timedelta from uuid import uuid4 import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy import exc as sa_exc from sqlalchemy.orm import Session, sessionmaker -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 00a2f9a59f..4f3c0e4200 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -842,6 +842,7 @@ class TestAgentService: conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) from graphon.file import FileTransferMethod, FileType + from models.enums import CreatorUserRole # Add files to message diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py index 02ab3f8314..fb0adbbcc2 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -3,10 +3,10 @@ from uuid import uuid4 import pytest +from graphon.variables import StringVariable from sqlalchemy.orm import sessionmaker from extensions.ext_database import db -from graphon.variables import StringVariable from models.workflow import ConversationVariable from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index 0de3c64c4f..f9bfa570cb 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -9,11 +9,11 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index 883c3c3feb..a814466e14 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -2,10 +2,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType -from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings from models.enums import DataSourceType diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index fe426ae516..c8f04e9215 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -5,9 +5,9 @@ Testcontainers integration tests for archived workflow run deletion service. from datetime import UTC, datetime, timedelta from uuid import uuid4 +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import select -from graphon.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 18c5320d0a..c46b8fba0b 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -3,6 +3,8 @@ import uuid from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from core.workflow.human_input_compat import ( EmailDeliveryConfig, @@ -10,8 +12,6 @@ from core.workflow.human_input_compat import ( EmailRecipients, ExternalRecipient, ) -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.human_input.entities import HumanInputNodeData from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, AppMode from models.workflow import Workflow, WorkflowType diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index 21a54e909e..0f252515f7 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from graphon.runtime import VariablePool from sqlalchemy.engine import Engine from configs import dify_config @@ -15,7 +16,6 @@ from core.workflow.human_input_compat import ( ExternalRecipient, MemberRecipient, ) -from graphon.runtime import VariablePool from models.account import Account, TenantAccountJoin from services import human_input_delivery_test_service as service_module from services.human_input_delivery_test_service import ( diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index c0c1c25f1e..9528257963 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -6,11 +6,11 @@ from unittest.mock import patch import pytest from faker import Faker +from graphon.file import FileType from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client -from graphon.file.enums import FileType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ( ConversationFromSource, diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 8955a3b5f2..ba926bf675 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService @@ -405,10 +405,11 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock models - from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.provider_entities import ProviderEntity + from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity + # Create real model objects instead of mocks provider_entity_1 = SimpleModelProviderEntity( ProviderEntity( @@ -643,9 +644,10 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock default model response - from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from graphon.model_runtime.entities.common_entities import I18nObject + from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity + mock_default_model = DefaultModelEntity( model="gpt-3.5-turbo", model_type=ModelType.LLM, diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 2a18345c87..749c6fff5b 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -8,9 +8,9 @@ from unittest.mock import patch import pytest from faker import Faker +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session -from graphon.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLogCreatedFrom diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 86cf2327c7..0c281c8c33 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,9 +1,9 @@ import pytest from faker import Faker +from graphon.variables.segments import StringSegment from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from graphon.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index ce5c2bd162..ce2fd2eeb1 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -5,6 +5,9 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity, VariableEntityType from sqlalchemy.orm import Session from core.app.app_config.entities import ( @@ -18,9 +21,6 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models import Account, Tenant from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index 4dab895135..7c43bf676b 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -1,10 +1,10 @@ from datetime import datetime, timedelta from uuid import uuid4 +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index d341c5ce99..a16f3ff773 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -3,6 +3,9 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.app_config.entities import WorkflowUIBasedAppConfig @@ -17,9 +20,6 @@ from core.workflow.human_input_compat import ( MemberRecipient, ) from extensions.ext_storage import storage -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.entities import HumanInputNodeData -from graphon.runtime import GraphRuntimeState, VariablePool from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index 9a7507a2f9..96cf9cebf5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,11 +2,11 @@ import uuid from unittest.mock import ANY, call, patch import pytest +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from models import Tenant from models.enums import CreatorUserRole diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index b9f513a6d0..159ab51304 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -24,12 +24,12 @@ from dataclasses import dataclass from datetime import timedelta import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import delete, select from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 8854ef5e04..7539bae685 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -10,6 +10,7 @@ from typing import Any import pytest from flask import Flask, Response from flask.testing import FlaskClient +from graphon.enums import BuiltinNodeTypes from sqlalchemy.orm import Session from configs import dify_config @@ -23,7 +24,6 @@ from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key -from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus diff --git a/api/tests/unit_tests/controllers/console/app/test_audio.py b/api/tests/unit_tests/controllers/console/app/test_audio.py index 2d218dac7e..c52bc02420 100644 --- a/api/tests/unit_tests/controllers/console/app/test_audio.py +++ b/api/tests/unit_tests/controllers/console/app/test_audio.py @@ -4,6 +4,7 @@ import io from types import SimpleNamespace import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -20,7 +21,6 @@ from controllers.console.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index 341efc05ca..3607636880 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -5,12 +5,11 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from graphon.file import File, FileTransferMethod, FileType from werkzeug.exceptions import HTTPException, NotFound from controllers.console.app import workflow as workflow_module from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File def _unwrap(func): diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index c4a8148446..e11102acb1 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -6,14 +6,14 @@ from unittest.mock import Mock import pytest from flask import Flask - -from controllers.console import wraps as console_wraps -from controllers.console.app import workflow_run as workflow_run_module -from controllers.web.error import NotFoundError from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import WorkflowExecutionStatus from graphon.nodes.human_input.entities import FormInput, UserAction from graphon.nodes.human_input.enums import FormInputType + +from controllers.console import wraps as console_wraps +from controllers.console.app import workflow_run as workflow_run_module +from controllers.web.error import NotFoundError from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index 559b5fea09..740da1f1df 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch import pytest from flask_restx import marshal +from graphon.variables.types import SegmentType from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_FIELDS, @@ -15,7 +16,6 @@ from controllers.console.app.workflow_draft_variable import ( ) from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment -from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile @@ -310,8 +310,7 @@ def test_workflow_node_variables_fields(): def test_workflow_file_variable_with_signed_url(): """Test that File type variables include signed URLs in API responses.""" - from graphon.file.enums import FileTransferMethod, FileType - from graphon.file.models import File + from graphon.file import File, FileTransferMethod, FileType # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( @@ -367,8 +366,7 @@ def test_workflow_file_variable_with_signed_url(): def test_workflow_file_variable_remote_url(): """Test that File type variables with REMOTE_URL transfer method return the remote URL.""" - from graphon.file.enums import FileTransferMethod, FileType - from graphon.file.models import File + from graphon.file import File, FileTransferMethod, FileType # Create a File object with REMOTE_URL transfer method test_file = File( diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py index 5136922e88..9c9f8da87c 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from werkzeug.exceptions import Forbidden, NotFound from controllers.console import console_ns @@ -17,7 +18,6 @@ from controllers.console.datasets.rag_pipeline.datasource_auth import ( DatasourceUpdateProviderNameApi, ) from core.plugin.impl.oauth import OAuthHandler -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py index 63950736c5..6ef8ccfdbd 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Response +from graphon.variables.types import SegmentType from controllers.console import console_ns from controllers.console.app.error import DraftWorkflowNotExist @@ -15,7 +16,6 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable impor ) from controllers.web.error import InvalidArgumentError, NotFoundError from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from graphon.variables.types import SegmentType from models.account import Account diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index e4acd91b76..710c9be684 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -20,7 +21,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from models.account import Account from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService diff --git a/api/tests/unit_tests/controllers/console/explore/test_audio.py b/api/tests/unit_tests/controllers/console/explore/test_audio.py index b4b57022e2..66c9ba48c5 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_audio.py +++ b/api/tests/unit_tests/controllers/console/explore/test_audio.py @@ -2,6 +2,7 @@ from io import BytesIO from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import controllers.console.explore.audio as audio_module @@ -19,7 +20,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index 145cc9cdd7..2e4ca4f2a4 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError, NotFound import controllers.console.explore.message as module @@ -21,7 +22,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from services.errors.conversation import ConversationNotExistsError from services.errors.message import ( FirstMessageNotExistsError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index 03eadcdb4e..04beb31389 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import controllers.console.explore.trial as module @@ -25,7 +26,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from models import Account from models.account import TenantStatus from models.model import AppMode diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index b2f949c6e2..9c42ee9529 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -11,10 +11,9 @@ from unittest.mock import MagicMock import pytest from flask import Flask from flask.views import MethodView -from werkzeug.exceptions import Forbidden - from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from werkzeug.exceptions import Forbidden if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py index 168479af1e..fb9eec98cb 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic_core import ValidationError from werkzeug.exceptions import Forbidden @@ -13,7 +14,6 @@ from controllers.console.workspace.model_providers import ( ModelProviderValidateApi, PreferredProviderTypeUpdateApi, ) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" INVALID_UUID = "123" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py index f0d32f81fb..c829327bc7 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_models.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -2,6 +2,8 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from controllers.console.workspace.models import ( DefaultModelApi, @@ -14,8 +16,6 @@ from controllers.console.workspace.models import ( ModelProviderModelParameterRuleApi, ModelProviderModelValidateApi, ) -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError def unwrap(func): diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index e81e612803..5a8cb4619f 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -13,6 +13,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -29,7 +30,6 @@ from controllers.service_api.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 3364c07e62..57681d8f5b 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -16,6 +16,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -34,7 +35,6 @@ from controllers.service_api.app.error import ( NotChatAppError, ) from core.errors.error import QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index 6543c27037..b1f036c6f3 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -19,6 +19,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.enums import WorkflowExecutionStatus from werkzeug.exceptions import BadRequest, NotFound from controllers.service_api.app.error import NotWorkflowAppError @@ -35,7 +36,6 @@ from controllers.service_api.app.workflow import ( WorkflowTaskStopApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError -from graphon.enums import WorkflowExecutionStatus from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py index eda270258d..4b8e3a738c 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -1,8 +1,9 @@ from types import SimpleNamespace -from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField from graphon.enums import WorkflowExecutionStatus +from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField + def test_workflow_run_status_field_with_enum() -> None: field = WorkflowRunStatusField() diff --git a/api/tests/unit_tests/controllers/web/test_audio.py b/api/tests/unit_tests/controllers/web/test_audio.py index a6ca441801..cbfc8fa613 100644 --- a/api/tests/unit_tests/controllers/web/test_audio.py +++ b/api/tests/unit_tests/controllers/web/test_audio.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.audio import AudioApi, TextApi from controllers.web.error import ( @@ -21,7 +22,6 @@ from controllers.web.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/web/test_completion.py b/api/tests/unit_tests/controllers/web/test_completion.py index 4f8d848637..49039d03fe 100644 --- a/api/tests/unit_tests/controllers/web/test_completion.py +++ b/api/tests/unit_tests/controllers/web/test_completion.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi from controllers.web.error import ( @@ -18,7 +19,6 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError def _completion_app() -> SimpleNamespace: diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py index cde8820e00..bc7aea0ef9 100644 --- a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -2,11 +2,11 @@ import json from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage from core.agent.cot_agent_runner import CotAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError -from graphon.model_runtime.entities.llm_entities import LLMUsage class DummyRunner(CotAgentRunner): diff --git a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py index ea8cc8aa86..97206019b9 100644 --- a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py @@ -1,9 +1,9 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.agent.cot_chat_agent_runner import CotChatAgentRunner -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from tests.unit_tests.core.agent.conftest import ( DummyAgentConfig, DummyAppConfig, diff --git a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py index 2f5873d865..defc8b4b64 100644 --- a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py @@ -1,8 +1,6 @@ import json import pytest - -from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -10,6 +8,8 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) +from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner + # ----------------------------- # Fixtures # ----------------------------- diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py index 17ab5babcb..a44a0650eb 100644 --- a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -3,11 +3,6 @@ from typing import Any from unittest.mock import MagicMock import pytest - -from core.agent.errors import AgentMaxIterationError -from core.agent.fc_agent_runner import FunctionCallAgentRunner -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueMessageFileEvent from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.message_entities import ( DocumentPromptMessageContent, @@ -16,6 +11,11 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) +from core.agent.errors import AgentMaxIterationError +from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueMessageFileEvent + # ============================== # Dummy Helper Classes # ============================== diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py index 186b4a501d..5ee66da94a 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py @@ -2,6 +2,8 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.entities.model_entities import ModelStatus @@ -10,8 +12,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey class TestModelConfigConverter: diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py index d9fe7004ff..e2f3c16335 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py @@ -1,9 +1,9 @@ import pytest +from graphon.variables.input_entities import VariableEntityType from core.app.app_config.easy_ui_based_app.variables.manager import ( BasicVariablesConfigManager, ) -from graphon.variables.input_entities import VariableEntityType class TestBasicVariablesConfigManagerConvert: diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index 11fc15c94d..8bde9c1f97 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -1,7 +1,8 @@ -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from graphon.file.models import FileTransferMethod, FileUploadConfig, ImageConfig +from graphon.file import FileTransferMethod, FileUploadConfig, ImageConfig from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager + def test_convert_with_vision(): config = { diff --git a/api/tests/unit_tests/core/app/app_config/test_entities.py b/api/tests/unit_tests/core/app/app_config/test_entities.py index f2bc3076da..000f83cd5a 100644 --- a/api/tests/unit_tests/core/app/app_config/test_entities.py +++ b/api/tests/unit_tests/core/app/app_config/test_entities.py @@ -1,10 +1,10 @@ import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.app_config.entities import ( DatasetRetrieveConfigEntity, PromptTemplateEntity, ) -from graphon.variables.input_entities import VariableEntity, VariableEntityType class TestAppConfigEntities: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index ef7df5e1da..061719d15a 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -3,12 +3,12 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 +from graphon.variables import SegmentType from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from factories import variable_factory -from graphon.variables import SegmentType from models import ConversationVariable, Workflow MINIMAL_GRAPH = { diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py index f2df35d7d0..e9fdeefee4 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator +from graphon.enums import WorkflowNodeExecutionStatus + from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.entities.task_entities import ( ChatbotAppBlockingResponse, @@ -10,7 +12,6 @@ from core.app.entities.task_entities import ( NodeStartStreamResponse, PingStreamResponse, ) -from graphon.enums import WorkflowNodeExecutionStatus class TestAdvancedChatGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py index 99a386cd45..a6d8598955 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py @@ -6,6 +6,8 @@ from types import SimpleNamespace from unittest import mock import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module from core.app.entities.app_invoke_entities import InvokeFrom @@ -17,8 +19,6 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import StreamEvent -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent from models.model import AppMode, EndUser diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 29fd63c063..82b2e51019 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -4,6 +4,8 @@ from contextlib import contextmanager from types import SimpleNamespace import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.apps.advanced_chat.generate_task_pipeline import ( @@ -47,8 +49,6 @@ from core.app.entities.task_entities import ( ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables -from graphon.enums import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models.enums import MessageStatus from models.model import AppMode, EndUser diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py index 80f7f94b1a..7dc4358150 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py @@ -1,12 +1,12 @@ import contextlib import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class DummyAccount: diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py index 4567b35480..08250bc3b6 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -1,10 +1,10 @@ import pytest +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.agent.entities import AgentEntity from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.moderation.base import ModerationError -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py index 8f3c41701b..68bcffb0e8 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -2,6 +2,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.chat.app_runner import ChatAppRunner @@ -9,7 +10,6 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.moderation.base import ModerationError -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index f56ca8de99..f255d2c7df 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent -from graphon.file.enums import FileTransferMethod, FileType -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index d6f7a05cdc..4a94a2b4f1 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -1,12 +1,11 @@ from types import SimpleNamespace import pytest +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.workflow.system_variables import build_system_variables from core.workflow.variable_pool_initializer import add_variables_to_pool -from graphon.runtime import GraphRuntimeState -from graphon.runtime.variable_pool import VariablePool def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 3ab63aed25..328cd12f12 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,9 +1,10 @@ from collections.abc import Mapping, Sequence -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from graphon.variables.segments import ArrayFileSegment, FileSegment +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter + class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Test class for WorkflowResponseConverter._fetch_files_from_variable_value method""" diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index e8946281ac..bc11bf4174 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -1,12 +1,13 @@ from datetime import UTC, datetime from types import SimpleNamespace +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter(): diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index 492e11ee0f..c9e146ff12 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -1,10 +1,11 @@ from types import SimpleNamespace +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter() -> WorkflowResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 7ee375d884..0fde7565d2 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -10,6 +10,8 @@ from typing import Any from unittest.mock import Mock import pytest +from graphon.entities import WorkflowStartReason +from graphon.enums import BuiltinNodeTypes from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -25,8 +27,6 @@ from core.app.entities.queue_entities import ( QueueNodeSucceededEvent, ) from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py index aa2085177e..619d66085a 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent import core.app.apps.completion.app_runner as module from core.app.apps.completion.app_runner import CompletionAppRunner from core.moderation.base import ModerationError -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py index f2e35f9900..96af9fbdee 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py @@ -3,13 +3,13 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError import core.app.apps.completion.app_generator as module from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py index cfe797aa76..6cdcab29ab 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( AppStreamResponse, @@ -10,7 +12,6 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus def test_convert_blocking_full_and_simple_response(): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py index 9db83f5531..4fe82efcb3 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py @@ -1,4 +1,5 @@ import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult import core.app.apps.pipeline.pipeline_queue_manager as module from core.app.apps.base_app_queue_manager import PublishFrom @@ -13,7 +14,6 @@ from core.app.entities.queue_entities import ( QueueWorkflowPartialSuccessEvent, QueueWorkflowSucceededEvent, ) -from graphon.model_runtime.entities.llm_entities import LLMResult def test_publish_sets_stop_listen_and_raises_on_stopped(mocker): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index fb19d6d761..ab70996f0a 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -22,11 +22,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.graph_events import GraphRunFailedEvent import core.app.apps.pipeline.pipeline_runner as module from core.app.apps.pipeline.pipeline_runner import PipelineRunner from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from graphon.graph_events import GraphRunFailedEvent def _build_app_generate_entity() -> SimpleNamespace: diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index b0f8b423e1..6167be3bbd 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,7 +1,7 @@ import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.apps.base_app_generator import BaseAppGenerator -from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_validate_inputs_with_zero(): @@ -476,8 +476,9 @@ class TestBaseAppGeneratorExtras: assert converted[1] == "event: ping\n\n" def test_get_draft_var_saver_factory_debugger(self): - from core.app.entities.app_invoke_entities import InvokeFrom from graphon.enums import BuiltinNodeTypes + + from core.app.entities.app_invoke_entities import InvokeFrom from models import Account base_app_generator = BaseAppGenerator() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py index 17de39ca99..1dee7fdab6 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -4,6 +4,15 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from core.app.app_config.entities import ( AdvancedChatMessageEntity, @@ -14,15 +23,6 @@ from core.app.app_config.entities import ( from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessageRole, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index 3673b7f68e..a126bc85f7 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -4,19 +4,14 @@ from types import ModuleType, SimpleNamespace from typing import Any import graphon.nodes.human_input.entities # noqa: F401 -from core.app.apps.advanced_chat import app_generator as adv_app_gen_module -from core.app.apps.workflow import app_generator as wf_app_gen_module -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason from graphon.entities.base_node_data import BaseNodeData, RetryConfig from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.entities.pause_reason import SchedulingPause -from graphon.entities.workflow_start_reason import WorkflowStartReason from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.graph_engine import GraphEngine -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from graphon.graph_engine.command_channels import InMemoryChannel from graphon.graph_events import ( GraphEngineEvent, GraphRunPausedEvent, @@ -30,6 +25,12 @@ from graphon.nodes.base.node import Node from graphon.nodes.end.entities import EndNodeData from graphon.nodes.start.entities import StartNodeData from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.apps.advanced_chat import app_generator as adv_app_gen_module +from core.app.apps.workflow import app_generator as wf_app_gen_module +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index 58c7bfa4bc..de5bca161c 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -4,23 +4,6 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest - -from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.app.entities.queue_entities import ( - QueueAgentLogEvent, - QueueIterationCompletedEvent, - QueueLoopCompletedEvent, - QueueNodeExceptionEvent, - QueueNodeFailedEvent, - QueueNodeRetryEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowPausedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.system_variables import default_system_variables from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import BuiltinNodeTypes from graphon.graph_events import ( @@ -41,6 +24,23 @@ from graphon.node_events import NodeRunResult from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.variables import StringVariable +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.queue_entities import ( + QueueAgentLogEvent, + QueueIterationCompletedEvent, + QueueLoopCompletedEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeSucceededEvent, + QueueTextChunkEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.workflow.system_variables import default_system_variables + class TestWorkflowBasedAppRunner: def test_resolve_user_from(self): diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py index 38a947986f..aa789d9ff3 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -1,11 +1,11 @@ from unittest.mock import MagicMock import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.queue_entities import QueueWorkflowPausedEvent -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph_events.graph import GraphRunPausedEvent class _DummyQueueManager: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 620a153204..9e30faecf2 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -4,14 +4,14 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.workflow.system_variables import default_system_variables -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.runtime import GraphRuntimeState, VariablePool from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index ef0edf4096..8a717e1dcc 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -3,6 +3,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from core.app.apps.common import workflow_response_converter from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -11,11 +16,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse from core.workflow.system_variables import build_system_variables -from graphon.entities.pause_reason import HumanInputRequired -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.graph_events.graph import GraphRunPausedEvent -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from models.account import Account from models.human_input import RecipientType diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py index 7dd7ffd727..b768e813bd 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( ErrorStreamResponse, @@ -9,7 +11,6 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus class TestWorkflowGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index a0a999cbc5..29df903aa8 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -2,14 +2,15 @@ import time from contextlib import contextmanager from unittest.mock import MagicMock +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState + from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.runtime import GraphRuntimeState from models.account import Account from models.model import AppMode from tests.workflow_test_utils import build_test_variable_pool diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index 115e35da8a..dabd2594b4 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -4,6 +4,8 @@ from contextlib import contextmanager from types import SimpleNamespace import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline @@ -44,8 +46,6 @@ from core.app.entities.task_entities import ( ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables, system_variables_to_mapping -from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import AppMode, EndUser diff --git a/api/tests/unit_tests/core/app/entities/test_task_entities.py b/api/tests/unit_tests/core/app/entities/test_task_entities.py index 7c79780641..014a0cba72 100644 --- a/api/tests/unit_tests/core/app/entities/test_task_entities.py +++ b/api/tests/unit_tests/core/app/entities/test_task_entities.py @@ -1,10 +1,11 @@ +from graphon.enums import WorkflowNodeExecutionStatus + from core.app.entities.task_entities import ( NodeFinishStreamResponse, NodeRetryStreamResponse, NodeStartStreamResponse, StreamEvent, ) -from graphon.enums import WorkflowNodeExecutionStatus class TestTaskEntities: diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index 279e315946..a78c1b428f 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -1,16 +1,17 @@ from collections.abc import Sequence from unittest.mock import Mock +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.command_channels import CommandChannel +from graphon.graph_events import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent +from graphon.node_events import NodeRunResult +from graphon.runtime import ReadOnlyGraphRuntimeState +from graphon.variables import StringVariable +from graphon.variables.segments import Segment, StringSegment + from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer from core.workflow.system_variables import SystemVariableKey from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph_engine.protocols.command_channel import CommandChannel -from graphon.graph_events.node import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent -from graphon.node_events import NodeRunResult -from graphon.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from graphon.variables import StringVariable -from graphon.variables.segments import Segment, StringSegment from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 92a7788f6e..035e64325b 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -4,6 +4,17 @@ from time import time from unittest.mock import Mock import pytest +from graphon.entities.pause_reason import SchedulingPause +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import ( + GraphRunFailedEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, +) +from graphon.runtime import ReadOnlyVariablePool +from graphon.variables.segments import Segment from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity @@ -14,17 +25,6 @@ from core.app.layers.pause_state_persist_layer import ( _WorkflowGenerateEntityWrapper, ) from core.workflow.system_variables import SystemVariableKey -from graphon.entities.pause_reason import SchedulingPause -from graphon.graph_engine.entities.commands import GraphEngineCommand -from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from graphon.graph_events.graph import ( - GraphRunFailedEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) -from graphon.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool -from graphon.variables.segments import Segment from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py index 56705f1a7e..95931f4f8b 100644 --- a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py @@ -1,5 +1,6 @@ +from graphon.graph_events import GraphRunPausedEvent + from core.app.layers.suspend_layer import SuspendLayer -from graphon.graph_events.graph import GraphRunPausedEvent class TestSuspendLayer: diff --git a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py index 1ac9a4d8c0..7cf6eb4f31 100644 --- a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py @@ -1,7 +1,8 @@ from unittest.mock import Mock, patch -from core.app.layers.timeslice_layer import TimeSliceLayer from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand + +from core.app.layers.timeslice_layer import TimeSliceLayer from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import SchedulerCommand diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py index ecc431936c..aa9285789b 100644 --- a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -2,10 +2,11 @@ from datetime import UTC, datetime, timedelta from types import SimpleNamespace from unittest.mock import Mock, patch +from graphon.graph_events import GraphRunFailedEvent, GraphRunSucceededEvent +from graphon.runtime import VariablePool + from core.app.layers.trigger_post_layer import TriggerPostLayer from core.workflow.system_variables import build_system_variables -from graphon.graph_events.graph import GraphRunFailedEvent, GraphRunSucceededEvent -from graphon.runtime import VariablePool from models.enums import WorkflowTriggerStatus diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py index c246f7b783..58aa7d7478 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.entities.queue_entities import QueueErrorEvent from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.errors.error import QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py index 1c1bf391d3..4aaa10a81a 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -2,6 +2,8 @@ from types import SimpleNamespace from unittest.mock import ANY, Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity @@ -26,8 +28,6 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher from core.ops.ops_trace_manager import TraceQueueManager -from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py index ea000f3886..f7e7b7e20e 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -5,6 +5,9 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from core.app.app_config.entities import ( AppAdditionalFeatures, @@ -38,9 +41,6 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AudioTrunk -from graphon.file.enums import FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py index abfbcdb941..31b7313066 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -17,11 +17,11 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.file import FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.entities.task_entities import MessageEndStreamResponse from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline -from graphon.file.enums import FileTransferMethod, FileType from models.model import MessageFile, UploadFile diff --git a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py index 21c761c579..29df7eea86 100644 --- a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py +++ b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py @@ -1,9 +1,10 @@ from types import SimpleNamespace from unittest.mock import patch +from graphon.model_runtime.entities.model_entities import ModelPropertyKey + from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager from core.app.app_config.entities import ModelConfigEntity -from graphon.model_runtime.entities.model_entities import ModelPropertyKey from models.provider_ids import ModelProviderID diff --git a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py index 5c50cb78da..dc2d82ccd6 100644 --- a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py +++ b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py @@ -2,14 +2,14 @@ from datetime import UTC, datetime from unittest.mock import Mock import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType +from graphon.node_events import NodeRunResult from core.app.workflow.layers.persistence import ( PersistenceWorkflowInfo, WorkflowPersistenceLayer, _NodeRuntimeSnapshot, ) -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType -from graphon.node_events import NodeRunResult def _build_layer() -> WorkflowPersistenceLayer: diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py index cddd03f4b0..7be9d6ac1e 100644 --- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -8,13 +8,13 @@ from unittest.mock import MagicMock, patch from urllib.parse import parse_qs, urlparse import pytest +from graphon.file import File, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope from core.app.workflow import file_runtime from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime from core.workflow.file_reference import build_file_reference -from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py index c4bfb23272..8497261d45 100644 --- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -1,10 +1,10 @@ from types import SimpleNamespace import pytest +from graphon.enums import BuiltinNodeTypes from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.workflow.node_factory import DifyNodeFactory -from graphon.enums import BuiltinNodeTypes class DummyNode: diff --git a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py index 82552470a9..a47d3db6f5 100644 --- a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py +++ b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py @@ -2,9 +2,10 @@ from __future__ import annotations from types import SimpleNamespace -from core.app.workflow.layers.observability import ObservabilityLayer from graphon.enums import BuiltinNodeTypes +from core.app.workflow.layers.observability import ObservabilityLayer + class TestObservabilityLayerExtras: def test_init_tracer_enabled_sets_tracer(self, monkeypatch): diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py index 9863f34aba..d8a68f6d00 100644 --- a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -4,27 +4,21 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest - -from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity -from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.workflow.system_variables import SystemVariableKey, build_system_variables +from graphon.entities import WorkflowNodeExecution from graphon.entities.pause_reason import SchedulingPause -from graphon.entities.workflow_node_execution import WorkflowNodeExecution from graphon.enums import ( BuiltinNodeTypes, WorkflowExecutionStatus, WorkflowNodeExecutionStatus, WorkflowType, ) -from graphon.graph_events.graph import ( +from graphon.graph_events import ( GraphRunAbortedEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, -) -from graphon.graph_events.node import ( NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunPauseRequestedEvent, @@ -35,6 +29,10 @@ from graphon.graph_events.node import ( from graphon.node_events import NodeRunResult from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.workflow.system_variables import SystemVariableKey, build_system_variables + class _RepoRecorder: def __init__(self) -> None: diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py index 7b433ab57b..5ff9774b52 100644 --- a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -301,7 +301,6 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() - from core.app.entities.queue_entities import QueueAgentMessageEvent from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -309,6 +308,8 @@ class TestAppGeneratorTTSPublisher: TextPromptMessageContent, ) + from core.app.entities.queue_entities import QueueAgentMessageEvent + chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( @@ -336,10 +337,11 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() - from core.app.entities.queue_entities import QueueAgentMessageEvent from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from graphon.model_runtime.entities.message_entities import AssistantPromptMessage + from core.app.entities.queue_entities import QueueAgentMessageEvent + chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index af992e4e9f..b0c72ee42f 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -2,16 +2,15 @@ import types from collections.abc import Generator import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError from core.workflow.file_reference import parse_file_reference -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.file import File -from graphon.file.enums import FileTransferMethod, FileType -from graphon.node_events import StreamChunkEvent, StreamCompletedEvent def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py index 0b91d59953..fbaf6d497d 100644 --- a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py @@ -1,11 +1,10 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.file import File, FileTransferMethod, FileType from core.datasource.entities.datasource_entities import DatasourceMessage from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer -from graphon.file import File -from graphon.file.enums import FileTransferMethod, FileType from models.tools import ToolFile diff --git a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py index ef8f360dbf..ff9fd0d8f3 100644 --- a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py +++ b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py @@ -1,11 +1,12 @@ +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType + from core.entities.execution_extra_content import ( ExecutionExtraContentDomainModel, HumanInputContent, HumanInputFormDefinition, HumanInputFormSubmissionData, ) -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from models.execution_extra_content import ExecutionContentType diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py index a0b2820157..2acd278a31 100644 --- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -8,6 +8,9 @@ drive provider mapping behavior. """ import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.entities.model_entities import ( DefaultModelEntity, @@ -16,9 +19,6 @@ from core.entities.model_entities import ( ProviderModelWithStatusEntity, SimpleModelProviderEntity, ) -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity: diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index fe2c226843..8cf0409c4c 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -6,6 +6,17 @@ from typing import Any from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus @@ -24,17 +35,6 @@ from core.entities.provider_entities import ( SystemConfiguration, SystemConfigurationStatus, ) -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FieldModelSchema, - FormType, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderEntity, -) from models.enums import CredentialSourceType from models.provider import ProviderType from models.provider_ids import ModelProviderID diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py index a159d3ad4d..8685d16283 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py @@ -1,4 +1,5 @@ import pytest +from graphon.model_runtime.entities.model_entities import ModelType from core.entities.parameter_entities import AppSelectorScope from core.entities.provider_entities import ( @@ -8,7 +9,6 @@ from core.entities.provider_entities import ( ProviderQuotaType, ) from core.tools.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType def test_provider_quota_type_value_of_returns_enum_member() -> None: diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py index 6ed9ddb476..b45f6fd9a7 100644 --- a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py @@ -2,20 +2,6 @@ import json from unittest.mock import MagicMock, patch import pytest - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import ( - ResponseFormat, - _handle_native_json_schema, - _handle_prompt_based_schema, - _parse_structured_output, - _prepare_schema_for_model, - _set_response_format, - convert_boolean_to_string, - invoke_llm_with_structured_output, - remove_additional_properties, -) -from core.model_manager import ModelInstance from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, @@ -31,6 +17,20 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import ( + ResponseFormat, + _handle_native_json_schema, + _handle_prompt_based_schema, + _parse_structured_output, + _prepare_schema_for_model, + _set_response_format, + convert_boolean_to_string, + invoke_llm_with_structured_output, + remove_additional_properties, +) +from core.model_manager import ModelInstance + class TestStructuredOutput: def test_remove_additional_properties(self): diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index b3a5885814..2c0a441125 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -2,12 +2,12 @@ import json from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator -from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError class TestLLMGenerator: diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index bfb1fde502..313d18c695 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -3,6 +3,7 @@ from unittest.mock import Mock, patch import jsonschema import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types @@ -18,7 +19,6 @@ from core.mcp.server.streamable_http import ( prepare_tool_arguments, process_mapping_response, ) -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py index f459250b8e..9a5fb319d7 100644 --- a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py +++ b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py @@ -4,8 +4,6 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest - -from core.memory.token_buffer_memory import TokenBufferMemory from graphon.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -13,6 +11,8 @@ from graphon.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) + +from core.memory.token_buffer_memory import TokenBufferMemory from models.model import AppMode # --------------------------------------------------------------------------- diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py index 249ecb5006..6a672fdfd5 100644 --- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -1,7 +1,6 @@ from unittest.mock import Mock import pytest - from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.provider_entities import ( diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py index c2324fdec4..62d631a754 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py @@ -5,6 +5,8 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module @@ -34,8 +36,6 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey class RecordingTraceClient: diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py index e4d8f2d5ea..2d2be12f05 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py @@ -1,6 +1,8 @@ import json from unittest.mock import MagicMock +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, StatusCode from core.ops.aliyun_trace.entities.semconv import ( @@ -24,8 +26,6 @@ from core.ops.aliyun_trace.utils import ( serialize_json_data, ) from core.rag.models.document import Document -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py index 8ebf441921..97f7a16327 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py @@ -5,6 +5,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( @@ -25,7 +26,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace -from graphon.enums import BuiltinNodeTypes from models import EndUser from models.enums import MessageStatus diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py index 34c64c54a1..bfe916f018 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py @@ -3,6 +3,7 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( @@ -21,7 +22,6 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py index afc5726ede..f4c485a9fc 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py @@ -9,6 +9,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( @@ -21,7 +22,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds -from graphon.enums import BuiltinNodeTypes # ── Helpers ────────────────────────────────────────────────────────────────── diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py index c02ac413f2..1cb32f2ee0 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py @@ -5,6 +5,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( @@ -18,7 +19,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser from models.enums import MessageStatus diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py index 6113e5c6c8..696f859b6f 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py @@ -1,6 +1,8 @@ from datetime import datetime from unittest.mock import MagicMock, patch +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import StatusCode from core.ops.entities.trace_entity import ( @@ -25,8 +27,6 @@ from core.ops.tencent_trace.entities.semconv import ( ) from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.rag.models.document import Document -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class TestTencentSpanBuilder: diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py index 265652381c..382e5dadc3 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py @@ -2,6 +2,8 @@ import logging from unittest.mock import MagicMock, patch import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( @@ -14,8 +16,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.tencent_trace.tencent_trace import TencentDataTrace -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes from models import Account, App, TenantAccountJoin logger = logging.getLogger(__name__) diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py index 4b925390d9..6b5cb5b09a 100644 --- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py +++ b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py @@ -1,7 +1,7 @@ +from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes from openinference.semconv.trace import OpenInferenceSpanKindValues from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind -from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes class TestGetNodeSpanKind: diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py index 531c7de05f..5014f40afc 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py @@ -7,6 +7,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from weave.trace_server.trace_server_interface import TraceStatus from core.ops.entities.config_entity import WeaveConfig @@ -22,7 +23,6 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.ops.weave_trace.weave_trace import WeaveDataTrace -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey # ── Helpers ────────────────────────────────────────────────────────────────── diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py index c24d3ac012..543b278715 100644 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py @@ -1,9 +1,10 @@ from types import SimpleNamespace from unittest.mock import patch +from graphon.model_runtime.entities.message_entities import UserPromptMessage + from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation from core.plugin.entities.request import RequestInvokeSummary -from graphon.model_runtime.entities.message_entities import UserPromptMessage def test_system_model_helpers_forward_user_id(): diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py index 68aa130518..f8d0e127b1 100644 --- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -6,15 +6,15 @@ from types import SimpleNamespace from unittest.mock import Mock, sentinel import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl import model_runtime as model_runtime_module from core.plugin.impl.model import PluginModelClient from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_schema() -> AIModelEntity: diff --git a/api/tests/unit_tests/core/plugin/test_plugin_entities.py b/api/tests/unit_tests/core/plugin/test_plugin_entities.py index f1c4c7e700..a812b01c5b 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_entities.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_entities.py @@ -4,6 +4,12 @@ from enum import StrEnum import pytest from flask import Response +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) from pydantic import ValidationError from core.plugin.entities.endpoint import EndpointEntityWithInstance @@ -25,12 +31,6 @@ from core.plugin.entities.request import ( ) from core.plugin.utils.http_parser import serialize_response from core.tools.entities.common_entities import I18nObject -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) class TestEndpointEntity: diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index af86f917b1..3063ca0197 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -17,6 +17,14 @@ from unittest.mock import MagicMock, patch import httpx import pytest +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from core.plugin.entities.plugin_daemon import ( @@ -37,14 +45,6 @@ from core.plugin.impl.exc import ( ) from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError class TestPluginRuntimeExecution: diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index 4d4313dd84..90730dff5a 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -1,13 +1,12 @@ from collections.abc import Generator import pytest +from graphon.file import File, FileTransferMethod, FileType from core.agent.entities import AgentInvokeMessage from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File class TestChunkMerger: diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 395d392127..2b280dd674 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,13 +2,6 @@ from typing import cast from unittest.mock import MagicMock, patch import pytest - -from configs import dify_config -from core.app.app_config.entities import ModelConfigEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -18,6 +11,13 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) + +from configs import dify_config +from core.app.app_config.entities import ModelConfigEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 803afa54d7..4a54649b28 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -1,11 +1,5 @@ from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import ( - ModelConfigWithCredentialsEntity, -) -from core.entities.provider_configuration import ProviderModelBundle -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, @@ -13,6 +7,13 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +from core.app.entities.app_invoke_entities import ( + ModelConfigWithCredentialsEntity, +) +from core.entities.provider_configuration import ProviderModelBundle +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index 5d865d934c..a4b3960b0a 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,5 +1,3 @@ -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -9,6 +7,9 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil + def test_build_prompt_message_with_prompt_message_contents(): prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")]) diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 9f9ea33695..e35ce2c48a 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -2,9 +2,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.prompt.prompt_transform import PromptTransform -from graphon.model_runtime.entities.model_entities import ModelPropertyKey # from core.app.app_config.entities import ModelConfigEntity # from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index 0dc74b33df..3f188cfbb4 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -2,6 +2,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -18,12 +24,6 @@ from core.prompt.prompt_templates.advanced_prompt_templates import ( CONTEXT, ) from core.prompt.simple_prompt_transform import SimplePromptTransform -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - TextPromptMessageContent, - UserPromptMessage, -) from models.model import AppMode, Conversation diff --git a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py index 1f3247590c..006b4e7345 100644 --- a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py +++ b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py @@ -1,12 +1,13 @@ from unittest.mock import MagicMock, patch +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError + from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError def _doc(content: str) -> Document: diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py index bfa78fe565..6fd44be4d4 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -12,11 +12,11 @@ from unittest.mock import Mock, patch import numpy as np import pytest +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from sqlalchemy.exc import IntegrityError from core.rag.embedding.cached_embedding import CacheEmbedding -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 392f0b458b..d7ba944e58 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -49,10 +49,6 @@ from unittest.mock import Mock, patch import numpy as np import pytest -from sqlalchemy.exc import IntegrityError - -from core.entities.embedding_type import EmbeddingInputType -from core.rag.embedding.cached_embedding import CacheEmbedding from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from graphon.model_runtime.errors.invoke import ( @@ -60,6 +56,10 @@ from graphon.model_runtime.errors.invoke import ( InvokeConnectionError, InvokeRateLimitError, ) +from sqlalchemy.exc import IntegrityError + +from core.entities.embedding_type import EmbeddingInputType +from core.rag.embedding.cached_embedding import CacheEmbedding from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index c861871f02..cc2873dd3f 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -2,14 +2,14 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelFeature from core.entities.knowledge_entities import PreviewDetail from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.models.document import AttachmentDocument, Document -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent -from graphon.model_runtime.entities.model_entities import ModelFeature class TestParagraphIndexProcessor: diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index 059876d410..450e716636 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -53,6 +53,7 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm.exc import ObjectDeletedError from core.errors.error import ProviderTokenNotInitError @@ -63,7 +64,6 @@ from core.indexing_runner import ( ) from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document -from graphon.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index 415597f336..2ec7f0498e 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -17,6 +17,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_manager import ModelInstance from core.rag.index_processor.constant.doc_type import DocType @@ -28,7 +29,6 @@ from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner -from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult def create_mock_model_instance() -> ModelInstance: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index a7e62e7b0a..c11426163e 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -6,6 +6,8 @@ from uuid import uuid4 import pytest from flask import Flask, current_app +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelFeature from sqlalchemy import column from core.app.app_config.entities import ( @@ -35,8 +37,6 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.entities.model_entities import ModelFeature from models.dataset import Dataset from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py index 43c521dcfd..5a2ecb8220 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py @@ -1,8 +1,9 @@ from unittest.mock import Mock -from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter from graphon.model_runtime.entities.llm_entities import LLMUsage +from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter + class TestFunctionCallMultiDatasetRouter: def test_invoke_returns_none_when_no_tools(self) -> None: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py index c56528cf55..539ac0f849 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -1,12 +1,13 @@ from types import SimpleNamespace from unittest.mock import Mock, patch -from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish -from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.message_entities import PromptMessageRole from graphon.model_runtime.entities.model_entities import ModelType +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish +from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter + class TestReactMultiDatasetRouter: def test_invoke_returns_none_when_no_tools(self) -> None: diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index 2735ec512f..e229d5fc1a 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -9,9 +9,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowType from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository -from graphon.entities.workflow_execution import WorkflowExecution, WorkflowType from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 05b4f3a053..7dbf78d0f0 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -9,14 +9,14 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest - -from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from core.repositories.factory import OrderConfig from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) from graphon.enums import BuiltinNodeTypes + +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 8be1ac318c..0fc82dda53 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -7,6 +7,11 @@ from datetime import datetime from types import SimpleNamespace import pytest +from graphon.nodes.human_input.entities import ( + FormDefinition, + UserAction, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( HumanInputFormRecord, @@ -21,11 +26,6 @@ from core.workflow.human_input_compat import ( ExternalRecipient, MemberRecipient, ) -from graphon.nodes.human_input.entities import ( - FormDefinition, - UserAction, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py index 1297a95df1..8ff0e40587 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -9,6 +9,8 @@ from typing import Any from unittest.mock import MagicMock import pytest +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( FormCreateParams, @@ -29,8 +31,6 @@ from core.workflow.human_input_compat import ( MemberRecipient, WebAppDeliveryMethod, ) -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import HumanInputFormRecipient, RecipientType diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py index 6cb3c3c6ac..e5c3e85487 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -3,11 +3,12 @@ from unittest.mock import MagicMock from uuid import uuid4 import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from graphon.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType from models import Account, CreatorUserRole, EndUser, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py index 6af7b02d4c..5b4d26b780 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -10,6 +10,12 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) from sqlalchemy import Engine, create_engine from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker @@ -23,12 +29,6 @@ from core.repositories.sqlalchemy_workflow_node_execution_repository import ( _find_first, _replace_or_append_offload, ) -from graphon.entities import WorkflowNodeExecution -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) from models import Account, EndUser from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py index abdbc72085..84fe522388 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py @@ -4,17 +4,17 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index 5af1376a0a..27729e7f06 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -11,17 +11,17 @@ from datetime import UTC, datetime from typing import Any from unittest.mock import MagicMock +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from sqlalchemy import Engine from configs import dify_config from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionTriggeredFrom from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index f17927f16b..ac65d0c02b 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,6 +1,7 @@ import json from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig + from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index afea9144c0..f5efb78b61 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -2,12 +2,12 @@ from unittest.mock import MagicMock, patch import pytest import redis +from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.model_manager import LBModelManager from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.model_entities import ModelType @pytest.fixture diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index b19a21d7f4..331166fe63 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -1,6 +1,15 @@ from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormOption, + FormType, + ProviderEntity, +) from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus from core.entities.provider_entities import ( @@ -12,15 +21,6 @@ from core.entities.provider_entities import ( RestrictModel, SystemConfiguration, ) -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormOption, - FormType, - ProviderEntity, -) from models.provider import Provider, ProviderType diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 7f6a50af99..259cb5fdd0 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -2,12 +2,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelSettings from core.provider_manager import ProviderManager -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType from models.provider import LoadBalancingModelConfig, ProviderModelSetting, TenantDefaultModel from models.provider_ids import ModelProviderID diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py index 1ff81f6120..5d744f88c9 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py @@ -6,13 +6,13 @@ from typing import Any from unittest.mock import patch import pytest +from graphon.model_runtime.entities.message_entities import UserPromptMessage from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType -from graphon.model_runtime.entities.message_entities import UserPromptMessage class _BuiltinDummyTool(BuiltinTool): diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py index 9ac280e31a..ee0ce51eec 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -6,6 +6,8 @@ from datetime import date from types import SimpleNamespace import pytest +from graphon.file import FileType +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -27,8 +29,6 @@ from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.errors import ToolInvokeError -from graphon.file.enums import FileType -from graphon.model_runtime.entities.model_entities import ModelPropertyKey def _build_builtin_tool(tool_cls: type[BuiltinTool]) -> BuiltinTool: diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py index b3442636b7..7fcebde3c5 100644 --- a/api/tests/unit_tests/core/tools/test_tool_file_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -12,9 +12,9 @@ from unittest.mock import MagicMock, Mock, patch import httpx import pytest +from graphon.file import FileTransferMethod from core.tools.tool_file_manager import ToolFileManager -from graphon.file import FileTransferMethod def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]: diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py index a4a563a4a1..52f262e1cf 100644 --- a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -13,8 +13,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest - -from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, @@ -24,6 +22,8 @@ from graphon.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) +from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils + def _mock_model_instance(*, schema: dict | None = None) -> SimpleNamespace: model_type_instance = Mock() diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py index 43f3fbd5c9..0e3a7e623a 100644 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -1,9 +1,9 @@ import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils -from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_ensure_no_human_input_nodes_passes_for_non_human_input(): diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py index b147d7fcdb..2607861b59 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -4,6 +4,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( @@ -13,7 +14,6 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from graphon.variables.input_entities import VariableEntity, VariableEntityType def _controller() -> WorkflowToolProviderController: diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 72a73dd936..c20edd7400 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -11,6 +11,7 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -24,7 +25,6 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool -from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType class StubScalars: diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py index ee7a3d9c96..78622b78b6 100644 --- a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -11,6 +11,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes, NodeType from core.plugin.entities.request import TriggerInvokeEventResponse from core.trigger.constants import ( @@ -26,7 +27,6 @@ from core.trigger.debug.event_selectors import ( select_trigger_debug_events, ) from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent -from graphon.enums import BuiltinNodeTypes, NodeType from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 72052c8c05..7406b88270 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -2,11 +2,6 @@ import dataclasses import orjson import pytest -from pydantic import BaseModel - -from core.helper import encrypter -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool from graphon.file import File, FileTransferMethod, FileType from graphon.runtime import VariablePool from graphon.variables.segment_group import SegmentGroup @@ -47,6 +42,11 @@ from graphon.variables.variables import ( StringVariable, Variable, ) +from pydantic import BaseModel + +from core.helper import encrypter +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool def _build_variable_pool( diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index d4e862220a..37ecd2890b 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,5 +1,4 @@ import pytest - from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import StringSegment from graphon.variables.types import ArrayValidation, SegmentType diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 14f9b2991d..09254e17a3 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -9,9 +9,7 @@ from dataclasses import dataclass from typing import Any import pytest - -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File +from graphon.file import File, FileTransferMethod, FileType from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ( ArrayFileSegment, diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index dae5e1ce98..75b01bf42e 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -1,6 +1,4 @@ import pytest -from pydantic import ValidationError - from graphon.variables import ( ArrayFileVariable, ArrayVariable, @@ -12,6 +10,7 @@ from graphon.variables import ( StringVariable, ) from graphon.variables.variables import VariableBase +from pydantic import ValidationError def test_frozen_variables(): diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py deleted file mode 100644 index ef5500b72f..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ /dev/null @@ -1,307 +0,0 @@ -import json -from time import time -from unittest.mock import MagicMock, patch - -import pytest - -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from graphon.variables.variables import StringVariable - - -class StubCoordinator: - def __init__(self) -> None: - self.state = "initial" - - def dumps(self) -> str: - return json.dumps({"state": self.state}) - - def loads(self, data: str) -> None: - payload = json.loads(data) - self.state = payload["state"] - - -class TestGraphRuntimeState: - def test_execution_context_defaults_to_empty_context(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - - with state.execution_context: - assert state.execution_context is not None - - state.execution_context = None - - with state.execution_context: - assert state.execution_context is not None - - def test_property_getters_and_setters(self): - # FIXME(-LAN-): Mock VariablePool if needed - variable_pool = VariablePool() - start_time = time() - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=start_time) - - # Test variable_pool property (read-only) - assert state.variable_pool == variable_pool - - # Test start_at property - assert state.start_at == start_time - new_time = time() + 100 - state.start_at = new_time - assert state.start_at == new_time - - # Test total_tokens property - assert state.total_tokens == 0 - state.total_tokens = 100 - assert state.total_tokens == 100 - - # Test node_run_steps property - assert state.node_run_steps == 0 - state.node_run_steps = 5 - assert state.node_run_steps == 5 - - def test_outputs_immutability(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test that getting outputs returns a copy - outputs1 = state.outputs - outputs2 = state.outputs - assert outputs1 == outputs2 - assert outputs1 is not outputs2 # Different objects - - # Test that modifying retrieved outputs doesn't affect internal state - outputs = state.outputs - outputs["test"] = "value" - assert "test" not in state.outputs - - # Test set_output method - state.set_output("key1", "value1") - assert state.get_output("key1") == "value1" - - # Test update_outputs method - state.update_outputs({"key2": "value2", "key3": "value3"}) - assert state.get_output("key2") == "value2" - assert state.get_output("key3") == "value3" - - def test_llm_usage_immutability(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test that getting llm_usage returns a copy - usage1 = state.llm_usage - usage2 = state.llm_usage - assert usage1 is not usage2 # Different objects - - def test_type_validation(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test total_tokens validation - with pytest.raises(ValueError): - state.total_tokens = -1 - - # Test node_run_steps validation - with pytest.raises(ValueError): - state.node_run_steps = -1 - - def test_helper_methods(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test increment_node_run_steps - initial_steps = state.node_run_steps - state.increment_node_run_steps() - assert state.node_run_steps == initial_steps + 1 - - # Test add_tokens - initial_tokens = state.total_tokens - state.add_tokens(50) - assert state.total_tokens == initial_tokens + 50 - - # Test add_tokens validation - with pytest.raises(ValueError): - state.add_tokens(-1) - - def test_ready_queue_default_instantiation(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - - queue = state.ready_queue - - from graphon.graph_engine.ready_queue import InMemoryReadyQueue - - assert isinstance(queue, InMemoryReadyQueue) - - def test_graph_execution_lazy_instantiation(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - - execution = state.graph_execution - - from graphon.graph_engine.domain.graph_execution import GraphExecution - - assert isinstance(execution, GraphExecution) - assert execution.workflow_id == "" - assert state.graph_execution is execution - - def test_response_coordinator_configuration(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - with pytest.raises(ValueError): - _ = state.response_coordinator - - mock_graph = MagicMock() - with patch( - "graphon.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True - ) as coordinator_cls: - coordinator_instance = coordinator_cls.return_value - state.configure(graph=mock_graph) - - assert state.response_coordinator is coordinator_instance - coordinator_cls.assert_called_once_with(variable_pool=variable_pool, graph=mock_graph) - - # Configure again with same graph should be idempotent - state.configure(graph=mock_graph) - - other_graph = MagicMock() - with pytest.raises(ValueError): - state.attach_graph(other_graph) - - def test_read_only_wrapper_exposes_additional_state(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - state.configure() - - wrapper = ReadOnlyGraphRuntimeStateWrapper(state) - - assert wrapper.ready_queue_size == 0 - assert wrapper.exceptions_count == 0 - - def test_read_only_wrapper_serializes_runtime_state(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - state.total_tokens = 5 - state.set_output("result", {"success": True}) - state.ready_queue.put("node-1") - - wrapper = ReadOnlyGraphRuntimeStateWrapper(state) - - wrapper_snapshot = json.loads(wrapper.dumps()) - state_snapshot = json.loads(state.dumps()) - - assert wrapper_snapshot == state_snapshot - - def test_dumps_and_loads_roundtrip_with_response_coordinator(self): - variable_pool = VariablePool() - variable_pool.add(("node1", "value"), "payload") - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - state.total_tokens = 10 - state.node_run_steps = 3 - state.set_output("final", {"result": True}) - usage = LLMUsage.from_metadata( - { - "prompt_tokens": 2, - "completion_tokens": 3, - "total_tokens": 5, - "total_price": "1.23", - "currency": "USD", - "latency": 0.5, - } - ) - state.llm_usage = usage - state.ready_queue.put("node-A") - - graph_execution = state.graph_execution - graph_execution.workflow_id = "wf-123" - graph_execution.exceptions_count = 4 - graph_execution.started = True - - mock_graph = MagicMock() - stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub, autospec=True): - state.attach_graph(mock_graph) - - stub.state = "configured" - - snapshot = state.dumps() - - restored = GraphRuntimeState.from_snapshot(snapshot) - - assert restored.total_tokens == 10 - assert restored.node_run_steps == 3 - assert restored.get_output("final") == {"result": True} - assert restored.llm_usage.total_tokens == usage.total_tokens - assert restored.ready_queue.qsize() == 1 - assert restored.ready_queue.get(timeout=0.01) == "node-A" - - restored_segment = restored.variable_pool.get(("node1", "value")) - assert restored_segment is not None - assert restored_segment.value == "payload" - - restored_execution = restored.graph_execution - assert restored_execution.workflow_id == "wf-123" - assert restored_execution.exceptions_count == 4 - assert restored_execution.started is True - - new_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True): - restored.attach_graph(mock_graph) - - assert new_stub.state == "configured" - - def test_loads_rehydrates_existing_instance(self): - variable_pool = VariablePool() - variable_pool.add(("node", "key"), "value") - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - state.total_tokens = 7 - state.node_run_steps = 2 - state.set_output("foo", "bar") - state.ready_queue.put("node-1") - - execution = state.graph_execution - execution.workflow_id = "wf-456" - execution.started = True - - mock_graph = MagicMock() - original_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub, autospec=True): - state.attach_graph(mock_graph) - - original_stub.state = "configured" - snapshot = state.dumps() - - new_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True): - restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) - restored.attach_graph(mock_graph) - restored.loads(snapshot) - - assert restored.total_tokens == 7 - assert restored.node_run_steps == 2 - assert restored.get_output("foo") == "bar" - assert restored.ready_queue.qsize() == 1 - assert restored.ready_queue.get(timeout=0.01) == "node-1" - - restored_segment = restored.variable_pool.get(("node", "key")) - assert restored_segment is not None - assert restored_segment.value == "value" - - restored_execution = restored.graph_execution - assert restored_execution.workflow_id == "wf-456" - assert restored_execution.started is True - - assert new_stub.state == "configured" - - def test_snapshot_restore_preserves_updated_conversation_variable(self): - variable_pool = VariablePool( - conversation_variables=[StringVariable(name="session_name", value="before")], - ) - variable_pool.add((CONVERSATION_VARIABLE_NODE_ID, "session_name"), "after") - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - snapshot = state.dumps() - restored = GraphRuntimeState.from_snapshot(snapshot) - - restored_value = restored.variable_pool.get((CONVERSATION_VARIABLE_NODE_ID, "session_name")) - assert restored_value is not None - assert restored_value.value == "after" diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py deleted file mode 100644 index 856ec959b7..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Tests for PauseReason discriminated union serialization/deserialization. -""" - -import pytest -from pydantic import BaseModel, ValidationError - -from graphon.entities.pause_reason import ( - HumanInputRequired, - PauseReason, - SchedulingPause, -) - - -class _Holder(BaseModel): - """Helper model that embeds PauseReason for union tests.""" - - reason: PauseReason - - -class TestPauseReasonDiscriminator: - """Test suite for PauseReason union discriminator.""" - - @pytest.mark.parametrize( - ("dict_value", "expected"), - [ - pytest.param( - { - "reason": { - "TYPE": "human_input_required", - "form_id": "form_id", - "form_content": "form_content", - "node_id": "node_id", - "node_title": "node_title", - }, - }, - HumanInputRequired( - form_id="form_id", - form_content="form_content", - node_id="node_id", - node_title="node_title", - ), - id="HumanInputRequired", - ), - pytest.param( - { - "reason": { - "TYPE": "scheduled_pause", - "message": "Hold on", - } - }, - SchedulingPause(message="Hold on"), - id="SchedulingPause", - ), - ], - ) - def test_model_validate(self, dict_value, expected): - """Ensure scheduled pause payloads with lowercase TYPE deserialize.""" - holder = _Holder.model_validate(dict_value) - - assert type(holder.reason) == type(expected) - assert holder.reason == expected - - @pytest.mark.parametrize( - "reason", - [ - HumanInputRequired( - form_id="form_id", - form_content="form_content", - node_id="node_id", - node_title="node_title", - ), - SchedulingPause(message="Hold on"), - ], - ids=lambda x: type(x).__name__, - ) - def test_model_construct(self, reason): - holder = _Holder(reason=reason) - assert holder.reason == reason - - def test_model_construct_with_invalid_type(self): - with pytest.raises(ValidationError): - holder = _Holder(reason=object()) # type: ignore - - def test_unknown_type_fails_validation(self): - """Unknown TYPE values should raise a validation error.""" - with pytest.raises(ValidationError): - _Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}}) diff --git a/api/tests/unit_tests/core/workflow/entities/test_template.py b/api/tests/unit_tests/core/workflow/entities/test_template.py deleted file mode 100644 index e8304b9bcd..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_template.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Tests for template module.""" - -from graphon.nodes.base.template import Template, TextSegment, VariableSegment - - -class TestTemplate: - """Test Template class functionality.""" - - def test_from_answer_template_simple(self): - """Test parsing a simple answer template.""" - template_str = "Hello, {{#node1.name#}}!" - template = Template.from_answer_template(template_str) - - assert len(template.segments) == 3 - assert isinstance(template.segments[0], TextSegment) - assert template.segments[0].text == "Hello, " - assert isinstance(template.segments[1], VariableSegment) - assert template.segments[1].selector == ["node1", "name"] - assert isinstance(template.segments[2], TextSegment) - assert template.segments[2].text == "!" - - def test_from_answer_template_multiple_vars(self): - """Test parsing an answer template with multiple variables.""" - template_str = "Hello {{#node1.name#}}, your age is {{#node2.age#}}." - template = Template.from_answer_template(template_str) - - assert len(template.segments) == 5 - assert isinstance(template.segments[0], TextSegment) - assert template.segments[0].text == "Hello " - assert isinstance(template.segments[1], VariableSegment) - assert template.segments[1].selector == ["node1", "name"] - assert isinstance(template.segments[2], TextSegment) - assert template.segments[2].text == ", your age is " - assert isinstance(template.segments[3], VariableSegment) - assert template.segments[3].selector == ["node2", "age"] - assert isinstance(template.segments[4], TextSegment) - assert template.segments[4].text == "." - - def test_from_answer_template_no_vars(self): - """Test parsing an answer template with no variables.""" - template_str = "Hello, world!" - template = Template.from_answer_template(template_str) - - assert len(template.segments) == 1 - assert isinstance(template.segments[0], TextSegment) - assert template.segments[0].text == "Hello, world!" - - def test_from_end_outputs_single(self): - """Test creating template from End node outputs with single variable.""" - outputs_config = [{"variable": "text", "value_selector": ["node1", "text"]}] - template = Template.from_end_outputs(outputs_config) - - assert len(template.segments) == 1 - assert isinstance(template.segments[0], VariableSegment) - assert template.segments[0].selector == ["node1", "text"] - - def test_from_end_outputs_multiple(self): - """Test creating template from End node outputs with multiple variables.""" - outputs_config = [ - {"variable": "text", "value_selector": ["node1", "text"]}, - {"variable": "result", "value_selector": ["node2", "result"]}, - ] - template = Template.from_end_outputs(outputs_config) - - assert len(template.segments) == 3 - assert isinstance(template.segments[0], VariableSegment) - assert template.segments[0].selector == ["node1", "text"] - assert template.segments[0].variable_name == "text" - assert isinstance(template.segments[1], TextSegment) - assert template.segments[1].text == "\n" - assert isinstance(template.segments[2], VariableSegment) - assert template.segments[2].selector == ["node2", "result"] - assert template.segments[2].variable_name == "result" - - def test_from_end_outputs_empty(self): - """Test creating template from empty End node outputs.""" - outputs_config = [] - template = Template.from_end_outputs(outputs_config) - - assert len(template.segments) == 0 - - def test_template_str_representation(self): - """Test string representation of template.""" - template_str = "Hello, {{#node1.name#}}!" - template = Template.from_answer_template(template_str) - - assert str(template) == template_str diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py deleted file mode 100644 index 7e08751683..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py +++ /dev/null @@ -1,136 +0,0 @@ -from graphon.runtime import VariablePool -from graphon.variables.segments import ( - BooleanSegment, - IntegerSegment, - NoneSegment, - StringSegment, -) - - -class TestVariablePoolGetAndNestedAttribute: - # - # _get_nested_attribute tests - # - def test__get_nested_attribute_existing_key(self): - pool = VariablePool.empty() - obj = {"a": 123} - segment = pool._get_nested_attribute(obj, "a") - assert segment is not None - assert segment.value == 123 - - def test__get_nested_attribute_missing_key(self): - pool = VariablePool.empty() - obj = {"a": 123} - segment = pool._get_nested_attribute(obj, "b") - assert segment is None - - def test__get_nested_attribute_non_dict(self): - pool = VariablePool.empty() - obj = ["not", "a", "dict"] - segment = pool._get_nested_attribute(obj, "a") - assert segment is None - - def test__get_nested_attribute_with_none_value(self): - pool = VariablePool.empty() - obj = {"a": None} - segment = pool._get_nested_attribute(obj, "a") - assert segment is not None - assert isinstance(segment, NoneSegment) - - def test__get_nested_attribute_with_empty_string(self): - pool = VariablePool.empty() - obj = {"a": ""} - segment = pool._get_nested_attribute(obj, "a") - assert segment is not None - assert isinstance(segment, StringSegment) - assert segment.value == "" - - # - # get tests - # - def test_get_simple_variable(self): - pool = VariablePool.empty() - pool.add(("node1", "var1"), "value1") - segment = pool.get(("node1", "var1")) - assert segment is not None - assert segment.value == "value1" - - def test_get_missing_variable(self): - pool = VariablePool.empty() - result = pool.get(("node1", "unknown")) - assert result is None - - def test_get_with_too_short_selector(self): - pool = VariablePool.empty() - result = pool.get(("only_node",)) - assert result is None - - def test_get_nested_object_attribute(self): - pool = VariablePool.empty() - obj_value = {"inner": "hello"} - pool.add(("node1", "obj"), obj_value) - - # simulate selector with nested attr - segment = pool.get(("node1", "obj", "inner")) - assert segment is not None - assert segment.value == "hello" - - def test_get_nested_object_missing_attribute(self): - pool = VariablePool.empty() - obj_value = {"inner": "hello"} - pool.add(("node1", "obj"), obj_value) - - result = pool.get(("node1", "obj", "not_exist")) - assert result is None - - def test_get_nested_object_attribute_with_falsy_values(self): - pool = VariablePool.empty() - obj_value = { - "inner_none": None, - "inner_empty": "", - "inner_zero": 0, - "inner_false": False, - } - pool.add(("node1", "obj"), obj_value) - - segment_none = pool.get(("node1", "obj", "inner_none")) - assert segment_none is not None - assert isinstance(segment_none, NoneSegment) - - segment_empty = pool.get(("node1", "obj", "inner_empty")) - assert segment_empty is not None - assert isinstance(segment_empty, StringSegment) - assert segment_empty.value == "" - - segment_zero = pool.get(("node1", "obj", "inner_zero")) - assert segment_zero is not None - assert isinstance(segment_zero, IntegerSegment) - assert segment_zero.value == 0 - - segment_false = pool.get(("node1", "obj", "inner_false")) - assert segment_false is not None - assert isinstance(segment_false, BooleanSegment) - assert segment_false.value is False - - -class TestVariablePoolGetNotModifyVariableDictionary: - _NODE_ID = "start" - _VAR_NAME = "name" - - def test_convert_to_template_should_not_introduce_extra_keys(self): - pool = VariablePool.empty() - pool.add([self._NODE_ID, self._VAR_NAME], 0) - pool.convert_template("The start.name is {{#start.name#}}") - assert "The start" not in pool.variable_dictionary - - def test_get_should_not_modify_variable_dictionary(self): - pool = VariablePool.empty() - pool.get([self._NODE_ID, self._VAR_NAME]) - assert len(pool.variable_dictionary) == 0 - assert "start" not in pool.variable_dictionary - - pool = VariablePool.empty() - pool.add([self._NODE_ID, self._VAR_NAME], "Joe") - pool.get([self._NODE_ID, "count"]) - start_subdict = pool.variable_dictionary[self._NODE_ID] - assert "count" not in start_subdict diff --git a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py deleted file mode 100644 index 5e697f22f3..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py +++ /dev/null @@ -1,225 +0,0 @@ -""" -Unit tests for WorkflowNodeExecution domain model, focusing on process_data truncation functionality. -""" - -from dataclasses import dataclass -from datetime import datetime -from typing import Any - -import pytest - -from graphon.entities.workflow_node_execution import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes - - -class TestWorkflowNodeExecutionProcessDataTruncation: - """Test process_data truncation functionality in WorkflowNodeExecution domain model.""" - - def create_workflow_node_execution( - self, - process_data: dict[str, Any] | None = None, - ) -> WorkflowNodeExecution: - """Create a WorkflowNodeExecution instance for testing.""" - return WorkflowNodeExecution( - id="test-execution-id", - workflow_id="test-workflow-id", - index=1, - node_id="test-node-id", - node_type=BuiltinNodeTypes.LLM, - title="Test Node", - process_data=process_data, - created_at=datetime.now(), - ) - - def test_initial_process_data_truncated_state(self): - """Test that process_data_truncated returns False initially.""" - execution = self.create_workflow_node_execution() - - assert execution.process_data_truncated is False - assert execution.get_truncated_process_data() is None - - def test_set_and_get_truncated_process_data(self): - """Test setting and getting truncated process_data.""" - execution = self.create_workflow_node_execution() - test_truncated_data = {"truncated": True, "key": "value"} - - execution.set_truncated_process_data(test_truncated_data) - - assert execution.process_data_truncated is True - assert execution.get_truncated_process_data() == test_truncated_data - - def test_set_truncated_process_data_to_none(self): - """Test setting truncated process_data to None.""" - execution = self.create_workflow_node_execution() - - # First set some data - execution.set_truncated_process_data({"key": "value"}) - assert execution.process_data_truncated is True - - # Then set to None - execution.set_truncated_process_data(None) - assert execution.process_data_truncated is False - assert execution.get_truncated_process_data() is None - - def test_get_response_process_data_with_no_truncation(self): - """Test get_response_process_data when no truncation is set.""" - original_data = {"original": True, "data": "value"} - execution = self.create_workflow_node_execution(process_data=original_data) - - response_data = execution.get_response_process_data() - - assert response_data == original_data - assert execution.process_data_truncated is False - - def test_get_response_process_data_with_truncation(self): - """Test get_response_process_data when truncation is set.""" - original_data = {"original": True, "large_data": "x" * 10000} - truncated_data = {"original": True, "large_data": "[TRUNCATED]"} - - execution = self.create_workflow_node_execution(process_data=original_data) - execution.set_truncated_process_data(truncated_data) - - response_data = execution.get_response_process_data() - - # Should return truncated data, not original - assert response_data == truncated_data - assert response_data != original_data - assert execution.process_data_truncated is True - - def test_get_response_process_data_with_none_process_data(self): - """Test get_response_process_data when process_data is None.""" - execution = self.create_workflow_node_execution(process_data=None) - - response_data = execution.get_response_process_data() - - assert response_data is None - assert execution.process_data_truncated is False - - def test_consistency_with_inputs_outputs_pattern(self): - """Test that process_data truncation follows the same pattern as inputs/outputs.""" - execution = self.create_workflow_node_execution() - - # Test that all truncation methods exist and behave consistently - test_data = {"test": "data"} - - # Test inputs truncation - execution.set_truncated_inputs(test_data) - assert execution.inputs_truncated is True - assert execution.get_truncated_inputs() == test_data - - # Test outputs truncation - execution.set_truncated_outputs(test_data) - assert execution.outputs_truncated is True - assert execution.get_truncated_outputs() == test_data - - # Test process_data truncation - execution.set_truncated_process_data(test_data) - assert execution.process_data_truncated is True - assert execution.get_truncated_process_data() == test_data - - @pytest.mark.parametrize( - "test_data", - [ - {"simple": "value"}, - {"nested": {"key": "value"}}, - {"list": [1, 2, 3]}, - {"mixed": {"string": "value", "number": 42, "list": [1, 2]}}, - {}, # empty dict - ], - ) - def test_truncated_process_data_with_various_data_types(self, test_data): - """Test that truncated process_data works with various data types.""" - execution = self.create_workflow_node_execution() - - execution.set_truncated_process_data(test_data) - - assert execution.process_data_truncated is True - assert execution.get_truncated_process_data() == test_data - assert execution.get_response_process_data() == test_data - - -@dataclass -class ProcessDataScenario: - """Test scenario data for process_data functionality.""" - - name: str - original_data: dict[str, Any] | None - truncated_data: dict[str, Any] | None - expected_truncated_flag: bool - expected_response_data: dict[str, Any] | None - - -class TestWorkflowNodeExecutionProcessDataScenarios: - """Test various scenarios for process_data handling.""" - - def get_process_data_scenarios(self) -> list[ProcessDataScenario]: - """Create test scenarios for process_data functionality.""" - return [ - ProcessDataScenario( - name="no_process_data", - original_data=None, - truncated_data=None, - expected_truncated_flag=False, - expected_response_data=None, - ), - ProcessDataScenario( - name="process_data_without_truncation", - original_data={"small": "data"}, - truncated_data=None, - expected_truncated_flag=False, - expected_response_data={"small": "data"}, - ), - ProcessDataScenario( - name="process_data_with_truncation", - original_data={"large": "x" * 10000, "metadata": "info"}, - truncated_data={"large": "[TRUNCATED]", "metadata": "info"}, - expected_truncated_flag=True, - expected_response_data={"large": "[TRUNCATED]", "metadata": "info"}, - ), - ProcessDataScenario( - name="empty_process_data", - original_data={}, - truncated_data=None, - expected_truncated_flag=False, - expected_response_data={}, - ), - ProcessDataScenario( - name="complex_nested_data_with_truncation", - original_data={ - "config": {"setting": "value"}, - "logs": ["log1", "log2"] * 1000, # Large list - "status": "running", - }, - truncated_data={"config": {"setting": "value"}, "logs": "[TRUNCATED: 2000 items]", "status": "running"}, - expected_truncated_flag=True, - expected_response_data={ - "config": {"setting": "value"}, - "logs": "[TRUNCATED: 2000 items]", - "status": "running", - }, - ), - ] - - @pytest.mark.parametrize( - "scenario", - get_process_data_scenarios(None), - ids=[scenario.name for scenario in get_process_data_scenarios(None)], - ) - def test_process_data_scenarios(self, scenario: ProcessDataScenario): - """Test various process_data scenarios.""" - execution = WorkflowNodeExecution( - id="test-execution-id", - workflow_id="test-workflow-id", - index=1, - node_id="test-node-id", - node_type=BuiltinNodeTypes.LLM, - title="Test Node", - process_data=scenario.original_data, - created_at=datetime.now(), - ) - - if scenario.truncated_data is not None: - execution.set_truncated_process_data(scenario.truncated_data) - - assert execution.process_data_truncated == scenario.expected_truncated_flag - assert execution.get_response_process_data() == scenario.expected_response_data diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph.py b/api/tests/unit_tests/core/workflow/graph/test_graph.py deleted file mode 100644 index b138a7dfdc..0000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph.py +++ /dev/null @@ -1,281 +0,0 @@ -"""Unit tests for Graph class methods.""" - -from unittest.mock import Mock - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState -from graphon.graph.edge import Edge -from graphon.graph.graph import Graph -from graphon.nodes.base.node import Node - - -def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node: - """Create a mock node for testing.""" - node = Mock(spec=Node) - node.id = node_id - node.execution_type = execution_type - node.state = state - node.node_type = BuiltinNodeTypes.START - return node - - -class TestMarkInactiveRootBranches: - """Test cases for _mark_inactive_root_branches method.""" - - def test_single_root_no_marking(self): - """Test that single root graph doesn't mark anything as skipped.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - } - - in_edges = {"child1": ["edge1"]} - out_edges = {"root1": ["edge1"]} - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["child1"].state == NodeState.UNKNOWN - assert edges["edge1"].state == NodeState.UNKNOWN - - def test_multiple_roots_mark_inactive(self): - """Test marking inactive root branches with multiple root nodes.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), - } - - in_edges = {"child1": ["edge1"], "child2": ["edge2"]} - out_edges = {"root1": ["edge1"], "root2": ["edge2"]} - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["child1"].state == NodeState.UNKNOWN - assert nodes["child2"].state == NodeState.SKIPPED - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - - def test_shared_downstream_node(self): - """Test that shared downstream nodes are not skipped if at least one path is active.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - "shared": create_mock_node("shared", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), - "edge3": Edge(id="edge3", tail="child1", head="shared", source_handle="source"), - "edge4": Edge(id="edge4", tail="child2", head="shared", source_handle="source"), - } - - in_edges = { - "child1": ["edge1"], - "child2": ["edge2"], - "shared": ["edge3", "edge4"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "child1": ["edge3"], - "child2": ["edge4"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["child1"].state == NodeState.UNKNOWN - assert nodes["child2"].state == NodeState.SKIPPED - assert nodes["shared"].state == NodeState.UNKNOWN # Not skipped because edge3 is active - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - assert edges["edge3"].state == NodeState.UNKNOWN - assert edges["edge4"].state == NodeState.SKIPPED - - def test_deep_branch_marking(self): - """Test marking deep branches with multiple levels.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "level1_a": create_mock_node("level1_a", NodeExecutionType.EXECUTABLE), - "level1_b": create_mock_node("level1_b", NodeExecutionType.EXECUTABLE), - "level2_a": create_mock_node("level2_a", NodeExecutionType.EXECUTABLE), - "level2_b": create_mock_node("level2_b", NodeExecutionType.EXECUTABLE), - "level3": create_mock_node("level3", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="level1_a", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="level1_b", source_handle="source"), - "edge3": Edge(id="edge3", tail="level1_a", head="level2_a", source_handle="source"), - "edge4": Edge(id="edge4", tail="level1_b", head="level2_b", source_handle="source"), - "edge5": Edge(id="edge5", tail="level2_b", head="level3", source_handle="source"), - } - - in_edges = { - "level1_a": ["edge1"], - "level1_b": ["edge2"], - "level2_a": ["edge3"], - "level2_b": ["edge4"], - "level3": ["edge5"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "level1_a": ["edge3"], - "level1_b": ["edge4"], - "level2_b": ["edge5"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["level1_a"].state == NodeState.UNKNOWN - assert nodes["level1_b"].state == NodeState.SKIPPED - assert nodes["level2_a"].state == NodeState.UNKNOWN - assert nodes["level2_b"].state == NodeState.SKIPPED - assert nodes["level3"].state == NodeState.SKIPPED - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - assert edges["edge3"].state == NodeState.UNKNOWN - assert edges["edge4"].state == NodeState.SKIPPED - assert edges["edge5"].state == NodeState.SKIPPED - - def test_non_root_execution_type(self): - """Test that nodes with non-ROOT execution type are not treated as root nodes.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "non_root": create_mock_node("non_root", NodeExecutionType.EXECUTABLE), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="non_root", head="child2", source_handle="source"), - } - - in_edges = {"child1": ["edge1"], "child2": ["edge2"]} - out_edges = {"root1": ["edge1"], "non_root": ["edge2"]} - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["non_root"].state == NodeState.UNKNOWN # Not marked as skipped - assert nodes["child1"].state == NodeState.UNKNOWN - assert nodes["child2"].state == NodeState.UNKNOWN - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.UNKNOWN - - def test_empty_graph(self): - """Test handling of empty graph structures.""" - nodes = {} - edges = {} - in_edges = {} - out_edges = {} - - # Should not raise any errors - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "non_existent") - - def test_three_roots_mark_two_inactive(self): - """Test with three root nodes where two should be marked inactive.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "root3": create_mock_node("root3", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - "child3": create_mock_node("child3", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), - "edge3": Edge(id="edge3", tail="root3", head="child3", source_handle="source"), - } - - in_edges = { - "child1": ["edge1"], - "child2": ["edge2"], - "child3": ["edge3"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "root3": ["edge3"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root2") - - assert nodes["root1"].state == NodeState.SKIPPED - assert nodes["root2"].state == NodeState.UNKNOWN # Active root - assert nodes["root3"].state == NodeState.SKIPPED - assert nodes["child1"].state == NodeState.SKIPPED - assert nodes["child2"].state == NodeState.UNKNOWN - assert nodes["child3"].state == NodeState.SKIPPED - assert edges["edge1"].state == NodeState.SKIPPED - assert edges["edge2"].state == NodeState.UNKNOWN - assert edges["edge3"].state == NodeState.SKIPPED - - def test_convergent_paths(self): - """Test convergent paths where multiple inactive branches lead to same node.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "root3": create_mock_node("root3", NodeExecutionType.ROOT), - "mid1": create_mock_node("mid1", NodeExecutionType.EXECUTABLE), - "mid2": create_mock_node("mid2", NodeExecutionType.EXECUTABLE), - "convergent": create_mock_node("convergent", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="mid1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="mid2", source_handle="source"), - "edge3": Edge(id="edge3", tail="root3", head="convergent", source_handle="source"), - "edge4": Edge(id="edge4", tail="mid1", head="convergent", source_handle="source"), - "edge5": Edge(id="edge5", tail="mid2", head="convergent", source_handle="source"), - } - - in_edges = { - "mid1": ["edge1"], - "mid2": ["edge2"], - "convergent": ["edge3", "edge4", "edge5"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "root3": ["edge3"], - "mid1": ["edge4"], - "mid2": ["edge5"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["root3"].state == NodeState.SKIPPED - assert nodes["mid1"].state == NodeState.UNKNOWN - assert nodes["mid2"].state == NodeState.SKIPPED - assert nodes["convergent"].state == NodeState.UNKNOWN # Not skipped due to active path from root1 - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - assert edges["edge3"].state == NodeState.SKIPPED - assert edges["edge4"].state == NodeState.UNKNOWN - assert edges["edge5"].state == NodeState.SKIPPED diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py deleted file mode 100644 index f3eaa1d686..0000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.graph import Graph -from graphon.nodes.base.node import Node - - -def _make_node(node_id: str, node_type: NodeType = BuiltinNodeTypes.START) -> Node: - node = MagicMock(spec=Node) - node.id = node_id - node.node_type = node_type - node.execution_type = None # attribute not used in builder path - return node - - -def test_graph_builder_creates_linear_graph(): - builder = Graph.new() - root = _make_node("root", BuiltinNodeTypes.START) - mid = _make_node("mid", BuiltinNodeTypes.LLM) - end = _make_node("end", BuiltinNodeTypes.END) - - graph = builder.add_root(root).add_node(mid).add_node(end).build() - - assert graph.root_node is root - assert graph.nodes == {"root": root, "mid": mid, "end": end} - assert len(graph.edges) == 2 - first_edge = next(iter(graph.edges.values())) - assert first_edge.tail == "root" - assert first_edge.head == "mid" - assert graph.out_edges["mid"] == [edge_id for edge_id, edge in graph.edges.items() if edge.tail == "mid"] - - -def test_graph_builder_supports_custom_predecessor(): - builder = Graph.new() - root = _make_node("root") - branch = _make_node("branch") - other = _make_node("other") - - graph = builder.add_root(root).add_node(branch).add_node(other, from_node_id="root").build() - - outgoing_root = graph.out_edges["root"] - assert len(outgoing_root) == 2 - edge_targets = {graph.edges[eid].head for eid in outgoing_root} - assert edge_targets == {"branch", "other"} - - -def test_graph_builder_validates_usage(): - builder = Graph.new() - node = _make_node("node") - - with pytest.raises(ValueError, match="Root node"): - builder.add_node(node) - - builder.add_root(node) - duplicate = _make_node("node") - with pytest.raises(ValueError, match="Duplicate"): - builder.add_node(duplicate) diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py deleted file mode 100644 index 3620a20e56..0000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py +++ /dev/null @@ -1,118 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import pytest - -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import default_system_variables -from graphon.graph import Graph -from graphon.graph.validation import GraphValidationError -from graphon.nodes import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState, VariablePool -from tests.workflow_test_utils import build_test_graph_init_params - - -def _build_iteration_graph(node_id: str) -> dict[str, Any]: - return { - "nodes": [ - { - "id": node_id, - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": [node_id, "output"], - }, - } - ], - "edges": [], - } - - -def _build_loop_graph(node_id: str) -> dict[str, Any]: - return { - "nodes": [ - { - "id": node_id, - "data": { - "type": "loop", - "title": "Loop", - "loop_count": 1, - "break_conditions": [], - "logical_operator": "and", - "loop_variables": [], - "outputs": {}, - }, - } - ], - "edges": [], - } - - -def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=default_system_variables(), - user_inputs={}, - environment_variables=[], - ), - start_at=0.0, - ) - return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) - - -def test_iteration_root_requires_skip_validation(): - node_id = "iteration-node" - graph_config = _build_iteration_graph(node_id) - node_factory = _make_factory(graph_config) - - with pytest.raises(GraphValidationError): - Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - ) - - graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - skip_validation=True, - ) - - assert graph.root_node.id == node_id - assert graph.root_node.node_type == BuiltinNodeTypes.ITERATION - - -def test_loop_root_requires_skip_validation(): - node_id = "loop-node" - graph_config = _build_loop_graph(node_id) - node_factory = _make_factory(graph_config) - - with pytest.raises(GraphValidationError): - Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - ) - - graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - skip_validation=True, - ) - - assert graph.root_node.id == node_id - assert graph.root_node.node_type == BuiltinNodeTypes.LOOP diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py deleted file mode 100644 index bfd0b48392..0000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ /dev/null @@ -1,219 +0,0 @@ -from __future__ import annotations - -import time -from collections.abc import Mapping -from dataclasses import dataclass - -import pytest - -from core.workflow.system_variables import build_system_variables -from graphon.entities import GraphInitParams -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeExecutionType, NodeType -from graphon.graph import Graph -from graphon.graph.validation import GraphValidationError -from graphon.nodes.base.node import Node -from graphon.runtime import GraphRuntimeState, VariablePool -from tests.workflow_test_utils import build_test_graph_init_params - - -class _TestNodeData(BaseNodeData): - type: NodeType | None = None - execution_type: NodeExecutionType | str | None = None - - -class _TestNode(Node[_TestNodeData]): - node_type = BuiltinNodeTypes.ANSWER - execution_type = NodeExecutionType.EXECUTABLE - - @classmethod - def version(cls) -> str: - return "1" - - def __init__( - self, - *, - id: str, - config: Mapping[str, object], - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - node_type_value = self.data.get("type") - if isinstance(node_type_value, str): - self.node_type = node_type_value - - def _run(self): - raise NotImplementedError - - def post_init(self) -> None: - super().post_init() - self._maybe_override_execution_type() - self.data = dict(self.node_data.model_dump()) - - def _maybe_override_execution_type(self) -> None: - execution_type_value = self.node_data.execution_type - if execution_type_value is None: - return - if isinstance(execution_type_value, NodeExecutionType): - self.execution_type = execution_type_value - else: - self.execution_type = NodeExecutionType(execution_type_value) - - -@dataclass(slots=True) -class _SimpleNodeFactory: - graph_init_params: GraphInitParams - graph_runtime_state: GraphRuntimeState - - def create_node(self, node_config: Mapping[str, object]) -> _TestNode: - node_id = str(node_config["id"]) - node = _TestNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) - return node - - -@pytest.fixture -def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: - graph_config: dict[str, object] = {"edges": [], "nodes": []} - init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="service-api", - call_depth=0, - ) - variable_pool = VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}) - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state) - return factory, graph_config - - -def test_graph_initialization_runs_default_validators( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -): - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - {"id": "answer", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "Answer"}}, - ] - graph_config["edges"] = [ - {"source": "start", "target": "answer", "sourceHandle": "success"}, - ] - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert graph.root_node.id == "start" - assert "answer" in graph.nodes - - -def test_graph_validation_fails_for_unknown_edge_targets( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - ] - graph_config["edges"] = [ - {"source": "start", "target": "missing", "sourceHandle": "success"}, - ] - - with pytest.raises(GraphValidationError) as exc: - Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert any(issue.code == "MISSING_NODE" for issue in exc.value.issues) - - -def test_graph_promotes_fail_branch_nodes_to_branch_execution_type( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - { - "id": "branch", - "data": { - "type": BuiltinNodeTypes.IF_ELSE, - "title": "Branch", - "error_strategy": ErrorStrategy.FAIL_BRANCH, - }, - }, - ] - graph_config["edges"] = [ - {"source": "start", "target": "branch", "sourceHandle": "success"}, - ] - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH - - -def test_graph_init_ignores_custom_note_nodes_before_node_data_validation( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - {"id": "answer", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "Answer"}}, - { - "id": "note", - "type": "custom-note", - "data": { - "type": "", - "title": "", - "desc": "", - "text": "{}", - "theme": "blue", - }, - }, - ] - graph_config["edges"] = [ - {"source": "start", "target": "answer", "sourceHandle": "success"}, - ] - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert graph.root_node.id == "start" - assert "answer" in graph.nodes - assert "note" not in graph.nodes - - -def test_graph_init_fails_for_unknown_root_node_id( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - ] - graph_config["edges"] = [] - - with pytest.raises(ValueError, match="Root node id missing not found in the graph"): - Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="missing") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/README.md b/api/tests/unit_tests/core/workflow/graph_engine/README.md index 960fef7d43..dd419f0810 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/README.md +++ b/api/tests/unit_tests/core/workflow/graph_engine/README.md @@ -1,441 +1,30 @@ -# Graph Engine Testing Framework +# Workflow Graph Engine Smoke Tests -## Overview +This directory now keeps only a small Dify-owned smoke layer around the external +`graphon` package. -This directory contains a comprehensive testing framework for the Graph Engine, including: +Retained coverage focuses on: -1. **TableTestRunner** - Advanced table-driven test framework for workflow testing -1. **Auto-Mock System** - Powerful mocking framework for testing without external dependencies +1. Dify workflow layers: + - `layers/test_llm_quota.py` + - `layers/test_observability.py` +2. Human-input resume integration: + - `test_parallel_human_input_join_resume.py` +3. One mocked tool/chatflow smoke path: + - `test_tool_in_chatflow.py` -## TableTestRunner Framework +The helper modules below remain only because the retained smoke tests use them: -The TableTestRunner (`test_table_runner.py`) provides a robust table-driven testing framework for GraphEngine workflows. +1. `test_mock_config.py` +2. `test_mock_factory.py` +3. `test_mock_nodes.py` +4. `test_table_runner.py` -### Features - -- **Table-driven testing** - Define test cases as structured data -- **Parallel test execution** - Run tests concurrently for faster execution -- **Property-based testing** - Integration with Hypothesis for fuzzing -- **Event sequence validation** - Verify correct event ordering -- **Mock configuration** - Seamless integration with the auto-mock system -- **Performance metrics** - Track execution times and bottlenecks -- **Detailed error reporting** - Comprehensive failure diagnostics - -### Basic Usage - -```python -from test_table_runner import TableTestRunner, WorkflowTestCase - -# Create test runner -runner = TableTestRunner() - -# Define test case -test_case = WorkflowTestCase( - fixture_path="simple_workflow", - inputs={"query": "Hello"}, - expected_outputs={"result": "World"}, - description="Basic workflow test", -) - -# Run single test -result = runner.run_test_case(test_case) -assert result.success -``` - -### Advanced Features - -#### Parallel Execution - -```python -runner = TableTestRunner(max_workers=8) - -test_cases = [ - WorkflowTestCase(...), - WorkflowTestCase(...), - # ... more test cases -] - -# Run tests in parallel -suite_result = runner.run_table_tests( - test_cases, - parallel=True, - fail_fast=False -) - -print(f"Success rate: {suite_result.success_rate:.1f}%") -``` - -#### Event Sequence Validation - -```python -from graphon.graph_events import ( - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, -) - -test_case = WorkflowTestCase( - fixture_path="workflow", - inputs={}, - expected_outputs={}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ] -) -``` - -### Test Suite Reports - -```python -# Run test suite -suite_result = runner.run_table_tests(test_cases) - -# Generate detailed report -report = runner.generate_report(suite_result) -print(report) - -# Access specific results -failed_results = suite_result.get_failed_results() -for result in failed_results: - print(f"Failed: {result.test_case.description}") - print(f" Error: {result.error}") -``` - -### Performance Testing - -```python -# Enable logging for performance insights -runner = TableTestRunner( - enable_logging=True, - log_level="DEBUG" -) - -# Run tests and analyze performance -suite_result = runner.run_table_tests(test_cases) - -# Get slowest tests -sorted_results = sorted( - suite_result.results, - key=lambda r: r.execution_time, - reverse=True -) - -print("Slowest tests:") -for result in sorted_results[:5]: - print(f" {result.test_case.description}: {result.execution_time:.2f}s") -``` - -## Integration: TableTestRunner + Auto-Mock System - -The TableTestRunner seamlessly integrates with the auto-mock system for comprehensive workflow testing: - -```python -from test_table_runner import TableTestRunner, WorkflowTestCase -from test_mock_config import MockConfigBuilder - -# Configure mocks -mock_config = (MockConfigBuilder() - .with_llm_response("Mocked LLM response") - .with_tool_response({"result": "mocked"}) - .with_delays(True) # Simulate realistic delays - .build()) - -# Create test case with mocking -test_case = WorkflowTestCase( - fixture_path="complex_workflow", - inputs={"query": "test"}, - expected_outputs={"answer": "Mocked LLM response"}, - use_auto_mock=True, # Enable auto-mocking - mock_config=mock_config, - description="Test with mocked services", -) - -# Run test -runner = TableTestRunner() -result = runner.run_test_case(test_case) -``` - -## Auto-Mock System - -The auto-mock system provides a powerful framework for testing workflows that contain nodes requiring third-party services (LLM, APIs, tools, etc.) without making actual external calls. This enables: - -- **Fast test execution** - No network latency or API rate limits -- **Deterministic results** - Consistent outputs for reliable testing -- **Cost savings** - No API usage charges during testing -- **Offline testing** - Tests can run without internet connectivity -- **Error simulation** - Test error handling without triggering real failures - -## Architecture - -The auto-mock system consists of three main components: - -### 1. MockNodeFactory (`test_mock_factory.py`) - -- Extends `DifyNodeFactory` to intercept node creation -- Automatically detects nodes requiring third-party services -- Returns mock node implementations instead of real ones -- Supports registration of custom mock implementations - -### 2. Mock Node Implementations (`test_mock_nodes.py`) - -- `MockLLMNode` - Mocks LLM API calls (OpenAI, Anthropic, etc.) -- `MockAgentNode` - Mocks agent execution -- `MockToolNode` - Mocks tool invocations -- `MockKnowledgeRetrievalNode` - Mocks knowledge base queries -- `MockHttpRequestNode` - Mocks HTTP requests -- `MockParameterExtractorNode` - Mocks parameter extraction -- `MockDocumentExtractorNode` - Mocks document processing -- `MockQuestionClassifierNode` - Mocks question classification - -### 3. Mock Configuration (`test_mock_config.py`) - -- `MockConfig` - Global configuration for mock behavior -- `NodeMockConfig` - Node-specific mock configuration -- `MockConfigBuilder` - Fluent interface for building configurations - -## Usage - -### Basic Example - -```python -from test_graph_engine import TableTestRunner, WorkflowTestCase -from test_mock_config import MockConfigBuilder - -# Create test runner -runner = TableTestRunner() - -# Configure mock responses -mock_config = (MockConfigBuilder() - .with_llm_response("Mocked LLM response") - .build()) - -# Define test case -test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Hello"}, - expected_outputs={"answer": "Mocked LLM response"}, - use_auto_mock=True, # Enable auto-mocking - mock_config=mock_config, -) - -# Run test -result = runner.run_test_case(test_case) -assert result.success -``` - -### Custom Node Outputs - -```python -# Configure specific outputs for individual nodes -mock_config = MockConfig() -mock_config.set_node_outputs("llm_node_123", { - "text": "Custom response for this specific node", - "usage": {"total_tokens": 50}, - "finish_reason": "stop", -}) -``` - -### Error Simulation - -```python -# Simulate node failures for error handling tests -mock_config = MockConfig() -mock_config.set_node_error("http_node", "Connection timeout") -``` - -### Simulated Delays - -```python -# Add realistic execution delays -from test_mock_config import NodeMockConfig - -node_config = NodeMockConfig( - node_id="llm_node", - outputs={"text": "Response"}, - delay=1.5, # 1.5 second delay -) -mock_config.set_node_config("llm_node", node_config) -``` - -### Custom Handlers - -```python -# Define custom logic for mock outputs -def custom_handler(node): - # Access node state and return dynamic outputs - return { - "text": f"Processed: {node.graph_runtime_state.variable_pool.get('query')}", - } - -node_config = NodeMockConfig( - node_id="llm_node", - custom_handler=custom_handler, -) -``` - -## Node Types Automatically Mocked - -The following node types are automatically mocked when `use_auto_mock=True`: - -- `LLM` - Language model nodes -- `AGENT` - Agent execution nodes -- `TOOL` - Tool invocation nodes -- `KNOWLEDGE_RETRIEVAL` - Knowledge base query nodes -- `HTTP_REQUEST` - HTTP request nodes -- `PARAMETER_EXTRACTOR` - Parameter extraction nodes -- `DOCUMENT_EXTRACTOR` - Document processing nodes -- `QUESTION_CLASSIFIER` - Question classification nodes - -## Advanced Features - -### Registering Custom Mock Implementations - -```python -from test_mock_factory import MockNodeFactory - -# Create custom mock implementation -class CustomMockNode(BaseNode): - def _run(self): - # Custom mock logic - pass - -# Register for a specific node type -factory = MockNodeFactory(...) -factory.register_mock_node_type(NodeType.CUSTOM, CustomMockNode) -``` - -### Default Configurations by Node Type - -```python -# Set defaults for all nodes of a specific type -mock_config.set_default_config(NodeType.LLM, { - "temperature": 0.7, - "max_tokens": 100, -}) -``` - -### MockConfigBuilder Fluent API - -```python -config = (MockConfigBuilder() - .with_llm_response("LLM response") - .with_agent_response("Agent response") - .with_tool_response({"result": "data"}) - .with_retrieval_response("Retrieved content") - .with_http_response({"status_code": 200, "body": "{}"}) - .with_node_output("node_id", {"output": "value"}) - .with_node_error("error_node", "Error message") - .with_delays(True) - .build()) -``` - -## Testing Workflows - -### 1. Create Workflow Fixture - -Create a YAML fixture file in `api/tests/fixtures/workflow/` directory defining your workflow graph. - -### 2. Configure Mocks - -Set up mock configurations for nodes that need third-party services. - -### 3. Define Test Cases - -Create `WorkflowTestCase` instances with inputs, expected outputs, and mock config. - -### 4. Run Tests - -Use `TableTestRunner` to execute test cases and validate results. - -## Best Practices - -1. **Use descriptive mock responses** - Make it clear in outputs that they are mocked -1. **Test both success and failure paths** - Use error simulation to test error handling -1. **Keep mock configs close to tests** - Define mocks in the same test file for clarity -1. **Use custom handlers sparingly** - Only when dynamic behavior is needed -1. **Document mock behavior** - Comment why specific mock values are chosen -1. **Validate mock accuracy** - Ensure mocks reflect real service behavior - -## Examples - -See `test_mock_example.py` for comprehensive examples including: - -- Basic LLM workflow testing -- Custom node outputs -- HTTP and tool workflow testing -- Error simulation -- Performance testing with delays - -## Running Tests - -### TableTestRunner Tests +Examples: ```bash -# Run graph engine tests (includes property-based tests) -uv run pytest api/tests/unit_tests/graphon/graph_engine/test_graph_engine.py - -# Run with specific test patterns -uv run pytest api/tests/unit_tests/graphon/graph_engine/test_graph_engine.py -k "test_echo" - -# Run with verbose output -uv run pytest api/tests/unit_tests/graphon/graph_engine/test_graph_engine.py -v +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py ``` - -### Mock System Tests - -```bash -# Run auto-mock system tests -uv run pytest api/tests/unit_tests/graphon/graph_engine/test_auto_mock_system.py - -# Run examples -uv run python api/tests/unit_tests/graphon/graph_engine/test_mock_example.py - -# Run simple validation -uv run python api/tests/unit_tests/graphon/graph_engine/test_mock_simple.py -``` - -### All Tests - -```bash -# Run all graph engine tests -uv run pytest api/tests/unit_tests/graphon/graph_engine/ - -# Run with coverage -uv run pytest api/tests/unit_tests/graphon/graph_engine/ --cov=graphon.graph_engine - -# Run in parallel -uv run pytest api/tests/unit_tests/graphon/graph_engine/ -n auto -``` - -## Troubleshooting - -### Issue: Mock not being applied - -- Ensure `use_auto_mock=True` in `WorkflowTestCase` -- Verify node ID matches in mock config -- Check that node type is in the auto-mock list - -### Issue: Unexpected outputs - -- Debug by printing `result.actual_outputs` -- Check if custom handler is overriding expected outputs -- Verify mock config is properly built - -### Issue: Import errors - -- Ensure all mock modules are in the correct path -- Check that required dependencies are installed - -## Future Enhancements - -Potential improvements to the auto-mock system: - -1. **Recording and playback** - Record real API responses for replay in tests -1. **Mock templates** - Pre-defined mock configurations for common scenarios -1. **Async support** - Better support for async node execution -1. **Mock validation** - Validate mock outputs against node schemas -1. **Performance profiling** - Built-in performance metrics for mocked workflows diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py deleted file mode 100644 index 795362b158..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ /dev/null @@ -1,315 +0,0 @@ -"""Tests for Redis command channel implementation.""" - -import json -from unittest.mock import MagicMock - -from graphon.graph_engine.command_channels.redis_channel import RedisChannel -from graphon.graph_engine.entities.commands import ( - AbortCommand, - CommandType, - GraphEngineCommand, - UpdateVariablesCommand, - VariableUpdate, -) -from graphon.variables import IntegerVariable, StringVariable - - -class TestRedisChannel: - """Test suite for RedisChannel functionality.""" - - def test_init(self): - """Test RedisChannel initialization.""" - mock_redis = MagicMock() - channel_key = "test:channel:key" - ttl = 7200 - - channel = RedisChannel(mock_redis, channel_key, ttl) - - assert channel._redis == mock_redis - assert channel._key == channel_key - assert channel._command_ttl == ttl - - def test_init_default_ttl(self): - """Test RedisChannel initialization with default TTL.""" - mock_redis = MagicMock() - channel_key = "test:channel:key" - - channel = RedisChannel(mock_redis, channel_key) - - assert channel._command_ttl == 3600 # Default TTL - - def test_send_command(self): - """Test sending a command to Redis.""" - mock_redis = MagicMock() - mock_pipe = MagicMock() - context = MagicMock() - context.__enter__.return_value = mock_pipe - context.__exit__.return_value = None - mock_redis.pipeline.return_value = context - - channel = RedisChannel(mock_redis, "test:key", 3600) - - pending_key = "test:key:pending" - - # Create a test command - command = GraphEngineCommand(command_type=CommandType.ABORT) - - # Send the command - channel.send_command(command) - - # Verify pipeline was used - mock_redis.pipeline.assert_called_once() - - # Verify rpush was called with correct data - expected_json = json.dumps(command.model_dump()) - mock_pipe.rpush.assert_called_once_with("test:key", expected_json) - - # Verify expire was set - mock_pipe.expire.assert_called_once_with("test:key", 3600) - mock_pipe.set.assert_called_once_with(pending_key, "1", ex=3600) - - # Verify execute was called - mock_pipe.execute.assert_called_once() - - def test_fetch_commands_empty(self): - """Test fetching commands when Redis list is empty.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context] - - # No pending marker - pending_pipe.execute.return_value = [None, 0] - mock_redis.llen.return_value = 0 - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert commands == [] - mock_redis.pipeline.assert_called_once() - fetch_pipe.lrange.assert_not_called() - fetch_pipe.delete.assert_not_called() - - def test_fetch_commands_with_abort_command(self): - """Test fetching abort commands from Redis.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Create abort command data - abort_command = AbortCommand() - command_json = json.dumps(abort_command.model_dump()) - - # Simulate Redis returning one command - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert len(commands) == 1 - assert isinstance(commands[0], AbortCommand) - assert commands[0].command_type == CommandType.ABORT - - def test_fetch_commands_multiple(self): - """Test fetching multiple commands from Redis.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Create multiple commands - command1 = GraphEngineCommand(command_type=CommandType.ABORT) - command2 = AbortCommand() - - command1_json = json.dumps(command1.model_dump()) - command2_json = json.dumps(command2.model_dump()) - - # Simulate Redis returning multiple commands - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command1_json.encode(), command2_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert len(commands) == 2 - assert commands[0].command_type == CommandType.ABORT - assert isinstance(commands[1], AbortCommand) - - def test_fetch_commands_with_update_variables_command(self): - """Test fetching update variables command from Redis.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - update_command = UpdateVariablesCommand( - updates=[ - VariableUpdate( - value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]), - ), - VariableUpdate( - value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]), - ), - ] - ) - command_json = json.dumps(update_command.model_dump()) - - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert len(commands) == 1 - assert isinstance(commands[0], UpdateVariablesCommand) - assert isinstance(commands[0].updates[0].value, StringVariable) - assert list(commands[0].updates[0].value.selector) == ["node1", "foo"] - assert commands[0].updates[0].value.value == "bar" - - def test_fetch_commands_skips_invalid_json(self): - """Test that invalid JSON commands are skipped.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Mix valid and invalid JSON - valid_command = AbortCommand() - valid_json = json.dumps(valid_command.model_dump()) - invalid_json = b"invalid json {" - - # Simulate Redis returning mixed valid/invalid commands - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[invalid_json, valid_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - # Should only return the valid command - assert len(commands) == 1 - assert isinstance(commands[0], AbortCommand) - - def test_deserialize_command_abort(self): - """Test deserializing an abort command.""" - channel = RedisChannel(MagicMock(), "test:key") - - abort_data = {"command_type": CommandType.ABORT} - command = channel._deserialize_command(abort_data) - - assert isinstance(command, AbortCommand) - assert command.command_type == CommandType.ABORT - - def test_deserialize_command_generic(self): - """Test deserializing a generic command.""" - channel = RedisChannel(MagicMock(), "test:key") - - # For now, only ABORT is supported, but test generic handling - generic_data = {"command_type": CommandType.ABORT} - command = channel._deserialize_command(generic_data) - - assert command is not None - assert command.command_type == CommandType.ABORT - - def test_deserialize_command_invalid(self): - """Test deserializing invalid command data.""" - channel = RedisChannel(MagicMock(), "test:key") - - # Missing command_type - invalid_data = {"some_field": "value"} - command = channel._deserialize_command(invalid_data) - - assert command is None - - def test_deserialize_command_invalid_type(self): - """Test deserializing command with invalid type.""" - channel = RedisChannel(MagicMock(), "test:key") - - # Invalid command type - invalid_data = {"command_type": "INVALID_TYPE"} - command = channel._deserialize_command(invalid_data) - - assert command is None - - def test_atomic_fetch_and_clear(self): - """Test that fetch_commands atomically fetches and clears the list.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - command = AbortCommand() - command_json = json.dumps(command.model_dump()) - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - - # First fetch should return the command - commands = channel.fetch_commands() - assert len(commands) == 1 - - # Verify both lrange and delete were called in the pipeline - assert fetch_pipe.lrange.call_count == 1 - assert fetch_pipe.delete.call_count == 1 - fetch_pipe.lrange.assert_called_with("test:key", 0, -1) - fetch_pipe.delete.assert_called_with("test:key") - - def test_fetch_commands_without_pending_marker_returns_empty(self): - """Ensure we avoid unnecessary list reads when pending flag is missing.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Pending flag absent - pending_pipe.execute.return_value = [None, 0] - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert commands == [] - mock_redis.llen.assert_not_called() - assert mock_redis.pipeline.call_count == 1 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py deleted file mode 100644 index cacbe9ba4e..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Tests for graph engine event handlers.""" - -from __future__ import annotations - -from graphon.entities.base_node_data import RetryConfig -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.graph_engine.domain.graph_execution import GraphExecution -from graphon.graph_engine.event_management.event_handlers import EventHandler -from graphon.graph_engine.event_management.event_manager import EventManager -from graphon.graph_engine.graph_state_manager import GraphStateManager -from graphon.graph_engine.ready_queue.in_memory import InMemoryReadyQueue -from graphon.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator -from graphon.graph_events import NodeRunRetryEvent, NodeRunStartedEvent -from graphon.node_events import NodeRunResult -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now - - -class _StubEdgeProcessor: - """Minimal edge processor stub for tests.""" - - -class _StubErrorHandler: - """Minimal error handler stub for tests.""" - - -class _StubNode: - """Simple node stub exposing the attributes needed by the state manager.""" - - def __init__(self, node_id: str) -> None: - self.id = node_id - self.state = NodeState.UNKNOWN - self.title = "Stub Node" - self.execution_type = NodeExecutionType.EXECUTABLE - self.error_strategy = None - self.retry_config = RetryConfig() - self.retry = False - - -def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]: - """Construct an EventHandler with in-memory dependencies for testing.""" - - node = _StubNode(node_id) - graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node) - - variable_pool = VariablePool() - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) - graph_execution = GraphExecution(workflow_id="test-workflow") - - event_manager = EventManager() - state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue()) - response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph) - - handler = EventHandler( - graph=graph, - graph_runtime_state=runtime_state, - graph_execution=graph_execution, - response_coordinator=response_coordinator, - event_collector=event_manager, - edge_processor=_StubEdgeProcessor(), - state_manager=state_manager, - error_handler=_StubErrorHandler(), - ) - - return handler, event_manager, graph_execution - - -def test_retry_does_not_emit_additional_start_event() -> None: - """Ensure retry attempts do not produce duplicate start events.""" - - node_id = "test-node" - handler, event_manager, graph_execution = _build_event_handler(node_id) - - execution_id = "exec-1" - node_type = BuiltinNodeTypes.CODE - start_time = naive_utc_now() - - start_event = NodeRunStartedEvent( - id=execution_id, - node_id=node_id, - node_type=node_type, - node_title="Stub Node", - start_at=start_time, - ) - handler.dispatch(start_event) - - retry_event = NodeRunRetryEvent( - id=execution_id, - node_id=node_id, - node_type=node_type, - node_title="Stub Node", - start_at=start_time, - error="boom", - retry_index=1, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error="boom", - error_type="TestError", - ), - ) - handler.dispatch(retry_event) - - # Simulate the node starting execution again after retry - second_start_event = NodeRunStartedEvent( - id=execution_id, - node_id=node_id, - node_type=node_type, - node_title="Stub Node", - start_at=start_time, - ) - handler.dispatch(second_start_event) - - collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined] - - assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent] - - node_execution = graph_execution.get_or_create_node_execution(node_id) - assert node_execution.retry_count == 1 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py deleted file mode 100644 index dc0998caf1..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Tests for the EventManager.""" - -from __future__ import annotations - -import logging - -from graphon.graph_engine.event_management.event_manager import EventManager -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent - - -class _FaultyLayer(GraphEngineLayer): - """Layer that raises from on_event to test error handling.""" - - def on_graph_start(self) -> None: # pragma: no cover - not used in tests - pass - - def on_event(self, event: GraphEngineEvent) -> None: - raise RuntimeError("boom") - - def on_graph_end(self, error: Exception | None) -> None: # pragma: no cover - not used in tests - pass - - -def test_event_manager_logs_layer_errors(caplog) -> None: - """Ensure errors raised by layers are logged when collecting events.""" - - event_manager = EventManager() - event_manager.set_layers([_FaultyLayer()]) - - with caplog.at_level(logging.ERROR): - event_manager.collect(GraphEngineEvent()) - - error_logs = [record for record in caplog.records if "Error in layer on_event" in record.getMessage()] - assert error_logs, "Expected layer errors to be logged" - - log_record = error_logs[0] - assert log_record.exc_info is not None - assert isinstance(log_record.exc_info[1], RuntimeError) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py deleted file mode 100644 index cf8811dc2b..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for graph traversal components.""" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py deleted file mode 100644 index b030496eb1..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py +++ /dev/null @@ -1,307 +0,0 @@ -"""Unit tests for skip propagator.""" - -from unittest.mock import MagicMock, create_autospec - -from graphon.graph import Edge, Graph -from graphon.graph_engine.graph_state_manager import GraphStateManager -from graphon.graph_engine.graph_traversal.skip_propagator import SkipPropagator - - -class TestSkipPropagator: - """Test suite for SkipPropagator.""" - - def test_propagate_skip_from_edge_with_unknown_edges_stops_processing(self) -> None: - """When there are unknown incoming edges, propagation should stop.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create a mock edge - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - # Setup graph edges dict - mock_graph.edges = {"edge_1": mock_edge} - - # Setup incoming edges - incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Setup state manager to return has_unknown=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": True, - "has_taken": False, - "all_skipped": False, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - mock_graph.get_incoming_edges.assert_called_once_with("node_2") - mock_state_manager.analyze_edge_states.assert_called_once_with(incoming_edges) - # Should not call any other state manager methods - mock_state_manager.enqueue_node.assert_not_called() - mock_state_manager.start_execution.assert_not_called() - mock_state_manager.mark_node_skipped.assert_not_called() - - def test_propagate_skip_from_edge_with_taken_edge_enqueues_node(self) -> None: - """When there is at least one taken edge, node should be enqueued.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create a mock edge - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - mock_graph.edges = {"edge_1": mock_edge} - incoming_edges = [MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Setup state manager to return has_taken=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": True, - "all_skipped": False, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - mock_state_manager.enqueue_node.assert_called_once_with("node_2") - mock_state_manager.start_execution.assert_called_once_with("node_2") - mock_state_manager.mark_node_skipped.assert_not_called() - - def test_propagate_skip_from_edge_with_all_skipped_propagates_to_node(self) -> None: - """When all incoming edges are skipped, should propagate skip to node.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create a mock edge - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - mock_graph.edges = {"edge_1": mock_edge} - incoming_edges = [MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Setup state manager to return all_skipped=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": False, - "all_skipped": True, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") - mock_state_manager.enqueue_node.assert_not_called() - mock_state_manager.start_execution.assert_not_called() - - def test_propagate_skip_to_node_marks_node_and_outgoing_edges_skipped(self) -> None: - """_propagate_skip_to_node should mark node and all outgoing edges as skipped.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create outgoing edges - edge1 = MagicMock(spec=Edge) - edge1.id = "edge_2" - edge1.head = "node_downstream_1" # Set head for propagate_skip_from_edge - - edge2 = MagicMock(spec=Edge) - edge2.id = "edge_3" - edge2.head = "node_downstream_2" - - # Setup graph edges dict for propagate_skip_from_edge - mock_graph.edges = {"edge_2": edge1, "edge_3": edge2} - mock_graph.get_outgoing_edges.return_value = [edge1, edge2] - - # Setup get_incoming_edges to return empty list to stop recursion - mock_graph.get_incoming_edges.return_value = [] - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Use mock to call private method - # Act - propagator._propagate_skip_to_node("node_1") - - # Assert - mock_state_manager.mark_node_skipped.assert_called_once_with("node_1") - mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") - mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") - assert mock_state_manager.mark_edge_skipped.call_count == 2 - # Should recursively propagate from each edge - # Since propagate_skip_from_edge is called, we need to verify it was called - # But we can't directly verify due to recursion. We'll trust the logic. - - def test_skip_branch_paths_marks_unselected_edges_and_propagates(self) -> None: - """skip_branch_paths should mark all unselected edges as skipped and propagate.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create unselected edges - edge1 = MagicMock(spec=Edge) - edge1.id = "edge_1" - edge1.head = "node_downstream_1" - - edge2 = MagicMock(spec=Edge) - edge2.id = "edge_2" - edge2.head = "node_downstream_2" - - unselected_edges = [edge1, edge2] - - # Setup graph edges dict - mock_graph.edges = {"edge_1": edge1, "edge_2": edge2} - # Setup get_incoming_edges to return empty list to stop recursion - mock_graph.get_incoming_edges.return_value = [] - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.skip_branch_paths(unselected_edges) - - # Assert - mock_state_manager.mark_edge_skipped.assert_any_call("edge_1") - mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") - assert mock_state_manager.mark_edge_skipped.call_count == 2 - # propagate_skip_from_edge should be called for each edge - # We can't directly verify due to the mock, but the logic is covered - - def test_propagate_skip_from_edge_recursively_propagates_through_graph(self) -> None: - """Skip propagation should recursively propagate through the graph.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create edge chain: edge_1 -> node_2 -> edge_3 -> node_4 - edge1 = MagicMock(spec=Edge) - edge1.id = "edge_1" - edge1.head = "node_2" - - edge3 = MagicMock(spec=Edge) - edge3.id = "edge_3" - edge3.head = "node_4" - - mock_graph.edges = {"edge_1": edge1, "edge_3": edge3} - - # Setup get_incoming_edges to return different values based on node - def get_incoming_edges_side_effect(node_id): - if node_id == "node_2": - return [edge1] - elif node_id == "node_4": - return [edge3] - return [] - - mock_graph.get_incoming_edges.side_effect = get_incoming_edges_side_effect - - # Setup get_outgoing_edges to return different values based on node - def get_outgoing_edges_side_effect(node_id): - if node_id == "node_2": - return [edge3] - elif node_id == "node_4": - return [] # No outgoing edges, stops recursion - return [] - - mock_graph.get_outgoing_edges.side_effect = get_outgoing_edges_side_effect - - # Setup state manager to return all_skipped for both nodes - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": False, - "all_skipped": True, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - # Should mark node_2 as skipped - mock_state_manager.mark_node_skipped.assert_any_call("node_2") - # Should mark edge_3 as skipped - mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") - # Should propagate to node_4 - mock_state_manager.mark_node_skipped.assert_any_call("node_4") - assert mock_state_manager.mark_node_skipped.call_count == 2 - - def test_propagate_skip_from_edge_with_mixed_edge_states_handles_correctly(self) -> None: - """Test with mixed edge states (some unknown, some taken, some skipped).""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - mock_graph.edges = {"edge_1": mock_edge} - incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge), MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Test 1: has_unknown=True, has_taken=False, all_skipped=False - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": True, - "has_taken": False, - "all_skipped": False, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - should stop processing - mock_state_manager.enqueue_node.assert_not_called() - mock_state_manager.mark_node_skipped.assert_not_called() - - # Reset mocks for next test - mock_state_manager.reset_mock() - mock_graph.reset_mock() - - # Test 2: has_unknown=False, has_taken=True, all_skipped=False - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": True, - "all_skipped": False, - } - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - should enqueue node - mock_state_manager.enqueue_node.assert_called_once_with("node_2") - mock_state_manager.start_execution.assert_called_once_with("node_2") - - # Reset mocks for next test - mock_state_manager.reset_mock() - mock_graph.reset_mock() - - # Test 3: has_unknown=False, has_taken=False, all_skipped=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": False, - "all_skipped": True, - } - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - should propagate skip - mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py deleted file mode 100644 index 2fead1d719..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Utilities for testing HumanInputNode without database dependencies.""" - -from __future__ import annotations - -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRecipientEntity, - HumanInputFormRepository, -) -from graphon.nodes.human_input.enums import HumanInputFormStatus -from libs.datetime_utils import naive_utc_now - - -class _InMemoryFormRecipient(HumanInputFormRecipientEntity): - """Minimal recipient entity required by the repository interface.""" - - def __init__(self, recipient_id: str, token: str) -> None: - self._id = recipient_id - self._token = token - - @property - def id(self) -> str: - return self._id - - @property - def token(self) -> str: - return self._token - - -@dataclass -class _InMemoryFormEntity(HumanInputFormEntity): - form_id: str - rendered: str - token: str | None = None - action_id: str | None = None - data: Mapping[str, Any] | None = None - is_submitted: bool = False - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() - - @property - def id(self) -> str: - return self.form_id - - @property - def submission_token(self) -> str | None: - return self.token - - @property - def recipients(self) -> list[HumanInputFormRecipientEntity]: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class InMemoryHumanInputFormRepository(HumanInputFormRepository): - """Pure in-memory repository used by workflow graph engine tests.""" - - def __init__(self) -> None: - self._form_counter = 0 - self.created_params: list[FormCreateParams] = [] - self.created_forms: list[_InMemoryFormEntity] = [] - self._forms_by_node_id: dict[str, _InMemoryFormEntity] = {} - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - self.created_params.append(params) - self._form_counter += 1 - form_id = f"form-{self._form_counter}" - token = f"token-{form_id}" - entity = _InMemoryFormEntity( - form_id=form_id, - rendered=params.rendered_content, - token=token, - ) - self.created_forms.append(entity) - self._forms_by_node_id[params.node_id] = entity - return entity - - def get_form(self, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_node_id.get(node_id) - - # Convenience helpers for tests ------------------------------------- - - def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: - """Simulate a human submission for the next repository lookup.""" - - if not self.created_forms: - raise AssertionError("no form has been created to attach submission data") - entity = self.created_forms[-1] - entity.action_id = action_id - entity.data = form_data or {} - entity.is_submitted = True - entity.status_value = HumanInputFormStatus.SUBMITTED - entity.expiration = naive_utc_now() + timedelta(days=1) - - def clear_submission(self) -> None: - if not self.created_forms: - return - for form in self.created_forms: - form.action_id = None - form.data = None - form.is_submitted = False - form.status_value = HumanInputFormStatus.WAITING diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index b642dc82fe..41627f5e0b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -5,13 +5,12 @@ Shared fixtures for ObservabilityLayer tests. from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import set_tracer_provider -from graphon.enums import BuiltinNodeTypes - @pytest.fixture def memory_span_exporter(): @@ -62,9 +61,10 @@ def mock_llm_node(): @pytest.fixture def mock_tool_node(): """Create a mock Tool Node with tool-specific attributes.""" - from core.tools.entities.tool_entities import ToolProviderType from graphon.nodes.tool.entities import ToolNodeData + from core.tools.entities.tool_entities import ToolProviderType + node = MagicMock() node.id = "test-tool-node-id" node.title = "Test Tool Node" @@ -117,8 +117,8 @@ def mock_result_event(): """Create a mock result event with NodeRunResult.""" from datetime import datetime - from graphon.graph_events.node import NodeRunSucceededEvent - from graphon.node_events.base import NodeRunResult + from graphon.graph_events import NodeRunSucceededEvent + from graphon.node_events import NodeRunResult node_run_result = NodeRunResult( inputs={"query": "test query"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py deleted file mode 100644 index 7ff77c19c1..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations - -import pytest - -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_engine.layers.base import ( - GraphEngineLayer, - GraphEngineLayerNotInitializedError, -) -from graphon.graph_events import GraphEngineEvent - -from ..test_table_runner import WorkflowRunner - - -class LayerForTest(GraphEngineLayer): - def on_graph_start(self) -> None: - pass - - def on_event(self, event: GraphEngineEvent) -> None: - pass - - def on_graph_end(self, error: Exception | None) -> None: - pass - - -def test_layer_runtime_state_raises_when_uninitialized() -> None: - layer = LayerForTest() - - with pytest.raises(GraphEngineLayerNotInitializedError): - _ = layer.graph_runtime_state - - -def test_layer_runtime_state_available_after_engine_layer() -> None: - runner = WorkflowRunner() - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture( - fixture_data, - inputs={"query": "test layer state"}, - ) - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - layer = LayerForTest() - engine.layer(layer) - - outputs = layer.graph_runtime_state.outputs - ready_queue_size = layer.graph_runtime_state.ready_queue_size - - assert outputs == {} - assert isinstance(ready_queue_size, int) - assert ready_queue_size >= 0 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 80874e768a..99d131737e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -3,15 +3,16 @@ from datetime import datetime from types import SimpleNamespace from unittest.mock import MagicMock, patch +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.entities.commands import CommandType +from graphon.graph_events import NodeRunSucceededEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import NodeRunResult + from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph_engine.entities.commands import CommandType -from graphon.graph_events.node import NodeRunSucceededEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import NodeRunResult def _build_dify_context() -> DifyRunContext: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index 14ce55938d..9cf72763ee 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -13,10 +13,10 @@ Test coverage: from unittest.mock import patch import pytest +from graphon.enums import BuiltinNodeTypes from opentelemetry.trace import StatusCode from core.app.workflow.layers.observability import ObservabilityLayer -from graphon.enums import BuiltinNodeTypes class TestObservabilityLayerInitialization: @@ -144,7 +144,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node, mock_result_event ): """Test that LLM parser is used for LLM nodes and extracts LLM-specific attributes.""" - from graphon.node_events.base import NodeRunResult + from graphon.node_events import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={}, @@ -182,7 +182,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_retrieval_node, mock_result_event ): """Test that retrieval parser is used for retrieval nodes and extracts retrieval-specific attributes.""" - from graphon.node_events.base import NodeRunResult + from graphon.node_events import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"query": "test query"}, @@ -210,7 +210,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node, mock_result_event ): """Test that result_event parameter allows parsers to extract inputs and outputs.""" - from graphon.node_events.base import NodeRunResult + from graphon.node_events import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"input_key": "input_value"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py deleted file mode 100644 index ab3a31f673..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py +++ /dev/null @@ -1,189 +0,0 @@ -"""Tests for dispatcher command checking behavior.""" - -from __future__ import annotations - -import queue -from unittest import mock - -from graphon.entities.pause_reason import SchedulingPause -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph_engine.event_management.event_handlers import EventHandler -from graphon.graph_engine.orchestration.dispatcher import Dispatcher -from graphon.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from graphon.graph_events import ( - GraphNodeEventBase, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from graphon.node_events import NodeRunResult -from libs.datetime_utils import naive_utc_now - - -def test_dispatcher_should_consume_remains_events_after_pause(): - event_queue = queue.Queue() - event_queue.put( - GraphNodeEventBase( - id="test", - node_id="test", - node_type=BuiltinNodeTypes.START, - ) - ) - event_handler = mock.Mock(spec=EventHandler) - execution_coordinator = mock.Mock(spec=ExecutionCoordinator) - execution_coordinator.paused.return_value = True - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=event_handler, - execution_coordinator=execution_coordinator, - ) - dispatcher._dispatcher_loop() - assert event_queue.empty() - - -class _StubExecutionCoordinator: - """Stub execution coordinator that tracks command checks.""" - - def __init__(self) -> None: - self.command_checks = 0 - self.scaling_checks = 0 - self.execution_complete = False - self.failed = False - self._paused = False - - def process_commands(self) -> None: - self.command_checks += 1 - - def check_scaling(self) -> None: - self.scaling_checks += 1 - - @property - def paused(self) -> bool: - return self._paused - - @property - def aborted(self) -> bool: - return False - - def mark_complete(self) -> None: - self.execution_complete = True - - def mark_failed(self, error: Exception) -> None: # pragma: no cover - defensive, not triggered in tests - self.failed = True - - -class _StubEventHandler: - """Minimal event handler that marks execution complete after handling an event.""" - - def __init__(self, coordinator: _StubExecutionCoordinator) -> None: - self._coordinator = coordinator - self.events = [] - - def dispatch(self, event) -> None: - self.events.append(event) - self._coordinator.mark_complete() - - -def _run_dispatcher_for_event(event) -> int: - """Run the dispatcher loop for a single event and return command check count.""" - event_queue: queue.Queue = queue.Queue() - event_queue.put(event) - - coordinator = _StubExecutionCoordinator() - event_handler = _StubEventHandler(coordinator) - - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=event_handler, - execution_coordinator=coordinator, - ) - - dispatcher._dispatcher_loop() - - return coordinator.command_checks - - -def _make_started_event() -> NodeRunStartedEvent: - return NodeRunStartedEvent( - id="start-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - node_title="Test Node", - start_at=naive_utc_now(), - ) - - -def _make_succeeded_event() -> NodeRunSucceededEvent: - return NodeRunSucceededEvent( - id="success-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - node_title="Test Node", - start_at=naive_utc_now(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ) - - -def test_dispatcher_checks_commands_during_idle_and_on_completion() -> None: - """Dispatcher polls commands when idle and after completion events.""" - started_checks = _run_dispatcher_for_event(_make_started_event()) - succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event()) - - assert started_checks == 2 - assert succeeded_checks == 3 - - -class _PauseStubEventHandler: - """Minimal event handler that marks execution complete after handling an event.""" - - def __init__(self, coordinator: _StubExecutionCoordinator) -> None: - self._coordinator = coordinator - self.events = [] - - def dispatch(self, event) -> None: - self.events.append(event) - if isinstance(event, NodeRunPauseRequestedEvent): - self._coordinator.mark_complete() - - -def test_dispatcher_drain_event_queue(): - events = [ - NodeRunStartedEvent( - id="start-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - node_title="Code", - start_at=naive_utc_now(), - ), - NodeRunPauseRequestedEvent( - id="pause-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - reason=SchedulingPause(message="test pause"), - ), - NodeRunSucceededEvent( - id="success-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - start_at=naive_utc_now(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ), - ] - - event_queue: queue.Queue = queue.Queue() - for e in events: - event_queue.put(e) - - coordinator = _StubExecutionCoordinator() - event_handler = _PauseStubEventHandler(coordinator) - - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=event_handler, - execution_coordinator=coordinator, - ) - - dispatcher._dispatcher_loop() - - # ensure all events are drained. - assert event_queue.empty() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py deleted file mode 100644 index 1510c8e595..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py +++ /dev/null @@ -1,37 +0,0 @@ -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_answer_end_with_text(): - fixture_name = "answer_end_with_text" - case = WorkflowTestCase( - fixture_name, - query="Hello, AI!", - expected_outputs={"answer": "prefixHello, AI!suffix"}, - expected_event_sequence=[ - GraphRunStartedEvent, - # Start - NodeRunStartedEvent, - # The chunks are now emitted as the Answer node processes them - # since sys.query is a special selector that gets attributed to - # the active response node - NodeRunStreamChunkEvent, # prefix - NodeRunStreamChunkEvent, # sys.query - NodeRunStreamChunkEvent, # suffix - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py deleted file mode 100644 index 6569439b56..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py +++ /dev/null @@ -1,28 +0,0 @@ -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - -LLM_NODE_ID = "1759052580454" - - -def test_answer_nodes_emit_in_order() -> None: - mock_config = ( - MockConfigBuilder() - .with_llm_response("unused default") - .with_node_output(LLM_NODE_ID, {"text": "mocked llm text"}) - .build() - ) - - expected_answer = "--- answer 1 ---\n\nfoo\n--- answer 2 ---\n\nmocked llm text\n" - - case = WorkflowTestCase( - fixture_path="test-answer-order", - query="", - expected_outputs={"answer": expected_answer}, - use_auto_mock=True, - mock_config=mock_config, - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - - assert result.success, result.error diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py deleted file mode 100644 index 05ec565def..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py +++ /dev/null @@ -1,24 +0,0 @@ -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_array_iteration_formatting_workflow(): - """ - Validate Iteration node processes [1,2,3] into formatted strings. - - Fixture description expects: - {"output": ["output: 1", "output: 2", "output: 3"]} - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="array_iteration_formatting_workflow", - inputs={}, - expected_outputs={"output": ["output: 1", "output: 2", "output: 3"]}, - description="Iteration formats numbers into strings", - use_auto_mock=True, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Iteration workflow failed: {result.error}" - assert result.actual_outputs == test_case.expected_outputs diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py deleted file mode 100644 index 5d0b37acc5..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +++ /dev/null @@ -1,392 +0,0 @@ -""" -Tests for the auto-mock system. - -This module contains tests that validate the auto-mock functionality -for workflows containing nodes that require third-party services. -""" - -import pytest - -from graphon.enums import BuiltinNodeTypes -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_simple_llm_workflow_with_auto_mock(): - """Test that a simple LLM workflow runs successfully with auto-mocking.""" - runner = TableTestRunner() - - # Create mock configuration - mock_config = MockConfigBuilder().with_llm_response("This is a test response from mocked LLM").build() - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Hello, how are you?"}, - expected_outputs={"answer": "This is a test response from mocked LLM"}, - description="Simple LLM workflow with auto-mock", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs is not None - assert "answer" in result.actual_outputs - assert result.actual_outputs["answer"] == "This is a test response from mocked LLM" - - -def test_llm_workflow_with_custom_node_output(): - """Test LLM workflow with custom output for specific node.""" - runner = TableTestRunner() - - # Create mock configuration with custom output for specific node - mock_config = MockConfig() - mock_config.set_node_outputs( - "llm_node", - { - "text": "Custom response for this specific node", - "usage": { - "prompt_tokens": 20, - "completion_tokens": 10, - "total_tokens": 30, - }, - "finish_reason": "stop", - }, - ) - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Test query"}, - expected_outputs={"answer": "Custom response for this specific node"}, - description="LLM workflow with custom node output", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs is not None - assert result.actual_outputs["answer"] == "Custom response for this specific node" - - -def test_http_tool_workflow_with_auto_mock(): - """Test workflow with HTTP request and tool nodes using auto-mock.""" - runner = TableTestRunner() - - # Create mock configuration - mock_config = MockConfig() - mock_config.set_node_outputs( - "http_node", - { - "status_code": 200, - "body": '{"key": "value", "number": 42}', - "headers": {"content-type": "application/json"}, - }, - ) - mock_config.set_node_outputs( - "tool_node", - { - "result": {"key": "value", "number": 42}, - }, - ) - - test_case = WorkflowTestCase( - fixture_path="http_request_with_json_tool_workflow", - inputs={"url": "https://api.example.com/data"}, - expected_outputs={ - "status_code": 200, - "parsed_data": {"key": "value", "number": 42}, - }, - description="HTTP and Tool workflow with auto-mock", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs is not None - assert result.actual_outputs["status_code"] == 200 - assert result.actual_outputs["parsed_data"] == {"key": "value", "number": 42} - - -def test_workflow_with_simulated_node_error(): - """Test that workflows handle simulated node errors correctly.""" - runner = TableTestRunner() - - # Create mock configuration with error - mock_config = MockConfig() - mock_config.set_node_error("llm_node", "Simulated LLM API error") - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "This should fail"}, - expected_outputs={}, # We expect failure, so no outputs - description="LLM workflow with simulated error", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - # The workflow should fail due to the simulated error - assert not result.success - assert result.error is not None - - -def test_workflow_with_mock_delays(): - """Test that mock delays work correctly.""" - runner = TableTestRunner() - - # Create mock configuration with delays - mock_config = MockConfig(simulate_delays=True) - node_config = NodeMockConfig( - node_id="llm_node", - outputs={"text": "Response after delay"}, - delay=0.1, # 100ms delay - ) - mock_config.set_node_config("llm_node", node_config) - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Test with delay"}, - expected_outputs={"answer": "Response after delay"}, - description="LLM workflow with simulated delay", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - # Execution time should be at least the delay - assert result.execution_time >= 0.1 - - -def test_mock_config_builder(): - """Test the MockConfigBuilder fluent interface.""" - config = ( - MockConfigBuilder() - .with_llm_response("LLM response") - .with_agent_response("Agent response") - .with_tool_response({"tool": "output"}) - .with_retrieval_response("Retrieval content") - .with_http_response({"status_code": 201, "body": "created"}) - .with_node_output("node1", {"output": "value"}) - .with_node_error("node2", "error message") - .with_delays(True) - .build() - ) - - assert config.default_llm_response == "LLM response" - assert config.default_agent_response == "Agent response" - assert config.default_tool_response == {"tool": "output"} - assert config.default_retrieval_response == "Retrieval content" - assert config.default_http_response == {"status_code": 201, "body": "created"} - assert config.simulate_delays is True - - node1_config = config.get_node_config("node1") - assert node1_config is not None - assert node1_config.outputs == {"output": "value"} - - node2_config = config.get_node_config("node2") - assert node2_config is not None - assert node2_config.error == "error message" - - -def test_mock_factory_node_type_detection(): - """Test that MockNodeFactory correctly identifies nodes to mock.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.runtime import GraphRuntimeState, VariablePool - - from .test_mock_factory import MockNodeFactory - - graph_init_params = build_test_graph_init_params( - workflow_id="test", - graph_config={}, - tenant_id="test", - app_id="test", - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # Test that third-party service nodes are identified for mocking - assert factory.should_mock_node(BuiltinNodeTypes.LLM) - assert factory.should_mock_node(BuiltinNodeTypes.AGENT) - assert factory.should_mock_node(BuiltinNodeTypes.TOOL) - assert factory.should_mock_node(BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL) - assert factory.should_mock_node(BuiltinNodeTypes.HTTP_REQUEST) - assert factory.should_mock_node(BuiltinNodeTypes.PARAMETER_EXTRACTOR) - assert factory.should_mock_node(BuiltinNodeTypes.DOCUMENT_EXTRACTOR) - - # Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Test that non-service nodes are not mocked - assert not factory.should_mock_node(BuiltinNodeTypes.START) - assert not factory.should_mock_node(BuiltinNodeTypes.END) - assert not factory.should_mock_node(BuiltinNodeTypes.IF_ELSE) - assert not factory.should_mock_node(BuiltinNodeTypes.VARIABLE_AGGREGATOR) - - -def test_custom_mock_handler(): - """Test using a custom handler function for mock outputs.""" - runner = TableTestRunner() - - # Custom handler that modifies output based on input - def custom_llm_handler(node) -> dict: - # In a real scenario, we could access node.graph_runtime_state.variable_pool - # to get the actual inputs - return { - "text": "Custom handler response", - "usage": { - "prompt_tokens": 5, - "completion_tokens": 3, - "total_tokens": 8, - }, - "finish_reason": "stop", - } - - mock_config = MockConfig() - node_config = NodeMockConfig( - node_id="llm_node", - custom_handler=custom_llm_handler, - ) - mock_config.set_node_config("llm_node", node_config) - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Test custom handler"}, - expected_outputs={"answer": "Custom handler response"}, - description="LLM workflow with custom handler", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs["answer"] == "Custom handler response" - - -def test_workflow_without_auto_mock(): - """Test that workflows work normally without auto-mock enabled.""" - runner = TableTestRunner() - - # This test uses the echo workflow which doesn't need external services - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "Test without mock"}, - expected_outputs={"query": "Test without mock"}, - description="Echo workflow without auto-mock", - use_auto_mock=False, # Auto-mock disabled - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs["query"] == "Test without mock" - - -def test_register_custom_mock_node(): - """Test registering a custom mock implementation for a node type.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.nodes.template_transform import TemplateTransformNode - from graphon.runtime import GraphRuntimeState, VariablePool - - from .test_mock_factory import MockNodeFactory - - # Create a custom mock for TemplateTransformNode - class MockTemplateTransformNode(TemplateTransformNode): - def _run(self): - # Custom mock implementation - pass - - graph_init_params = build_test_graph_init_params( - workflow_id="test", - graph_config={}, - tenant_id="test", - app_id="test", - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Unregister mock - factory.unregister_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - assert not factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Re-register custom mock - factory.register_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM, MockTemplateTransformNode) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - -def test_default_config_by_node_type(): - """Test setting default configurations by node type.""" - mock_config = MockConfig() - - # Set default config for all LLM nodes - mock_config.set_default_config( - BuiltinNodeTypes.LLM, - { - "default_response": "Default LLM response for all nodes", - "temperature": 0.7, - }, - ) - - # Set default config for all HTTP nodes - mock_config.set_default_config( - BuiltinNodeTypes.HTTP_REQUEST, - { - "default_status": 200, - "default_timeout": 30, - }, - ) - - llm_config = mock_config.get_default_config(BuiltinNodeTypes.LLM) - assert llm_config["default_response"] == "Default LLM response for all nodes" - assert llm_config["temperature"] == 0.7 - - http_config = mock_config.get_default_config(BuiltinNodeTypes.HTTP_REQUEST) - assert http_config["default_status"] == 200 - assert http_config["default_timeout"] == 30 - - # Non-configured node type should return empty dict - tool_config = mock_config.get_default_config(BuiltinNodeTypes.TOOL) - assert tool_config == {} - - -if __name__ == "__main__": - # Run all tests - pytest.main([__file__, "-v"]) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py deleted file mode 100644 index cefe3b8ac8..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py +++ /dev/null @@ -1,41 +0,0 @@ -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_basic_chatflow(): - fixture_name = "basic_chatflow" - mock_config = MockConfigBuilder().with_llm_response("mocked llm response").build() - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - expected_outputs={"answer": "mocked llm response"}, - expected_event_sequence=[ - GraphRunStartedEvent, - # START - NodeRunStartedEvent, - NodeRunSucceededEvent, - # LLM - NodeRunStartedEvent, - ] - + [NodeRunStreamChunkEvent] * ("mocked llm response".count(" ") + 2) - + [ - NodeRunSucceededEvent, - # ANSWER - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py deleted file mode 100644 index 01ac2d7a96..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ /dev/null @@ -1,266 +0,0 @@ -"""Test the command system for GraphEngine control.""" - -import time -from unittest.mock import MagicMock - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from graphon.entities.graph_init_params import GraphInitParams -from graphon.entities.pause_reason import SchedulingPause -from graphon.graph import Graph -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_engine.entities.commands import ( - AbortCommand, - CommandType, - PauseCommand, - UpdateVariablesCommand, - VariableUpdate, -) -from graphon.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import IntegerVariable, StringVariable - - -def test_abort_command(): - """Test that GraphEngine properly handles abort commands.""" - - # Create shared GraphRuntimeState - shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) - - # Create a minimal mock graph - mock_graph = MagicMock(spec=Graph) - mock_graph.nodes = {} - mock_graph.edges = {} - mock_graph.root_node = MagicMock() - mock_graph.root_node.id = "start" - - # Create mock nodes with required attributes - using shared runtime state - start_node = StartNode( - id="start", - config={"id": "start", "data": {"title": "start", "variables": []}}, - graph_init_params=GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ), - graph_runtime_state=shared_runtime_state, - ) - mock_graph.nodes["start"] = start_node - - # Mock graph methods - mock_graph.get_outgoing_edges = MagicMock(return_value=[]) - mock_graph.get_incoming_edges = MagicMock(return_value=[]) - - # Create command channel - command_channel = InMemoryChannel() - - # Create GraphEngine with same shared runtime state - engine = GraphEngine( - workflow_id="test_workflow", - graph=mock_graph, - graph_runtime_state=shared_runtime_state, # Use shared instance - command_channel=command_channel, - config=GraphEngineConfig(), - ) - - # Queue an abort request before starting. - engine.request_abort("Test abort") - - # Run engine and collect events - events = list(engine.run()) - - # Verify we get start and abort events - assert any(isinstance(e, GraphRunStartedEvent) for e in events) - assert any(isinstance(e, GraphRunAbortedEvent) for e in events) - - # Find the abort event and check its reason - abort_events = [e for e in events if isinstance(e, GraphRunAbortedEvent)] - assert len(abort_events) == 1 - assert abort_events[0].reason is not None - assert "aborted: test abort" in abort_events[0].reason.lower() - - -def test_redis_channel_serialization(): - """Test that Redis channel properly serializes and deserializes commands.""" - import json - from unittest.mock import MagicMock - - # Mock redis client - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None) - - from graphon.graph_engine.command_channels.redis_channel import RedisChannel - - # Create channel with a specific key - channel = RedisChannel(mock_redis, channel_key="workflow:123:commands") - - # Test sending a command - abort_command = AbortCommand(reason="Test abort") - channel.send_command(abort_command) - - # Verify redis methods were called - mock_pipeline.rpush.assert_called_once() - mock_pipeline.expire.assert_called_once() - - # Verify the serialized data - call_args = mock_pipeline.rpush.call_args - key = call_args[0][0] - command_json = call_args[0][1] - - assert key == "workflow:123:commands" - - # Verify JSON structure - command_data = json.loads(command_json) - assert command_data["command_type"] == "abort" - assert command_data["reason"] == "Test abort" - - # Test pause command serialization - pause_command = PauseCommand(reason="User requested pause") - channel.send_command(pause_command) - - assert len(mock_pipeline.rpush.call_args_list) == 2 - second_call_args = mock_pipeline.rpush.call_args_list[1] - pause_command_json = second_call_args[0][1] - pause_command_data = json.loads(pause_command_json) - assert pause_command_data["command_type"] == CommandType.PAUSE.value - assert pause_command_data["reason"] == "User requested pause" - - -def test_pause_command(): - """Test that GraphEngine properly handles pause commands.""" - - shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) - - mock_graph = MagicMock(spec=Graph) - mock_graph.nodes = {} - mock_graph.edges = {} - mock_graph.root_node = MagicMock() - mock_graph.root_node.id = "start" - - start_node = StartNode( - id="start", - config={"id": "start", "data": {"title": "start", "variables": []}}, - graph_init_params=GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ), - graph_runtime_state=shared_runtime_state, - ) - mock_graph.nodes["start"] = start_node - - mock_graph.get_outgoing_edges = MagicMock(return_value=[]) - mock_graph.get_incoming_edges = MagicMock(return_value=[]) - - command_channel = InMemoryChannel() - - engine = GraphEngine( - workflow_id="test_workflow", - graph=mock_graph, - graph_runtime_state=shared_runtime_state, - command_channel=command_channel, - config=GraphEngineConfig(), - ) - - pause_command = PauseCommand(reason="User requested pause") - command_channel.send_command(pause_command) - - events = list(engine.run()) - - assert any(isinstance(e, GraphRunStartedEvent) for e in events) - pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)] - assert len(pause_events) == 1 - assert pause_events[0].reasons == [SchedulingPause(message="User requested pause")] - - graph_execution = engine.graph_runtime_state.graph_execution - assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")] - - -def test_update_variables_command_updates_pool(): - """Test that GraphEngine updates variable pool via update variables command.""" - - shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) - shared_runtime_state.variable_pool.add(("node1", "foo"), "old value") - - mock_graph = MagicMock(spec=Graph) - mock_graph.nodes = {} - mock_graph.edges = {} - mock_graph.root_node = MagicMock() - mock_graph.root_node.id = "start" - - start_node = StartNode( - id="start", - config={"id": "start", "data": {"title": "start", "variables": []}}, - graph_init_params=GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ), - graph_runtime_state=shared_runtime_state, - ) - mock_graph.nodes["start"] = start_node - - mock_graph.get_outgoing_edges = MagicMock(return_value=[]) - mock_graph.get_incoming_edges = MagicMock(return_value=[]) - - command_channel = InMemoryChannel() - - engine = GraphEngine( - workflow_id="test_workflow", - graph=mock_graph, - graph_runtime_state=shared_runtime_state, - command_channel=command_channel, - config=GraphEngineConfig(), - ) - - update_command = UpdateVariablesCommand( - updates=[ - VariableUpdate( - value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]), - ), - VariableUpdate( - value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]), - ), - ] - ) - command_channel.send_command(update_command) - - list(engine.run()) - - updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"]) - added_new = shared_runtime_state.variable_pool.get(["node2", "bar"]) - - assert updated_existing is not None - assert updated_existing.value == "new value" - assert added_new is not None - assert added_new.value == 123 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py deleted file mode 100644 index ba9c502452..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Test suite for complex branch workflow with parallel execution and conditional routing. - -This test suite validates the behavior of a workflow that: -1. Executes nodes in parallel (IF/ELSE and LLM branches) -2. Routes based on conditional logic (query containing 'hello') -3. Handles multiple answer nodes with different outputs -""" - -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -class TestComplexBranchWorkflow: - """Test suite for complex branch workflow with parallel execution.""" - - def setup_method(self): - """Set up test environment before each test method.""" - self.runner = TableTestRunner() - self.fixture_path = "test_complex_branch" - - def test_hello_branch_with_llm(self): - """ - Test when query contains 'hello' - should trigger true branch. - Both IF/ELSE and LLM should execute in parallel. - """ - mock_text_1 = "This is a mocked LLM response for hello world" - test_cases = [ - WorkflowTestCase( - fixture_path=self.fixture_path, - query="hello world", - expected_outputs={ - "answer": f"contains 'hello'{mock_text_1}", - }, - description="Basic hello case with parallel LLM execution", - use_auto_mock=True, - mock_config=(MockConfigBuilder().with_node_output("1755502777322", {"text": mock_text_1}).build()), - ), - WorkflowTestCase( - fixture_path=self.fixture_path, - query="say hello to everyone", - expected_outputs={ - "answer": "contains 'hello'Mocked response for greeting", - }, - description="Hello in middle of sentence", - use_auto_mock=True, - mock_config=( - MockConfigBuilder() - .with_node_output("1755502777322", {"text": "Mocked response for greeting"}) - .build() - ), - ), - ] - - suite_result = self.runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test '{result.test_case.description}' failed: {result.error}" - assert result.actual_outputs - assert any(isinstance(event, GraphRunStartedEvent) for event in result.events) - assert any(isinstance(event, GraphRunSucceededEvent) for event in result.events) - - start_index = next( - idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunStartedEvent) - ) - success_index = max( - idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunSucceededEvent) - ) - assert start_index < success_index - - started_node_ids = {event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)} - assert {"1755502773326", "1755502777322"}.issubset(started_node_ids), ( - f"Branch or LLM nodes missing in events: {started_node_ids}" - ) - - assert any(isinstance(event, NodeRunStreamChunkEvent) for event in result.events), ( - "Expected streaming chunks from LLM execution" - ) - - llm_start_index = next( - idx - for idx, event in enumerate(result.events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == "1755502777322" - ) - assert any( - idx > llm_start_index and isinstance(event, NodeRunStreamChunkEvent) - for idx, event in enumerate(result.events) - ), "Streaming chunks should follow LLM node start" - - def test_non_hello_branch_with_llm(self): - """ - Test when query doesn't contain 'hello' - should trigger false branch. - LLM output should be used as the final answer. - """ - test_cases = [ - WorkflowTestCase( - fixture_path=self.fixture_path, - query="goodbye world", - expected_outputs={ - "answer": "Mocked LLM response for goodbye", - }, - description="Goodbye case - false branch with LLM output", - use_auto_mock=True, - mock_config=( - MockConfigBuilder() - .with_node_output("1755502777322", {"text": "Mocked LLM response for goodbye"}) - .build() - ), - ), - WorkflowTestCase( - fixture_path=self.fixture_path, - query="test message", - expected_outputs={ - "answer": "Mocked response for test", - }, - description="Regular message - false branch", - use_auto_mock=True, - mock_config=( - MockConfigBuilder().with_node_output("1755502777322", {"text": "Mocked response for test"}).build() - ), - ), - ] - - suite_result = self.runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test '{result.test_case.description}' failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py deleted file mode 100644 index 3851480731..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py +++ /dev/null @@ -1,220 +0,0 @@ -""" -Test for streaming output workflow behavior. - -This test validates that: -- When blocking == 1: No NodeRunStreamChunkEvent (flow through Template node) -- When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output) -""" - -from graphon.enums import BuiltinNodeTypes -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_events import ( - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_table_runner import TableTestRunner - - -def test_streaming_output_with_blocking_equals_one(): - """ - Test workflow when blocking == 1 (LLM → Template → End). - - Template node doesn't produce streaming output, so no NodeRunStreamChunkEvent should be present. - This test should FAIL according to requirements. - """ - runner = TableTestRunner() - - # Load the workflow configuration - fixture_data = runner.workflow_runner.load_fixture("conditional_streaming_vs_template_workflow") - - # Create graph from fixture with auto-mock enabled - graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - inputs={"query": "Hello, how are you?", "blocking": 1}, - use_mock_factory=True, - ) - - # Create and run the engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Execute the workflow - events = list(engine.run()) - - # Check for successful completion - success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] - assert len(success_events) > 0, "Workflow should complete successfully" - - # Check for streaming events - stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] - stream_chunk_count = len(stream_chunk_events) - - # According to requirements, we expect exactly 3 streaming events from the End node - # 1. User query - # 2. Newline - # 3. Template output (which contains the LLM response) - assert stream_chunk_count == 3, f"Expected 3 streaming events when blocking=1, but got {stream_chunk_count}" - - first_chunk, second_chunk, third_chunk = stream_chunk_events[0], stream_chunk_events[1], stream_chunk_events[2] - assert first_chunk.chunk == "Hello, how are you?", ( - f"Expected first chunk to be user input, but got {first_chunk.chunk}" - ) - assert second_chunk.chunk == "\n", f"Expected second chunk to be newline, but got {second_chunk.chunk}" - # Third chunk will be the template output with the mock LLM response - assert isinstance(third_chunk.chunk, str), f"Expected third chunk to be string, but got {type(third_chunk.chunk)}" - - # Find indices of first LLM success event and first stream chunk event - llm2_start_index = next( - ( - i - for i, e in enumerate(events) - if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM - ), - -1, - ) - first_chunk_index = next( - (i for i, e in enumerate(events) if isinstance(e, NodeRunStreamChunkEvent)), - -1, - ) - - assert first_chunk_index < llm2_start_index, ( - f"Expected first chunk before LLM2 start, but got {first_chunk_index} and {llm2_start_index}" - ) - - # Check that NodeRunStreamChunkEvent contains 'query' should has same id with Start NodeRunStartedEvent - start_node_id = graph.root_node.id - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_id == start_node_id] - assert len(start_events) == 1, f"Expected 1 start event for node {start_node_id}, but got {len(start_events)}" - start_event = start_events[0] - query_chunk_events = [e for e in stream_chunk_events if e.chunk == "Hello, how are you?"] - assert all(e.id == start_event.id for e in query_chunk_events), "Expected all query chunk events to have same id" - - # Check all Template's NodeRunStreamChunkEvent should has same id with Template's NodeRunStartedEvent - start_events = [ - e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM - ] - template_chunk_events = [e for e in stream_chunk_events if e.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM] - assert len(template_chunk_events) == 1, f"Expected 1 template chunk event, but got {len(template_chunk_events)}" - assert all(e.id in [se.id for se in start_events] for e in template_chunk_events), ( - "Expected all Template chunk events to have same id with Template's NodeRunStartedEvent" - ) - - # Check that NodeRunStreamChunkEvent contains '\n' is from the End node - end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.END] - assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}" - newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"] - assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}" - # The newline chunk should be from the End node (check node_id, not execution id) - assert all(e.node_id == end_events[0].node_id for e in newline_chunk_events), ( - "Expected all newline chunk events to be from End node" - ) - - -def test_streaming_output_with_blocking_not_equals_one(): - """ - Test workflow when blocking != 1 (LLM → End directly). - - End node should produce streaming output with NodeRunStreamChunkEvent. - This test should PASS according to requirements. - """ - runner = TableTestRunner() - - # Load the workflow configuration - fixture_data = runner.workflow_runner.load_fixture("conditional_streaming_vs_template_workflow") - - # Create graph from fixture with auto-mock enabled - graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - inputs={"query": "Hello, how are you?", "blocking": 2}, - use_mock_factory=True, - ) - - # Create and run the engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Execute the workflow - events = list(engine.run()) - - # Check for successful completion - success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] - assert len(success_events) > 0, "Workflow should complete successfully" - - # Check for streaming events - expecting streaming events - stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] - stream_chunk_count = len(stream_chunk_events) - - # This assertion should PASS according to requirements - assert stream_chunk_count > 0, f"Expected streaming events when blocking!=1, but got {stream_chunk_count}" - - # We should have at least 2 chunks (query and newline) - assert stream_chunk_count >= 2, f"Expected at least 2 streaming events, but got {stream_chunk_count}" - - first_chunk, second_chunk = stream_chunk_events[0], stream_chunk_events[1] - assert first_chunk.chunk == "Hello, how are you?", ( - f"Expected first chunk to be user input, but got {first_chunk.chunk}" - ) - assert second_chunk.chunk == "\n", f"Expected second chunk to be newline, but got {second_chunk.chunk}" - - # Find indices of first LLM success event and first stream chunk event - llm2_start_index = next( - ( - i - for i, e in enumerate(events) - if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM - ), - -1, - ) - first_chunk_index = next( - (i for i, e in enumerate(events) if isinstance(e, NodeRunStreamChunkEvent)), - -1, - ) - - assert first_chunk_index < llm2_start_index, ( - f"Expected first chunk before LLM2 start, but got {first_chunk_index} and {llm2_start_index}" - ) - - # With auto-mock, the LLM will produce mock responses - just verify we have streaming chunks - # and they are strings - for chunk_event in stream_chunk_events[2:]: - assert isinstance(chunk_event.chunk, str), f"Expected chunk to be string, but got {type(chunk_event.chunk)}" - - # Check that NodeRunStreamChunkEvent contains 'query' should has same id with Start NodeRunStartedEvent - start_node_id = graph.root_node.id - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_id == start_node_id] - assert len(start_events) == 1, f"Expected 1 start event for node {start_node_id}, but got {len(start_events)}" - start_event = start_events[0] - query_chunk_events = [e for e in stream_chunk_events if e.chunk == "Hello, how are you?"] - assert all(e.id == start_event.id for e in query_chunk_events), "Expected all query chunk events to have same id" - - # Check all LLM's NodeRunStreamChunkEvent should be from LLM nodes - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.LLM] - llm_chunk_events = [e for e in stream_chunk_events if e.node_type == BuiltinNodeTypes.LLM] - llm_node_ids = {se.node_id for se in start_events} - assert all(e.node_id in llm_node_ids for e in llm_chunk_events), ( - "Expected all LLM chunk events to be from LLM nodes" - ) - - # Check that NodeRunStreamChunkEvent contains '\n' is from the End node - end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.END] - assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}" - newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"] - assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}" - # The newline chunk should be from the End node (check node_id, not execution id) - assert all(e.node_id == end_events[0].node_id for e in newline_chunk_events), ( - "Expected all newline chunk events to be from End node" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py deleted file mode 100644 index ae7dd48bb1..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Utilities for detecting if database service is available for workflow tests. -""" - -import psycopg2 -import pytest - -from configs import dify_config - - -def is_database_available() -> bool: - """ - Check if the database service is available by attempting to connect to it. - - Returns: - True if database is available, False otherwise. - """ - try: - # Try to establish a database connection using a context manager - with psycopg2.connect( - host=dify_config.DB_HOST, - port=dify_config.DB_PORT, - database=dify_config.DB_DATABASE, - user=dify_config.DB_USERNAME, - password=dify_config.DB_PASSWORD, - connect_timeout=2, # 2 second timeout - ) as conn: - pass # Connection established and will be closed automatically - return True - except (psycopg2.OperationalError, psycopg2.Error): - return False - - -def skip_if_database_unavailable(): - """ - Pytest skip decorator that skips tests when database service is unavailable. - - Usage: - @skip_if_database_unavailable() - def test_my_workflow(): - ... - """ - return pytest.mark.skipif( - not is_database_available(), - reason="Database service is not available (connection refused or authentication failed)", - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py deleted file mode 100644 index 3264ad1168..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py +++ /dev/null @@ -1,72 +0,0 @@ -import queue - -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph_engine.orchestration.dispatcher import Dispatcher -from graphon.graph_events import NodeRunSucceededEvent -from graphon.node_events import NodeRunResult -from libs.datetime_utils import naive_utc_now - - -class StubExecutionCoordinator: - def __init__(self, paused: bool) -> None: - self._paused = paused - self.mark_complete_called = False - self.failed_error: Exception | None = None - - @property - def aborted(self) -> bool: - return False - - @property - def paused(self) -> bool: - return self._paused - - @property - def execution_complete(self) -> bool: - return False - - def check_scaling(self) -> None: - return None - - def process_commands(self) -> None: - return None - - def mark_complete(self) -> None: - self.mark_complete_called = True - - def mark_failed(self, error: Exception) -> None: - self.failed_error = error - - -class StubEventHandler: - def __init__(self) -> None: - self.events: list[object] = [] - - def dispatch(self, event: object) -> None: - self.events.append(event) - - -def test_dispatcher_drains_events_when_paused() -> None: - event_queue: queue.Queue = queue.Queue() - event = NodeRunSucceededEvent( - id="exec-1", - node_id="node-1", - node_type=BuiltinNodeTypes.START, - start_at=naive_utc_now(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ) - event_queue.put(event) - - handler = StubEventHandler() - coordinator = StubExecutionCoordinator(paused=True) - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=handler, - execution_coordinator=coordinator, - event_emitter=None, - ) - - dispatcher._dispatcher_loop() - - assert handler.events == [event] - assert coordinator.mark_complete_called is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py deleted file mode 100644 index ada55f3dc5..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Test case for end node without value_type field (backward compatibility). - -This test validates that end nodes work correctly even when the value_type -field is missing from the output configuration, ensuring backward compatibility -with older workflow definitions. -""" - -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_end_node_without_value_type_field(): - """ - Test that end node works without explicit value_type field. - - The fixture implements a simple workflow that: - 1. Takes a query input from start node - 2. Passes it directly to end node - 3. End node outputs the value without specifying value_type - 4. Should correctly infer the type and output the value - - This ensures backward compatibility with workflow definitions - created before value_type became a required field. - """ - fixture_name = "end_node_without_value_type_field_workflow" - - case = WorkflowTestCase( - fixture_path=fixture_name, - inputs={"query": "test query"}, - expected_outputs={"query": "test query"}, - expected_event_sequence=[ - # Graph start - GraphRunStartedEvent, - # Start node - NodeRunStartedEvent, - NodeRunStreamChunkEvent, # Start node streams the input value - NodeRunSucceededEvent, - # End node - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Graph end - GraphRunSucceededEvent, - ], - description="End node without value_type field should work correctly", - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs == {"query": "test query"}, ( - f"Expected output to be {{'query': 'test query'}}, got {result.actual_outputs}" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py deleted file mode 100644 index 95a94110d2..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Unit tests for the execution coordinator orchestration logic.""" - -from unittest.mock import MagicMock - -import pytest - -from graphon.graph_engine.command_processing.command_processor import CommandProcessor -from graphon.graph_engine.domain.graph_execution import GraphExecution -from graphon.graph_engine.graph_state_manager import GraphStateManager -from graphon.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from graphon.graph_engine.worker_management.worker_pool import WorkerPool - - -def _build_coordinator(graph_execution: GraphExecution) -> tuple[ExecutionCoordinator, MagicMock, MagicMock]: - command_processor = MagicMock(spec=CommandProcessor) - state_manager = MagicMock(spec=GraphStateManager) - worker_pool = MagicMock(spec=WorkerPool) - - coordinator = ExecutionCoordinator( - graph_execution=graph_execution, - state_manager=state_manager, - command_processor=command_processor, - worker_pool=worker_pool, - ) - return coordinator, state_manager, worker_pool - - -def test_handle_pause_stops_workers_and_clears_state() -> None: - """Paused execution should stop workers and clear executing state.""" - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - graph_execution.pause("Awaiting human input") - - coordinator, state_manager, worker_pool = _build_coordinator(graph_execution) - - coordinator.handle_pause_if_needed() - - worker_pool.stop.assert_called_once_with() - state_manager.clear_executing.assert_called_once_with() - - -def test_handle_pause_noop_when_execution_running() -> None: - """Running execution should not trigger pause handling.""" - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - - coordinator, state_manager, worker_pool = _build_coordinator(graph_execution) - - coordinator.handle_pause_if_needed() - - worker_pool.stop.assert_not_called() - state_manager.clear_executing.assert_not_called() - - -def test_has_executing_nodes_requires_pause() -> None: - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - - coordinator, _, _ = _build_coordinator(graph_execution) - - with pytest.raises(AssertionError): - coordinator.has_executing_nodes() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py deleted file mode 100644 index 51ece26d49..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ /dev/null @@ -1,770 +0,0 @@ -""" -Table-driven test framework for GraphEngine workflows. - -This file contains property-based tests and specific workflow tests. -The core test framework is in test_table_runner.py. -""" - -import time - -from hypothesis import HealthCheck, given, settings -from hypothesis import strategies as st - -from graphon.entities.base_node_data import DefaultValue, DefaultValueType -from graphon.enums import ErrorStrategy -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_events import ( - GraphRunPartialSucceededEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) - -# Import the test framework from the new module -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase - - -# Property-based fuzzing tests for the start-end workflow -@given(query_input=st.text()) -@settings(max_examples=50, deadline=30000, suppress_health_check=[HealthCheck.too_slow]) -def test_echo_workflow_property_basic_strings(query_input): - """ - Property-based test: Echo workflow should return exactly what was input. - - This tests the fundamental property that for any string input, - the start-end workflow should echo it back unchanged. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Fuzzing test with input: {repr(query_input)[:50]}...", - ) - - result = runner.run_test_case(test_case) - - # Property: The workflow should complete successfully - assert result.success, f"Workflow failed with input {repr(query_input)}: {result.error}" - - # Property: Output should equal input (echo behavior) - assert result.actual_outputs - assert result.actual_outputs == {"query": query_input}, ( - f"Echo property violated. Input: {repr(query_input)}, " - f"Expected: {repr(query_input)}, Got: {repr(result.actual_outputs.get('query'))}" - ) - - -@given(query_input=st.text(min_size=0, max_size=1000)) -@settings(max_examples=30, deadline=20000) -def test_echo_workflow_property_bounded_strings(query_input): - """ - Property-based test with size bounds to test edge cases more efficiently. - - Tests strings up to 1000 characters to balance thoroughness with performance. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Bounded fuzzing test (len={len(query_input)})", - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed with bounded input: {result.error}" - assert result.actual_outputs == {"query": query_input} - - -@given( - query_input=st.one_of( - st.text(alphabet=st.characters(whitelist_categories=["Lu", "Ll", "Nd", "Po"])), # Letters, digits, punctuation - st.text(alphabet="🎉🌟💫⭐🔥💯🚀🎯"), # Emojis - st.text(alphabet="αβγδεζηθικλμνξοπρστυφχψω"), # Greek letters - st.text(alphabet="中文测试한국어日本語العربية"), # International characters - st.just(""), # Empty string - st.just(" " * 100), # Whitespace only - st.just("\n\t\r\f\v"), # Special whitespace chars - st.just('{"json": "like", "data": [1, 2, 3]}'), # JSON-like string - st.just("SELECT * FROM users; DROP TABLE users;--"), # SQL injection attempt - st.just(""), # XSS attempt - st.just("../../etc/passwd"), # Path traversal attempt - ) -) -@settings(max_examples=40, deadline=25000) -def test_echo_workflow_property_diverse_inputs(query_input): - """ - Property-based test with diverse input types including edge cases and security payloads. - - Tests various categories of potentially problematic inputs: - - Unicode characters from different languages - - Emojis and special symbols - - Whitespace variations - - Malicious payloads (SQL injection, XSS, path traversal) - - JSON-like structures - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Diverse input fuzzing: {type(query_input).__name__}", - ) - - result = runner.run_test_case(test_case) - - # Property: System should handle all inputs gracefully (no crashes) - assert result.success, f"Workflow failed with diverse input {repr(query_input)}: {result.error}" - - # Property: Echo behavior must be preserved regardless of input type - assert result.actual_outputs == {"query": query_input} - - -@given(query_input=st.text(min_size=1000, max_size=5000)) -@settings(max_examples=10, deadline=60000) -def test_echo_workflow_property_large_inputs(query_input): - """ - Property-based test for large inputs to test memory and performance boundaries. - - Tests the system's ability to handle larger payloads efficiently. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Large input test (size: {len(query_input)} chars)", - timeout=45.0, # Longer timeout for large inputs - ) - - start_time = time.perf_counter() - result = runner.run_test_case(test_case) - execution_time = time.perf_counter() - start_time - - # Property: Large inputs should still work - assert result.success, f"Large input workflow failed: {result.error}" - - # Property: Echo behavior preserved for large inputs - assert result.actual_outputs == {"query": query_input} - - # Property: Performance should be reasonable even for large inputs - assert execution_time < 30.0, f"Large input took too long: {execution_time:.2f}s" - - -def test_echo_workflow_robustness_smoke_test(): - """ - Smoke test to ensure the basic workflow functionality works before fuzzing. - - This test uses a simple, known-good input to verify the test infrastructure - is working correctly before running the fuzzing tests. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "smoke test"}, - expected_outputs={"query": "smoke test"}, - description="Smoke test for basic functionality", - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Smoke test failed: {result.error}" - assert result.actual_outputs == {"query": "smoke test"} - assert result.execution_time > 0 - - -def test_if_else_workflow_true_branch(): - """ - Test if-else workflow when input contains 'hello' (true branch). - - Should output {"true": input_query} when query contains "hello". - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello world"}, - expected_outputs={"true": "hello world"}, - description="Basic hello case", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "say hello to everyone"}, - expected_outputs={"true": "say hello to everyone"}, - description="Hello in middle of sentence", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello"}, - expected_outputs={"true": "hello"}, - description="Just hello", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hellohello"}, - expected_outputs={"true": "hellohello"}, - description="Multiple hello occurrences", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test case '{result.test_case.description}' failed: {result.error}" - # Check that outputs contain ONLY the expected key (true branch) - assert result.actual_outputs == result.test_case.expected_outputs, ( - f"Expected only 'true' key in outputs for {result.test_case.description}. " - f"Expected: {result.test_case.expected_outputs}, Got: {result.actual_outputs}" - ) - - -def test_if_else_workflow_false_branch(): - """ - Test if-else workflow when input does not contain 'hello' (false branch). - - Should output {"false": input_query} when query does not contain "hello". - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "goodbye world"}, - expected_outputs={"false": "goodbye world"}, - description="Basic goodbye case", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hi there"}, - expected_outputs={"false": "hi there"}, - description="Simple greeting without hello", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": ""}, - expected_outputs={"false": ""}, - description="Empty string", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "test message"}, - expected_outputs={"false": "test message"}, - description="Regular message", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test case '{result.test_case.description}' failed: {result.error}" - # Check that outputs contain ONLY the expected key (false branch) - assert result.actual_outputs == result.test_case.expected_outputs, ( - f"Expected only 'false' key in outputs for {result.test_case.description}. " - f"Expected: {result.test_case.expected_outputs}, Got: {result.actual_outputs}" - ) - - -def test_if_else_workflow_edge_cases(): - """ - Test if-else workflow edge cases and case sensitivity. - - Tests various edge cases including case sensitivity, similar words, etc. - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "Hello world"}, - expected_outputs={"false": "Hello world"}, - description="Capitalized Hello (case sensitive test)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "HELLO"}, - expected_outputs={"false": "HELLO"}, - description="All caps HELLO (case sensitive test)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "helllo"}, - expected_outputs={"false": "helllo"}, - description="Typo: helllo (with extra l)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "helo"}, - expected_outputs={"false": "helo"}, - description="Typo: helo (missing l)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello123"}, - expected_outputs={"true": "hello123"}, - description="Hello with numbers", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello!@#"}, - expected_outputs={"true": "hello!@#"}, - description="Hello with special characters", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": " hello "}, - expected_outputs={"true": " hello "}, - description="Hello with surrounding spaces", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test case '{result.test_case.description}' failed: {result.error}" - # Check that outputs contain ONLY the expected key - assert result.actual_outputs == result.test_case.expected_outputs, ( - f"Expected exact match for {result.test_case.description}. " - f"Expected: {result.test_case.expected_outputs}, Got: {result.actual_outputs}" - ) - - -@given(query_input=st.text()) -@settings(max_examples=50, deadline=30000, suppress_health_check=[HealthCheck.too_slow]) -def test_if_else_workflow_property_basic_strings(query_input): - """ - Property-based test: If-else workflow should output correct branch based on 'hello' content. - - This tests the fundamental property that for any string input: - - If input contains "hello", output should be {"true": input} - - If input doesn't contain "hello", output should be {"false": input} - """ - runner = TableTestRunner() - - # Determine expected output based on whether input contains "hello" - contains_hello = "hello" in query_input - expected_key = "true" if contains_hello else "false" - expected_outputs = {expected_key: query_input} - - test_case = WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": query_input}, - expected_outputs=expected_outputs, - description=f"Property test with input: {repr(query_input)[:50]}...", - ) - - result = runner.run_test_case(test_case) - - # Property: The workflow should complete successfully - assert result.success, f"Workflow failed with input {repr(query_input)}: {result.error}" - - # Property: Output should contain ONLY the expected key with correct value - assert result.actual_outputs == expected_outputs, ( - f"If-else property violated. Input: {repr(query_input)}, " - f"Expected: {expected_outputs}, Got: {result.actual_outputs}" - ) - - -@given(query_input=st.text(min_size=0, max_size=1000)) -@settings(max_examples=30, deadline=20000) -def test_if_else_workflow_property_bounded_strings(query_input): - """ - Property-based test with size bounds for if-else workflow. - - Tests strings up to 1000 characters to balance thoroughness with performance. - """ - runner = TableTestRunner() - - contains_hello = "hello" in query_input - expected_key = "true" if contains_hello else "false" - expected_outputs = {expected_key: query_input} - - test_case = WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": query_input}, - expected_outputs=expected_outputs, - description=f"Bounded if-else test (len={len(query_input)}, contains_hello={contains_hello})", - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed with bounded input: {result.error}" - assert result.actual_outputs == expected_outputs - - -@given( - query_input=st.one_of( - st.text(alphabet=st.characters(whitelist_categories=["Lu", "Ll", "Nd", "Po"])), # Letters, digits, punctuation - st.text(alphabet="hello"), # Strings that definitely contain hello - st.text(alphabet="xyz"), # Strings that definitely don't contain hello - st.just("hello world"), # Known true case - st.just("goodbye world"), # Known false case - st.just(""), # Empty string - st.just("Hello"), # Case sensitivity test - st.just("HELLO"), # Case sensitivity test - st.just("hello" * 10), # Multiple hello occurrences - st.just("say hello to everyone"), # Hello in middle - st.text(alphabet="🎉🌟💫⭐🔥💯🚀🎯"), # Emojis - st.text(alphabet="中文测试한국어日本語العربية"), # International characters - ) -) -@settings(max_examples=40, deadline=25000) -def test_if_else_workflow_property_diverse_inputs(query_input): - """ - Property-based test with diverse input types for if-else workflow. - - Tests various categories including: - - Known true/false cases - - Case sensitivity scenarios - - Unicode characters from different languages - - Emojis and special symbols - - Multiple hello occurrences - """ - runner = TableTestRunner() - - contains_hello = "hello" in query_input - expected_key = "true" if contains_hello else "false" - expected_outputs = {expected_key: query_input} - - test_case = WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": query_input}, - expected_outputs=expected_outputs, - description=f"Diverse if-else test: {type(query_input).__name__} (contains_hello={contains_hello})", - ) - - result = runner.run_test_case(test_case) - - # Property: System should handle all inputs gracefully (no crashes) - assert result.success, f"Workflow failed with diverse input {repr(query_input)}: {result.error}" - - # Property: Correct branch logic must be preserved regardless of input type - assert result.actual_outputs == expected_outputs, ( - f"Branch logic violated. Input: {repr(query_input)}, " - f"Contains 'hello': {contains_hello}, Expected: {expected_outputs}, Got: {result.actual_outputs}" - ) - - -# Tests for the Layer system -def test_layer_system_basic(): - """Test basic layer functionality with DebugLoggingLayer.""" - from graphon.graph_engine.layers import DebugLoggingLayer - - runner = WorkflowRunner() - - # Load a simple echo workflow - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture(fixture_data, inputs={"query": "test layer system"}) - - # Create engine with layer - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Add debug logging layer - debug_layer = DebugLoggingLayer(level="DEBUG", include_inputs=True, include_outputs=True) - engine.layer(debug_layer) - - # Run workflow - events = list(engine.run()) - - # Verify events were generated - assert len(events) > 0 - assert isinstance(events[0], GraphRunStartedEvent) - assert isinstance(events[-1], GraphRunSucceededEvent) - - # Verify layer received context - assert debug_layer.graph_runtime_state is not None - assert debug_layer.command_channel is not None - - # Verify layer tracked execution stats - assert debug_layer.node_count > 0 - assert debug_layer.success_count > 0 - - -def test_layer_chaining(): - """Test chaining multiple layers.""" - from graphon.graph_engine.layers import DebugLoggingLayer, GraphEngineLayer - - # Create a custom test layer - class TestLayer(GraphEngineLayer): - def __init__(self): - super().__init__() - self.events_received = [] - self.graph_started = False - self.graph_ended = False - - def on_graph_start(self): - self.graph_started = True - - def on_event(self, event): - self.events_received.append(event.__class__.__name__) - - def on_graph_end(self, error): - self.graph_ended = True - - runner = WorkflowRunner() - - # Load workflow - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture(fixture_data, inputs={"query": "test chaining"}) - - # Create engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Chain multiple layers - test_layer = TestLayer() - debug_layer = DebugLoggingLayer(level="INFO") - - engine.layer(test_layer).layer(debug_layer) - - # Run workflow - events = list(engine.run()) - - # Verify both layers received events - assert test_layer.graph_started - assert test_layer.graph_ended - assert len(test_layer.events_received) > 0 - - # Verify debug layer also worked - assert debug_layer.node_count > 0 - - -def test_layer_error_handling(): - """Test that layer errors don't crash the engine.""" - from graphon.graph_engine.layers import GraphEngineLayer - - # Create a layer that throws errors - class FaultyLayer(GraphEngineLayer): - def on_graph_start(self): - raise RuntimeError("Intentional error in on_graph_start") - - def on_event(self, event): - raise RuntimeError("Intentional error in on_event") - - def on_graph_end(self, error): - raise RuntimeError("Intentional error in on_graph_end") - - runner = WorkflowRunner() - - # Load workflow - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture(fixture_data, inputs={"query": "test error handling"}) - - # Create engine with faulty layer - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Add faulty layer - engine.layer(FaultyLayer()) - - # Run workflow - should not crash despite layer errors - events = list(engine.run()) - - # Verify workflow still completed successfully - assert len(events) > 0 - assert isinstance(events[-1], GraphRunSucceededEvent) - assert events[-1].outputs == {"query": "test error handling"} - - -def test_event_sequence_validation(): - """Test the new event sequence validation feature.""" - from graphon.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent - - runner = TableTestRunner() - - # Test 1: Successful event sequence validation - test_case_success = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test event sequence"}, - expected_outputs={"query": "test event sequence"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, # Start node begins - NodeRunStreamChunkEvent, # Start node streaming - NodeRunSucceededEvent, # Start node completes - NodeRunStartedEvent, # End node begins - NodeRunSucceededEvent, # End node completes - GraphRunSucceededEvent, # Graph completes - ], - description="Test with correct event sequence", - ) - - result = runner.run_test_case(test_case_success) - assert result.success, f"Test should pass with correct event sequence. Error: {result.event_mismatch_details}" - assert result.event_sequence_match is True - assert result.event_mismatch_details is None - - # Test 2: Failed event sequence validation - wrong order - test_case_wrong_order = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test wrong order"}, - expected_outputs={"query": "test wrong order"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunSucceededEvent, # Wrong: expecting success before start - NodeRunStreamChunkEvent, - NodeRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - description="Test with incorrect event order", - ) - - result = runner.run_test_case(test_case_wrong_order) - assert not result.success, "Test should fail with incorrect event sequence" - assert result.event_sequence_match is False - assert result.event_mismatch_details is not None - assert "Event mismatch at position" in result.event_mismatch_details - - # Test 3: Failed event sequence validation - wrong count - test_case_wrong_count = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test wrong count"}, - expected_outputs={"query": "test wrong count"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Missing the second node's events - GraphRunSucceededEvent, - ], - description="Test with incorrect event count", - ) - - result = runner.run_test_case(test_case_wrong_count) - assert not result.success, "Test should fail with incorrect event count" - assert result.event_sequence_match is False - assert result.event_mismatch_details is not None - assert "Event count mismatch" in result.event_mismatch_details - - # Test 4: No event sequence validation (backward compatibility) - test_case_no_validation = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test no validation"}, - expected_outputs={"query": "test no validation"}, - # No expected_event_sequence provided - description="Test without event sequence validation", - ) - - result = runner.run_test_case(test_case_no_validation) - assert result.success, "Test should pass when no event sequence is provided" - assert result.event_sequence_match is None - assert result.event_mismatch_details is None - - -def test_event_sequence_validation_with_table_tests(): - """Test event sequence validation with table-driven tests.""" - from graphon.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent - - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test1"}, - expected_outputs={"query": "test1"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - description="Table test 1: Valid sequence", - ), - WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test2"}, - expected_outputs={"query": "test2"}, - # No event sequence validation for this test - description="Table test 2: No sequence validation", - ), - WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test3"}, - expected_outputs={"query": "test3"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - description="Table test 3: Valid sequence", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - # Check all tests passed - for i, result in enumerate(suite_result.results): - if i == 1: # Test 2 has no event sequence validation - assert result.event_sequence_match is None - else: - assert result.event_sequence_match is True - assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}" - - -def test_graph_run_emits_partial_success_when_node_failure_recovered(): - runner = TableTestRunner() - - fixture_data = runner.workflow_runner.load_fixture("basic_chatflow") - mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build() - - graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - query="hello", - use_mock_factory=True, - mock_config=mock_config, - ) - - llm_node = graph.nodes["llm"] - base_node_data = llm_node.node_data - base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE - base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)] - - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - events = list(engine.run()) - - assert isinstance(events[-1], GraphRunPartialSucceededEvent) - - partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent)) - assert partial_event.exceptions_count == 1 - assert partial_event.outputs.get("answer") == "fallback response" - - assert not any(isinstance(event, GraphRunSucceededEvent) for event in events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py deleted file mode 100644 index 348ceb6788..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Unit tests for GraphExecution serialization helpers.""" - -from __future__ import annotations - -import json -from collections import deque -from unittest.mock import MagicMock - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState -from graphon.graph_engine.domain import GraphExecution -from graphon.graph_engine.response_coordinator import ResponseStreamCoordinator -from graphon.graph_engine.response_coordinator.path import Path -from graphon.graph_engine.response_coordinator.session import ResponseSession -from graphon.graph_events import NodeRunStreamChunkEvent -from graphon.nodes.base.template import Template, TextSegment, VariableSegment - - -class CustomGraphExecutionError(Exception): - """Custom exception used to verify error serialization.""" - - -def test_graph_execution_serialization_round_trip() -> None: - """GraphExecution serialization restores full aggregate state.""" - # Arrange - execution = GraphExecution(workflow_id="wf-1") - execution.start() - node_a = execution.get_or_create_node_execution("node-a") - node_a.mark_started(execution_id="exec-1") - node_a.increment_retry() - node_a.mark_failed("boom") - node_b = execution.get_or_create_node_execution("node-b") - node_b.mark_skipped() - execution.fail(CustomGraphExecutionError("serialization failure")) - - # Act - serialized = execution.dumps() - payload = json.loads(serialized) - restored = GraphExecution(workflow_id="wf-1") - restored.loads(serialized) - - # Assert - assert payload["type"] == "GraphExecution" - assert payload["version"] == "1.0" - assert restored.workflow_id == "wf-1" - assert restored.started is True - assert restored.completed is True - assert restored.aborted is False - assert isinstance(restored.error, CustomGraphExecutionError) - assert str(restored.error) == "serialization failure" - assert set(restored.node_executions) == {"node-a", "node-b"} - restored_node_a = restored.node_executions["node-a"] - assert restored_node_a.state is NodeState.TAKEN - assert restored_node_a.retry_count == 1 - assert restored_node_a.execution_id == "exec-1" - assert restored_node_a.error == "boom" - restored_node_b = restored.node_executions["node-b"] - assert restored_node_b.state is NodeState.SKIPPED - assert restored_node_b.retry_count == 0 - assert restored_node_b.execution_id is None - assert restored_node_b.error is None - - -def test_graph_execution_loads_replaces_existing_state() -> None: - """loads replaces existing runtime data with serialized snapshot.""" - # Arrange - source = GraphExecution(workflow_id="wf-2") - source.start() - source_node = source.get_or_create_node_execution("node-source") - source_node.mark_taken() - serialized = source.dumps() - - target = GraphExecution(workflow_id="wf-2") - target.start() - target.abort("pre-existing abort") - temp_node = target.get_or_create_node_execution("node-temp") - temp_node.increment_retry() - temp_node.mark_failed("temp error") - - # Act - target.loads(serialized) - - # Assert - assert target.aborted is False - assert target.error is None - assert target.started is True - assert target.completed is False - assert set(target.node_executions) == {"node-source"} - restored_node = target.node_executions["node-source"] - assert restored_node.state is NodeState.TAKEN - assert restored_node.retry_count == 0 - assert restored_node.execution_id is None - assert restored_node.error is None - - -def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> None: - """ResponseStreamCoordinator serialization restores coordinator internals.""" - - template_main = Template(segments=[TextSegment(text="Hi "), VariableSegment(selector=["node-source", "text"])]) - template_secondary = Template(segments=[TextSegment(text="secondary")]) - - class DummyNode: - def __init__(self, node_id: str, template: Template, execution_type: NodeExecutionType) -> None: - self.id = node_id - self.node_type = ( - BuiltinNodeTypes.ANSWER if execution_type == NodeExecutionType.RESPONSE else BuiltinNodeTypes.LLM - ) - self.execution_type = execution_type - self.state = NodeState.UNKNOWN - self.title = node_id - self.template = template - - def blocks_variable_output(self, *_args) -> bool: - return False - - response_node1 = DummyNode("response-1", template_main, NodeExecutionType.RESPONSE) - response_node2 = DummyNode("response-2", template_main, NodeExecutionType.RESPONSE) - response_node3 = DummyNode("response-3", template_main, NodeExecutionType.RESPONSE) - source_node = DummyNode("node-source", template_secondary, NodeExecutionType.EXECUTABLE) - - class DummyGraph: - def __init__(self) -> None: - self.nodes = { - response_node1.id: response_node1, - response_node2.id: response_node2, - response_node3.id: response_node3, - source_node.id: source_node, - } - self.edges: dict[str, object] = {} - self.root_node = response_node1 - - def get_outgoing_edges(self, _node_id: str): # pragma: no cover - not exercised - return [] - - def get_incoming_edges(self, _node_id: str): # pragma: no cover - not exercised - return [] - - graph = DummyGraph() - - def fake_from_node(cls, node: DummyNode) -> ResponseSession: - return ResponseSession(node_id=node.id, template=node.template) - - monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node)) - - coordinator = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type] - coordinator._response_nodes = {"response-1", "response-2", "response-3"} - coordinator._paths_maps = { - "response-1": [Path(edges=["edge-1"])], - "response-2": [Path(edges=[])], - "response-3": [Path(edges=["edge-2", "edge-3"])], - } - - active_session = ResponseSession(node_id="response-1", template=response_node1.template) - active_session.index = 1 - coordinator._active_session = active_session - waiting_session = ResponseSession(node_id="response-2", template=response_node2.template) - coordinator._waiting_sessions = deque([waiting_session]) - pending_session = ResponseSession(node_id="response-3", template=response_node3.template) - pending_session.index = 2 - coordinator._response_sessions = {"response-3": pending_session} - - coordinator._node_execution_ids = {"response-1": "exec-1"} - event = NodeRunStreamChunkEvent( - id="exec-1", - node_id="response-1", - node_type=BuiltinNodeTypes.ANSWER, - selector=["node-source", "text"], - chunk="chunk-1", - is_final=False, - ) - coordinator._stream_buffers = {("node-source", "text"): [event]} - coordinator._stream_positions = {("node-source", "text"): 1} - coordinator._closed_streams = {("node-source", "text")} - - serialized = coordinator.dumps() - - restored = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type] - monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node)) - restored.loads(serialized) - - assert restored._response_nodes == {"response-1", "response-2", "response-3"} - assert restored._paths_maps["response-1"][0].edges == ["edge-1"] - assert restored._active_session is not None - assert restored._active_session.node_id == "response-1" - assert restored._active_session.index == 1 - waiting_restored = list(restored._waiting_sessions) - assert len(waiting_restored) == 1 - assert waiting_restored[0].node_id == "response-2" - assert waiting_restored[0].index == 0 - assert set(restored._response_sessions) == {"response-3"} - assert restored._response_sessions["response-3"].index == 2 - assert restored._node_execution_ids == {"response-1": "exec-1"} - assert ("node-source", "text") in restored._stream_buffers - restored_event = restored._stream_buffers[("node-source", "text")][0] - assert restored_event.chunk == "chunk-1" - assert restored._stream_positions[("node-source", "text")] == 1 - assert ("node-source", "text") in restored._closed_streams diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py deleted file mode 100644 index a6417822d2..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ /dev/null @@ -1,190 +0,0 @@ -import time -from collections.abc import Mapping - -from core.workflow.system_variables import build_system_variables -from graphon.entities import GraphInitParams -from graphon.enums import NodeState -from graphon.graph import Graph -from graphon.graph_engine.graph_state_manager import GraphStateManager -from graphon.graph_engine.ready_queue import InMemoryReadyQueue -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_llm_node( - *, - node_id: str, - runtime_state: GraphRuntimeState, - graph_init_params: GraphInitParams, - mock_config: MockConfig, -) -> MockLLMNode: - llm_data = LLMNodeData( - title=f"LLM {node_id}", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=f"Prompt {node_id}", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - return MockLLMNode( - id=llm_config["id"], - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - -def _build_graph(runtime_state: GraphRuntimeState) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - mock_config = MockConfig() - llm_a = _build_llm_node( - node_id="llm_a", - runtime_state=runtime_state, - graph_init_params=graph_init_params, - mock_config=mock_config, - ) - llm_b = _build_llm_node( - node_id="llm_b", - runtime_state=runtime_state, - graph_init_params=graph_init_params, - mock_config=mock_config, - ) - - end_data = EndNodeData(title="End", outputs=[], desc=None) - end_config = {"id": "end", "data": end_data.model_dump()} - end_node = EndNode( - id=end_config["id"], - config=end_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - builder = ( - Graph.new() - .add_root(start_node) - .add_node(llm_a, from_node_id="start") - .add_node(llm_b, from_node_id="start") - .add_node(end_node, from_node_id="llm_a") - ) - return builder.connect(tail="llm_b", head="end").build() - - -def _edge_state_map(graph: Graph) -> Mapping[tuple[str, str, str], NodeState]: - return {(edge.tail, edge.head, edge.source_handle): edge.state for edge in graph.edges.values()} - - -def test_runtime_state_snapshot_restores_graph_states() -> None: - runtime_state = _build_runtime_state() - graph = _build_graph(runtime_state) - runtime_state.attach_graph(graph) - - graph.nodes["llm_a"].state = NodeState.TAKEN - graph.nodes["llm_b"].state = NodeState.SKIPPED - - for edge in graph.edges.values(): - if edge.tail == "start" and edge.head == "llm_a": - edge.state = NodeState.TAKEN - elif edge.tail == "start" and edge.head == "llm_b": - edge.state = NodeState.SKIPPED - elif edge.head == "end" and edge.tail == "llm_a": - edge.state = NodeState.TAKEN - elif edge.head == "end" and edge.tail == "llm_b": - edge.state = NodeState.SKIPPED - - snapshot = runtime_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resumed_graph = _build_graph(resumed_state) - resumed_state.attach_graph(resumed_graph) - - assert resumed_graph.nodes["llm_a"].state == NodeState.TAKEN - assert resumed_graph.nodes["llm_b"].state == NodeState.SKIPPED - assert _edge_state_map(resumed_graph) == _edge_state_map(graph) - - -def test_join_readiness_uses_restored_edge_states() -> None: - runtime_state = _build_runtime_state() - graph = _build_graph(runtime_state) - runtime_state.attach_graph(graph) - - ready_queue = InMemoryReadyQueue() - state_manager = GraphStateManager(graph, ready_queue) - - for edge in graph.get_incoming_edges("end"): - if edge.tail == "llm_a": - edge.state = NodeState.TAKEN - if edge.tail == "llm_b": - edge.state = NodeState.UNKNOWN - - assert state_manager.is_node_ready("end") is False - - for edge in graph.get_incoming_edges("end"): - if edge.tail == "llm_b": - edge.state = NodeState.TAKEN - - assert state_manager.is_node_ready("end") is True - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resumed_graph = _build_graph(resumed_state) - resumed_state.attach_graph(resumed_graph) - - resumed_state_manager = GraphStateManager(resumed_graph, InMemoryReadyQueue()) - assert resumed_state_manager.is_node_ready("end") is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py deleted file mode 100644 index ca9a929591..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ /dev/null @@ -1,389 +0,0 @@ -import datetime -import time -from collections.abc import Iterable -from unittest import mock -from unittest.mock import MagicMock - -from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.graph import Graph -from graphon.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.graph_events.node import NodeRunHumanInputFormFilledEvent -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _build_branching_graph( - mock_config: MockConfig, - form_repository: HumanInputFormRepository, - graph_runtime_state: GraphRuntimeState | None = None, -) -> tuple[Graph, GraphRuntimeState]: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - if graph_runtime_state is None: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="test-execution-id", - ), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: - llm_data = LLMNodeData( - title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=prompt_text, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - llm_node = MockLLMNode( - id=node_id, - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - credentials_provider=mock.Mock(), - model_factory=mock.Mock(), - ) - return llm_node - - llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream") - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[ - UserAction(id="primary", title="Primary"), - UserAction(id="secondary", title="Secondary"), - ], - ) - - human_config = {"id": "human", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - form_repository=form_repository, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), - ) - - llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") - llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary") - - end_primary_data = EndNodeData( - title="End Primary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="primary_text", value_type=OutputVariableType.STRING, value_selector=["llm_primary", "text"] - ), - ], - desc=None, - ) - end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()} - end_primary = EndNode( - id=end_primary_config["id"], - config=end_primary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - end_secondary_data = EndNodeData( - title="End Secondary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="secondary_text", - value_type=OutputVariableType.STRING, - value_selector=["llm_secondary", "text"], - ), - ], - desc=None, - ) - end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()} - end_secondary = EndNode( - id=end_secondary_config["id"], - config=end_secondary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = ( - Graph.new() - .add_root(start_node) - .add_node(llm_initial) - .add_node(human_node) - .add_node(llm_primary, from_node_id="human", source_handle="primary") - .add_node(end_primary, from_node_id="llm_primary") - .add_node(llm_secondary, from_node_id="human", source_handle="secondary") - .add_node(end_secondary, from_node_id="llm_secondary") - .build() - ) - return graph, graph_runtime_state - - -def _expected_mock_llm_chunks(text: str) -> list[str]: - chunks: list[str] = [] - for index, word in enumerate(text.split(" ")): - chunk = word if index == 0 else f" {word}" - chunks.append(chunk) - chunks.append("") - return chunks - - -def _assert_stream_chunk_sequence( - chunk_events: Iterable[NodeRunStreamChunkEvent], - expected_nodes: list[str], - expected_chunks: list[str], -) -> None: - actual_nodes = [event.node_id for event in chunk_events] - actual_chunks = [event.chunk for event in chunk_events] - assert actual_nodes == expected_nodes - assert actual_chunks == expected_chunks - - -def test_human_input_llm_streaming_across_multiple_branches() -> None: - mock_config = MockConfig() - mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"}) - mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"}) - mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"}) - - branch_scenarios = [ - { - "handle": "primary", - "resume_llm": "llm_primary", - "end_node": "end_primary", - "expected_pre_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes - ("end_primary", ["\n"]), # literal segment emitted when end_primary session activates - ], - "expected_post_chunks": [ - ("llm_primary", _expected_mock_llm_chunks("Primary stream output")), # live stream from chosen branch - ], - }, - { - "handle": "secondary", - "resume_llm": "llm_secondary", - "end_node": "end_secondary", - "expected_pre_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes - ("end_secondary", ["\n"]), # literal segment emitted when end_secondary session activates - ], - "expected_post_chunks": [ - ("llm_secondary", _expected_mock_llm_chunks("Secondary")), # live stream from chosen branch - ], - }, - ] - - for scenario in branch_scenarios: - runner = TableTestRunner() - - mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form.return_value = None - mock_form_entity = MagicMock(spec=HumanInputFormEntity) - mock_form_entity.id = "test_form_id" - mock_form_entity.submission_token = "test_web_app_token" - mock_form_entity.recipients = [] - mock_form_entity.rendered_content = "rendered" - mock_form_entity.submitted = False - mock_create_repo.create_form.return_value = mock_form_entity - - def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]: - return _build_branching_graph(mock_config, mock_create_repo) - - initial_case = WorkflowTestCase( - description="HumanInput pause before branching decision", - graph_factory=initial_graph_factory, - expected_event_sequence=[ - GraphRunStartedEvent, # initial run: graph execution starts - NodeRunStartedEvent, # start node begins execution - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial starts streaming - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # human node begins and issues pause - NodeRunPauseRequestedEvent, # human node requests pause awaiting input - GraphRunPausedEvent, # graph run pauses awaiting resume - ], - ) - - initial_result = runner.run_test_case(initial_case) - - assert initial_result.success, initial_result.event_mismatch_details - assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events) - - pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"]) - post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"]) - expected_pre_chunk_events_in_resumption = [ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunHumanInputFormFilledEvent, - ] - - expected_resume_sequence: list[type] = ( - expected_pre_chunk_events_in_resumption - + [NodeRunStreamChunkEvent] * pre_chunk_count - + [ - NodeRunSucceededEvent, - NodeRunStartedEvent, - ] - + [NodeRunStreamChunkEvent] * post_chunk_count - + [ - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ] - ) - - mock_get_repo = MagicMock(spec=HumanInputFormRepository) - submitted_form = MagicMock(spec=HumanInputFormEntity) - submitted_form.id = mock_form_entity.id - submitted_form.submission_token = mock_form_entity.submission_token - submitted_form.recipients = [] - submitted_form.rendered_content = mock_form_entity.rendered_content - submitted_form.submitted = True - submitted_form.selected_action_id = scenario["handle"] - submitted_form.submitted_data = {} - submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - mock_get_repo.get_form.return_value = submitted_form - - def resume_graph_factory( - initial_result=initial_result, mock_get_repo=mock_get_repo - ) -> tuple[Graph, GraphRuntimeState]: - assert initial_result.graph_runtime_state is not None - serialized_runtime_state = initial_result.graph_runtime_state.dumps() - resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) - return _build_branching_graph(mock_config, mock_get_repo, resume_runtime_state) - - resume_case = WorkflowTestCase( - description=f"HumanInput resumes via {scenario['handle']} branch", - graph_factory=resume_graph_factory, - expected_event_sequence=expected_resume_sequence, - ) - - resume_result = runner.run_test_case(resume_case) - - assert resume_result.success, resume_result.event_mismatch_details - - resume_events = resume_result.events - - chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)] - assert len(chunk_events) == pre_chunk_count + post_chunk_count - - pre_chunk_events = chunk_events[:pre_chunk_count] - post_chunk_events = chunk_events[pre_chunk_count:] - - expected_pre_nodes: list[str] = [] - expected_pre_chunks: list[str] = [] - for node_id, chunks in scenario["expected_pre_chunks"]: - expected_pre_nodes.extend([node_id] * len(chunks)) - expected_pre_chunks.extend(chunks) - _assert_stream_chunk_sequence(pre_chunk_events, expected_pre_nodes, expected_pre_chunks) - - expected_post_nodes: list[str] = [] - expected_post_chunks: list[str] = [] - for node_id, chunks in scenario["expected_post_chunks"]: - expected_post_nodes.extend([node_id] * len(chunks)) - expected_post_chunks.extend(chunks) - _assert_stream_chunk_sequence(post_chunk_events, expected_post_nodes, expected_post_chunks) - - human_success_index = next( - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human" - ) - pre_indices = [ - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index - ] - expected_pre_chunk_events_count_in_resumption = len(expected_pre_chunk_events_in_resumption) - assert pre_indices == list(range(expected_pre_chunk_events_count_in_resumption, human_success_index)) - - resume_chunk_indices = [ - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"] - ] - assert resume_chunk_indices, "Expected streaming output from the selected branch" - resume_start_index = next( - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"] - ) - resume_success_index = next( - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"] - ) - assert resume_start_index < min(resume_chunk_indices) - assert max(resume_chunk_indices) < resume_success_index - - started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)] - assert started_nodes == ["human", scenario["resume_llm"], scenario["end_node"]] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py deleted file mode 100644 index c50aaafe2c..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ /dev/null @@ -1,346 +0,0 @@ -import datetime -import time -from unittest import mock -from unittest.mock import MagicMock - -from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.graph import Graph -from graphon.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.graph_events.node import NodeRunHumanInputFormFilledEvent -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _build_llm_human_llm_graph( - mock_config: MockConfig, - form_repository: HumanInputFormRepository, - graph_runtime_state: GraphRuntimeState | None = None, -) -> tuple[Graph, GraphRuntimeState]: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - if graph_runtime_state is None: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id," - ), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: - llm_data = LLMNodeData( - title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=prompt_text, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - llm_node = MockLLMNode( - id=node_id, - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - credentials_provider=mock.Mock(), - model_factory=mock.Mock(), - ) - return llm_node - - llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt") - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[ - UserAction(id="accept", title="Accept"), - UserAction(id="reject", title="Reject"), - ], - ) - - human_config = {"id": "human", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - form_repository=form_repository, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), - ) - - llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") - - end_data = EndNodeData( - title="End", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="resume_text", value_type=OutputVariableType.STRING, value_selector=["llm_resume", "text"] - ), - ], - desc=None, - ) - end_config = {"id": "end", "data": end_data.model_dump()} - end_node = EndNode( - id=end_config["id"], - config=end_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = ( - Graph.new() - .add_root(start_node) - .add_node(llm_first) - .add_node(human_node) - .add_node(llm_second, source_handle="accept") - .add_node(end_node) - .build() - ) - return graph, graph_runtime_state - - -def _expected_mock_llm_chunks(text: str) -> list[str]: - chunks: list[str] = [] - for index, word in enumerate(text.split(" ")): - chunk = word if index == 0 else f" {word}" - chunks.append(chunk) - chunks.append("") - return chunks - - -def test_human_input_llm_streaming_order_across_pause() -> None: - runner = TableTestRunner() - - initial_text = "Hello, pause" - resume_text = "Welcome back!" - - mock_config = MockConfig() - mock_config.set_node_outputs("llm_initial", {"text": initial_text}) - mock_config.set_node_outputs("llm_resume", {"text": resume_text}) - - expected_initial_sequence: list[type] = [ - GraphRunStartedEvent, # graph run begins - NodeRunStartedEvent, # start node begins - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial begins streaming - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # human node begins and requests pause - NodeRunPauseRequestedEvent, # human node pause requested - GraphRunPausedEvent, # graph run pauses awaiting resume - ] - - mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form.return_value = None - mock_form_entity = MagicMock(spec=HumanInputFormEntity) - mock_form_entity.id = "test_form_id" - mock_form_entity.submission_token = "test_web_app_token" - mock_form_entity.recipients = [] - mock_form_entity.rendered_content = "rendered" - mock_form_entity.submitted = False - mock_create_repo.create_form.return_value = mock_form_entity - - def graph_factory() -> tuple[Graph, GraphRuntimeState]: - return _build_llm_human_llm_graph(mock_config, mock_create_repo) - - initial_case = WorkflowTestCase( - description="HumanInput pause preserves LLM streaming order", - graph_factory=graph_factory, - expected_event_sequence=expected_initial_sequence, - ) - - initial_result = runner.run_test_case(initial_case) - - assert initial_result.success, initial_result.event_mismatch_details - - initial_events = initial_result.events - initial_chunks = _expected_mock_llm_chunks(initial_text) - - initial_stream_chunk_events = [event for event in initial_events if isinstance(event, NodeRunStreamChunkEvent)] - assert initial_stream_chunk_events == [] - - pause_index = next(i for i, event in enumerate(initial_events) if isinstance(event, GraphRunPausedEvent)) - llm_succeeded_index = next( - i - for i, event in enumerate(initial_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_initial" - ) - assert llm_succeeded_index < pause_index - - graph_runtime_state = initial_result.graph_runtime_state - graph = initial_result.graph - assert graph_runtime_state is not None - assert graph is not None - - coordinator = graph_runtime_state.response_coordinator - stream_buffers = coordinator._stream_buffers # Tests may access internals for assertions - assert ("llm_initial", "text") in stream_buffers - initial_stream_chunks = [event.chunk for event in stream_buffers[("llm_initial", "text")]] - assert initial_stream_chunks == initial_chunks - assert ("llm_resume", "text") not in stream_buffers - - resume_chunks = _expected_mock_llm_chunks(resume_text) - expected_resume_sequence: list[type] = [ - GraphRunStartedEvent, # resumed graph run begins - NodeRunStartedEvent, # human node restarts - # Form Filled should be generated first, then the node execution ends and stream chunk is generated. - NodeRunHumanInputFormFilledEvent, - NodeRunStreamChunkEvent, # cached llm_initial chunk 1 - NodeRunStreamChunkEvent, # cached llm_initial chunk 2 - NodeRunStreamChunkEvent, # cached llm_initial final chunk - NodeRunStreamChunkEvent, # end node emits combined template separator - NodeRunSucceededEvent, # human node finishes instantly after input - NodeRunStartedEvent, # llm_resume begins streaming - NodeRunStreamChunkEvent, # llm_resume chunk 1 - NodeRunStreamChunkEvent, # llm_resume chunk 2 - NodeRunStreamChunkEvent, # llm_resume final chunk - NodeRunSucceededEvent, # llm_resume completes streaming - NodeRunStartedEvent, # end node starts - NodeRunSucceededEvent, # end node finishes - GraphRunSucceededEvent, # graph run succeeds after resume - ] - - mock_get_repo = MagicMock(spec=HumanInputFormRepository) - submitted_form = MagicMock(spec=HumanInputFormEntity) - submitted_form.id = mock_form_entity.id - submitted_form.submission_token = mock_form_entity.submission_token - submitted_form.recipients = [] - submitted_form.rendered_content = mock_form_entity.rendered_content - submitted_form.submitted = True - submitted_form.selected_action_id = "accept" - submitted_form.submitted_data = {} - submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - mock_get_repo.get_form.return_value = submitted_form - - def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]: - # restruct the graph runtime state - serialized_runtime_state = initial_result.graph_runtime_state.dumps() - resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) - return _build_llm_human_llm_graph( - mock_config, - mock_get_repo, - resume_runtime_state, - ) - - resume_case = WorkflowTestCase( - description="HumanInput resume continues LLM streaming order", - graph_factory=resume_graph_factory, - expected_event_sequence=expected_resume_sequence, - ) - - resume_result = runner.run_test_case(resume_case) - - assert resume_result.success, resume_result.event_mismatch_details - - resume_events = resume_result.events - - success_index = next(i for i, event in enumerate(resume_events) if isinstance(event, GraphRunSucceededEvent)) - llm_resume_succeeded_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume" - ) - assert llm_resume_succeeded_index < success_index - - resume_chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)] - assert [event.node_id for event in resume_chunk_events[:3]] == ["llm_initial"] * 3 - assert [event.chunk for event in resume_chunk_events[:3]] == initial_chunks - assert resume_chunk_events[3].node_id == "end" - assert resume_chunk_events[3].chunk == "\n" - assert [event.node_id for event in resume_chunk_events[4:]] == ["llm_resume"] * 3 - assert [event.chunk for event in resume_chunk_events[4:]] == resume_chunks - - human_success_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human" - ) - cached_chunk_indices = [ - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id in {"llm_initial", "end"} - ] - assert all(index < human_success_index for index in cached_chunk_indices) - - llm_resume_start_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == "llm_resume" - ) - llm_resume_success_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume" - ) - llm_resume_chunk_indices = [ - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == "llm_resume" - ] - assert llm_resume_chunk_indices - first_resume_chunk_index = min(llm_resume_chunk_indices) - last_resume_chunk_index = max(llm_resume_chunk_indices) - assert llm_resume_start_index < first_resume_chunk_index - assert last_resume_chunk_index < llm_resume_success_index - - started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)] - assert started_nodes == ["human", "llm_resume", "end"] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py deleted file mode 100644 index 246df45d5f..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ /dev/null @@ -1,324 +0,0 @@ -import time -from unittest import mock - -from core.workflow.system_variables import build_system_variables -from graphon.graph import Graph -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.if_else.entities import IfElseNodeData -from graphon.nodes.if_else.if_else_node import IfElseNode -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.utils.condition.entities import Condition -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - graph_config=graph_config, - user_from="account", - invoke_from="debugger", - ) - - variable_pool = VariablePool( - system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"), - user_inputs={}, - conversation_variables=[], - ) - variable_pool.add(("branch", "value"), branch_value) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: - llm_data = LLMNodeData( - title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=prompt_text, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - llm_node = MockLLMNode( - id=node_id, - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - credentials_provider=mock.Mock(), - model_factory=mock.Mock(), - ) - return llm_node - - llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream") - - if_else_data = IfElseNodeData( - title="IfElse", - cases=[ - IfElseNodeData.Case( - case_id="primary", - logical_operator="and", - conditions=[ - Condition(variable_selector=["branch", "value"], comparison_operator="is", value="primary") - ], - ), - IfElseNodeData.Case( - case_id="secondary", - logical_operator="and", - conditions=[ - Condition(variable_selector=["branch", "value"], comparison_operator="is", value="secondary") - ], - ), - ], - ) - if_else_config = {"id": "if_else", "data": if_else_data.model_dump()} - if_else_node = IfElseNode( - id=if_else_config["id"], - config=if_else_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") - llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary") - - end_primary_data = EndNodeData( - title="End Primary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="primary_text", value_type=OutputVariableType.STRING, value_selector=["llm_primary", "text"] - ), - ], - desc=None, - ) - end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()} - end_primary = EndNode( - id=end_primary_config["id"], - config=end_primary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - end_secondary_data = EndNodeData( - title="End Secondary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="secondary_text", - value_type=OutputVariableType.STRING, - value_selector=["llm_secondary", "text"], - ), - ], - desc=None, - ) - end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()} - end_secondary = EndNode( - id=end_secondary_config["id"], - config=end_secondary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = ( - Graph.new() - .add_root(start_node) - .add_node(llm_initial) - .add_node(if_else_node) - .add_node(llm_primary, from_node_id="if_else", source_handle="primary") - .add_node(end_primary, from_node_id="llm_primary") - .add_node(llm_secondary, from_node_id="if_else", source_handle="secondary") - .add_node(end_secondary, from_node_id="llm_secondary") - .build() - ) - return graph, graph_runtime_state - - -def _expected_mock_llm_chunks(text: str) -> list[str]: - chunks: list[str] = [] - for index, word in enumerate(text.split(" ")): - chunk = word if index == 0 else f" {word}" - chunks.append(chunk) - chunks.append("") - return chunks - - -def test_if_else_llm_streaming_order() -> None: - mock_config = MockConfig() - mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"}) - mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"}) - mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"}) - - scenarios = [ - { - "branch": "primary", - "resume_llm": "llm_primary", - "end_node": "end_primary", - "expected_sequence": [ - GraphRunStartedEvent, # graph run begins - NodeRunStartedEvent, # start node begins execution - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial starts and streams - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # if_else evaluates conditions - NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed - NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed - NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed - NodeRunStreamChunkEvent, # template literal newline emitted - NodeRunSucceededEvent, # if_else completes branch selection - NodeRunStartedEvent, # llm_primary begins streaming - NodeRunStreamChunkEvent, # llm_primary chunk 1 - NodeRunStreamChunkEvent, # llm_primary chunk 2 - NodeRunStreamChunkEvent, # llm_primary chunk 3 - NodeRunStreamChunkEvent, # llm_primary final chunk - NodeRunSucceededEvent, # llm_primary completes streaming - NodeRunStartedEvent, # end_primary node starts - NodeRunSucceededEvent, # end_primary finishes aggregation - GraphRunSucceededEvent, # graph run succeeds - ], - "expected_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), - ("end_primary", ["\n"]), - ("llm_primary", _expected_mock_llm_chunks("Primary stream output")), - ], - }, - { - "branch": "secondary", - "resume_llm": "llm_secondary", - "end_node": "end_secondary", - "expected_sequence": [ - GraphRunStartedEvent, # graph run begins - NodeRunStartedEvent, # start node begins execution - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial starts and streams - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # if_else evaluates conditions - NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed - NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed - NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed - NodeRunStreamChunkEvent, # template literal newline emitted - NodeRunSucceededEvent, # if_else completes branch selection - NodeRunStartedEvent, # llm_secondary begins streaming - NodeRunStreamChunkEvent, # llm_secondary chunk 1 - NodeRunStreamChunkEvent, # llm_secondary final chunk - NodeRunSucceededEvent, # llm_secondary completes - NodeRunStartedEvent, # end_secondary node starts - NodeRunSucceededEvent, # end_secondary finishes aggregation - GraphRunSucceededEvent, # graph run succeeds - ], - "expected_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), - ("end_secondary", ["\n"]), - ("llm_secondary", _expected_mock_llm_chunks("Secondary")), - ], - }, - ] - - for scenario in scenarios: - runner = TableTestRunner() - - def graph_factory( - branch_value: str = scenario["branch"], - cfg: MockConfig = mock_config, - ) -> tuple[Graph, GraphRuntimeState]: - return _build_if_else_graph(branch_value, cfg) - - test_case = WorkflowTestCase( - description=f"IfElse streaming via {scenario['branch']} branch", - graph_factory=graph_factory, - expected_event_sequence=scenario["expected_sequence"], - ) - - result = runner.run_test_case(test_case) - - assert result.success, result.event_mismatch_details - - chunk_events = [event for event in result.events if isinstance(event, NodeRunStreamChunkEvent)] - expected_nodes: list[str] = [] - expected_chunks: list[str] = [] - for node_id, chunks in scenario["expected_chunks"]: - expected_nodes.extend([node_id] * len(chunks)) - expected_chunks.extend(chunks) - assert [event.node_id for event in chunk_events] == expected_nodes - assert [event.chunk for event in chunk_events] == expected_chunks - - branch_node_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == "if_else" - ) - branch_success_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "if_else" - ) - pre_branch_chunk_indices = [ - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStreamChunkEvent) and index < branch_success_index - ] - assert len(pre_branch_chunk_indices) == len(_expected_mock_llm_chunks("Initial stream")) + 1 - assert min(pre_branch_chunk_indices) == branch_node_index + 1 - assert max(pre_branch_chunk_indices) < branch_success_index - - resume_chunk_indices = [ - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"] - ] - assert resume_chunk_indices - resume_start_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"] - ) - resume_success_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"] - ) - assert resume_start_index < min(resume_chunk_indices) - assert max(resume_chunk_indices) < resume_success_index - - started_nodes = [event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)] - assert started_nodes == ["start", "llm_initial", "if_else", scenario["resume_llm"], scenario["end_node"]] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py b/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py deleted file mode 100644 index b9bf4be13a..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py +++ /dev/null @@ -1,126 +0,0 @@ -""" -Test cases for the Iteration node's flatten_output functionality. - -This module tests the iteration node's ability to: -1. Flatten array outputs when flatten_output=True (default) -2. Preserve nested array structure when flatten_output=False -""" - -from .test_database_utils import skip_if_database_unavailable -from .test_mock_config import MockConfigBuilder, NodeMockConfig -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _create_iteration_mock_config(): - """Helper to create a mock config for iteration tests.""" - - def code_inner_handler(node): - pool = node.graph_runtime_state.variable_pool - item_seg = pool.get(["iteration_node", "item"]) - if item_seg is not None: - item = item_seg.to_object() - return {"result": [item, item * 2]} - # This fallback is likely unreachable, but if it is, - # it doesn't simulate iteration with different values as the comment suggests. - return {"result": [1, 2]} - - return ( - MockConfigBuilder() - .with_node_output("code_node", {"result": [1, 2, 3]}) - .with_node_config(NodeMockConfig(node_id="code_inner_node", custom_handler=code_inner_handler)) - .build() - ) - - -@skip_if_database_unavailable() -def test_iteration_with_flatten_output_enabled(): - """ - Test iteration node with flatten_output=True (default behavior). - - The fixture implements an iteration that: - 1. Iterates over [1, 2, 3] - 2. For each item, outputs [item, item*2] - 3. With flatten_output=True, should output [1, 2, 2, 4, 3, 6] - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="iteration_flatten_output_enabled_workflow", - inputs={}, - expected_outputs={"output": [1, 2, 2, 4, 3, 6]}, - description="Iteration with flatten_output=True flattens nested arrays", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs is not None, "Should have outputs" - assert result.actual_outputs == {"output": [1, 2, 2, 4, 3, 6]}, ( - f"Expected flattened output [1, 2, 2, 4, 3, 6], got {result.actual_outputs}" - ) - - -@skip_if_database_unavailable() -def test_iteration_with_flatten_output_disabled(): - """ - Test iteration node with flatten_output=False. - - The fixture implements an iteration that: - 1. Iterates over [1, 2, 3] - 2. For each item, outputs [item, item*2] - 3. With flatten_output=False, should output [[1, 2], [2, 4], [3, 6]] - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="iteration_flatten_output_disabled_workflow", - inputs={}, - expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]}, - description="Iteration with flatten_output=False preserves nested structure", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs is not None, "Should have outputs" - assert result.actual_outputs == {"output": [[1, 2], [2, 4], [3, 6]]}, ( - f"Expected nested output [[1, 2], [2, 4], [3, 6]], got {result.actual_outputs}" - ) - - -@skip_if_database_unavailable() -def test_iteration_flatten_output_comparison(): - """ - Run both flatten_output configurations in parallel to verify the difference. - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="iteration_flatten_output_enabled_workflow", - inputs={}, - expected_outputs={"output": [1, 2, 2, 4, 3, 6]}, - description="flatten_output=True: Flattened output", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ), - WorkflowTestCase( - fixture_path="iteration_flatten_output_disabled_workflow", - inputs={}, - expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]}, - description="flatten_output=False: Nested output", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ), - ] - - suite_result = runner.run_table_tests(test_cases, parallel=True) - - # Assert all tests passed - assert suite_result.passed_tests == 2, f"Expected 2 passed tests, got {suite_result.passed_tests}" - assert suite_result.failed_tests == 0, f"Expected 0 failed tests, got {suite_result.failed_tests}" - assert suite_result.success_rate == 100.0, f"Expected 100% success rate, got {suite_result.success_rate}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py deleted file mode 100644 index 821da46b76..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Test case for loop with inner answer output error scenario. - -This test validates the behavior of a loop containing an answer node -inside the loop that may produce output errors. -""" - -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_loop_contains_answer(): - """ - Test loop with inner answer node that may have output errors. - - The fixture implements a loop that: - 1. Iterates 4 times (index 0-3) - 2. Contains an inner answer node that outputs index and item values - 3. Has a break condition when index equals 4 - 4. Tests error handling for answer nodes within loops - """ - fixture_name = "loop_contains_answer" - mock_config = MockConfigBuilder().build() - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - query="1", - expected_outputs={"answer": "1\n2\n1 + 2"}, - expected_event_sequence=[ - # Graph start - GraphRunStartedEvent, - # Start - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Loop start - NodeRunStartedEvent, - NodeRunLoopStartedEvent, - # Variable assigner - NodeRunStartedEvent, - NodeRunVariableUpdatedEvent, - NodeRunStreamChunkEvent, # 1 - NodeRunStreamChunkEvent, # \n - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Loop next - NodeRunLoopNextEvent, - # Variable assigner - NodeRunStartedEvent, - NodeRunVariableUpdatedEvent, - NodeRunStreamChunkEvent, # 2 - NodeRunStreamChunkEvent, # \n - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Loop end - NodeRunLoopSucceededEvent, - NodeRunStreamChunkEvent, # 1 - NodeRunStreamChunkEvent, # + - NodeRunStreamChunkEvent, # 2 - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Graph end - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py deleted file mode 100644 index ad8d777ea6..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -Test cases for the Loop node functionality using TableTestRunner. - -This module tests the loop node's ability to: -1. Execute iterations with loop variables -2. Handle break conditions correctly -3. Update and propagate loop variables between iterations -4. Output the final loop variable value -""" - -from tests.unit_tests.core.workflow.graph_engine.test_table_runner import ( - TableTestRunner, - WorkflowTestCase, -) - - -def test_loop_with_break_condition(): - """ - Test loop node with break condition. - - The increment_loop_with_break_condition_workflow.yml fixture implements a loop that: - 1. Starts with num=1 - 2. Increments num by 1 each iteration - 3. Breaks when num >= 5 - 4. Should output {"num": 5} - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="increment_loop_with_break_condition_workflow", - inputs={}, # No inputs needed for this test - expected_outputs={"num": 5}, - description="Loop with break condition when num >= 5", - ) - - result = runner.run_test_case(test_case) - - # Assert the test passed - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs is not None, "Should have outputs" - assert result.actual_outputs == {"num": 5}, f"Expected {{'num': 5}}, got {result.actual_outputs}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py deleted file mode 100644 index 4a60c7769c..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py +++ /dev/null @@ -1,72 +0,0 @@ -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_loop_with_tool(): - fixture_name = "search_dify_from_2023_to_2025" - mock_config = ( - MockConfigBuilder() - .with_tool_response( - { - "text": "mocked search result", - } - ) - .build() - ) - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - expected_outputs={ - "answer": """- mocked search result -- mocked search result""" - }, - expected_event_sequence=[ - GraphRunStartedEvent, - # START - NodeRunStartedEvent, - NodeRunSucceededEvent, - # LOOP START - NodeRunStartedEvent, - NodeRunLoopStartedEvent, - # 2023 - NodeRunStartedEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunVariableUpdatedEvent, - NodeRunVariableUpdatedEvent, - NodeRunSucceededEvent, - NodeRunLoopNextEvent, - # 2024 - NodeRunStartedEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunVariableUpdatedEvent, - NodeRunVariableUpdatedEvent, - NodeRunSucceededEvent, - # LOOP END - NodeRunLoopSucceededEvent, - NodeRunStreamChunkEvent, # loop.res - NodeRunSucceededEvent, - # ANSWER - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py deleted file mode 100644 index c511548749..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py +++ /dev/null @@ -1,281 +0,0 @@ -""" -Example demonstrating the auto-mock system for testing workflows. - -This example shows how to test workflows with third-party service nodes -without making actual API calls. -""" - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def example_test_llm_workflow(): - """ - Example: Testing a workflow with an LLM node. - - This demonstrates how to test a workflow that uses an LLM service - without making actual API calls to OpenAI, Anthropic, etc. - """ - print("\n=== Example: Testing LLM Workflow ===\n") - - # Initialize the test runner - runner = TableTestRunner() - - # Configure mock responses - mock_config = MockConfigBuilder().with_llm_response("I'm a helpful AI assistant. How can I help you today?").build() - - # Define the test case - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Hello, AI!"}, - expected_outputs={"answer": "I'm a helpful AI assistant. How can I help you today?"}, - description="Testing LLM workflow with mocked response", - use_auto_mock=True, # Enable auto-mocking - mock_config=mock_config, - ) - - # Run the test - result = runner.run_test_case(test_case) - - if result.success: - print("✅ Test passed!") - print(f" Input: {test_case.inputs['query']}") - print(f" Output: {result.actual_outputs['answer']}") - print(f" Execution time: {result.execution_time:.2f}s") - else: - print(f"❌ Test failed: {result.error}") - - return result.success - - -def example_test_with_custom_outputs(): - """ - Example: Testing with custom outputs for specific nodes. - - This shows how to provide different mock outputs for specific node IDs, - useful when testing complex workflows with multiple LLM/tool nodes. - """ - print("\n=== Example: Custom Node Outputs ===\n") - - runner = TableTestRunner() - - # Configure mock with specific outputs for different nodes - mock_config = MockConfigBuilder().build() - - # Set custom output for a specific LLM node - mock_config.set_node_outputs( - "llm_node", - { - "text": "This is a custom response for the specific LLM node", - "usage": { - "prompt_tokens": 50, - "completion_tokens": 20, - "total_tokens": 70, - }, - "finish_reason": "stop", - }, - ) - - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Tell me about custom outputs"}, - expected_outputs={"answer": "This is a custom response for the specific LLM node"}, - description="Testing with custom node outputs", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if result.success: - print("✅ Test with custom outputs passed!") - print(f" Custom output: {result.actual_outputs['answer']}") - else: - print(f"❌ Test failed: {result.error}") - - return result.success - - -def example_test_http_and_tool_workflow(): - """ - Example: Testing a workflow with HTTP request and tool nodes. - - This demonstrates mocking external HTTP calls and tool executions. - """ - print("\n=== Example: HTTP and Tool Workflow ===\n") - - runner = TableTestRunner() - - # Configure mocks for HTTP and Tool nodes - mock_config = MockConfigBuilder().build() - - # Mock HTTP response - mock_config.set_node_outputs( - "http_node", - { - "status_code": 200, - "body": '{"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}', - "headers": {"content-type": "application/json"}, - }, - ) - - # Mock tool response (e.g., JSON parser) - mock_config.set_node_outputs( - "tool_node", - { - "result": {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}, - }, - ) - - test_case = WorkflowTestCase( - fixture_path="http-tool-workflow", - inputs={"url": "https://api.example.com/users"}, - expected_outputs={ - "status_code": 200, - "parsed_data": {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}, - }, - description="Testing HTTP and Tool workflow", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if result.success: - print("✅ HTTP and Tool workflow test passed!") - print(f" HTTP Status: {result.actual_outputs['status_code']}") - print(f" Parsed Data: {result.actual_outputs['parsed_data']}") - else: - print(f"❌ Test failed: {result.error}") - - return result.success - - -def example_test_error_simulation(): - """ - Example: Simulating errors in specific nodes. - - This shows how to test error handling in workflows by simulating - failures in specific nodes. - """ - print("\n=== Example: Error Simulation ===\n") - - runner = TableTestRunner() - - # Configure mock to simulate an error - mock_config = MockConfigBuilder().build() - mock_config.set_node_error("llm_node", "API rate limit exceeded") - - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "This will fail"}, - expected_outputs={}, # We expect failure - description="Testing error handling", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if not result.success: - print("✅ Error simulation worked as expected!") - print(f" Simulated error: {result.error}") - else: - print("❌ Expected failure but test succeeded") - - return not result.success # Success means we got the expected error - - -def example_test_with_delays(): - """ - Example: Testing with simulated execution delays. - - This demonstrates how to simulate realistic execution times - for performance testing. - """ - print("\n=== Example: Simulated Delays ===\n") - - runner = TableTestRunner() - - # Configure mock with delays - mock_config = ( - MockConfigBuilder() - .with_delays(True) # Enable delay simulation - .with_llm_response("Response after delay") - .build() - ) - - # Add specific delay for the LLM node - from .test_mock_config import NodeMockConfig - - node_config = NodeMockConfig( - node_id="llm_node", - outputs={"text": "Response after delay"}, - delay=0.5, # 500ms delay - ) - mock_config.set_node_config("llm_node", node_config) - - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Test with delay"}, - expected_outputs={"answer": "Response after delay"}, - description="Testing with simulated delays", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if result.success: - print("✅ Delay simulation test passed!") - print(f" Execution time: {result.execution_time:.2f}s") - print(" (Should be >= 0.5s due to simulated delay)") - else: - print(f"❌ Test failed: {result.error}") - - return result.success and result.execution_time >= 0.5 - - -def run_all_examples(): - """Run all example tests.""" - print("\n" + "=" * 50) - print("AUTO-MOCK SYSTEM EXAMPLES") - print("=" * 50) - - examples = [ - example_test_llm_workflow, - example_test_with_custom_outputs, - example_test_http_and_tool_workflow, - example_test_error_simulation, - example_test_with_delays, - ] - - results = [] - for example in examples: - try: - results.append(example()) - except Exception as e: - print(f"\n❌ Example failed with exception: {e}") - results.append(False) - - print("\n" + "=" * 50) - print("SUMMARY") - print("=" * 50) - - passed = sum(results) - total = len(results) - print(f"\n✅ Passed: {passed}/{total}") - - if passed == total: - print("\n🎉 All examples passed successfully!") - else: - print(f"\n⚠️ {total - passed} example(s) failed") - - return passed == total - - -if __name__ == "__main__": - import sys - - success = run_all_examples() - sys.exit(0 if success else 1) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 76b2984a4b..88989db856 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -7,11 +7,12 @@ requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request from typing import TYPE_CHECKING, Any -from core.workflow.node_factory import DifyNodeFactory from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node +from core.workflow.node_factory import DifyNodeFactory + from .test_mock_nodes import ( MockAgentNode, MockCodeNode, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py deleted file mode 100644 index aff479104f..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ /dev/null @@ -1,199 +0,0 @@ -""" -Simple test to verify MockNodeFactory works with iteration nodes. -""" - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from graphon.enums import BuiltinNodeTypes -from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder -from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory - - -def test_mock_factory_registers_iteration_node(): - """Test that MockNodeFactory has iteration node registered.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create a MockNodeFactory instance - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={"nodes": [], "edges": []}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # Check that iteration node is registered - assert BuiltinNodeTypes.ITERATION in factory._mock_node_types - print("✓ Iteration node is registered in MockNodeFactory") - - # Check that loop node is registered - assert BuiltinNodeTypes.LOOP in factory._mock_node_types - print("✓ Loop node is registered in MockNodeFactory") - - # Check the class types - from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode, MockLoopNode - - assert factory._mock_node_types[BuiltinNodeTypes.ITERATION] == MockIterationNode - print("✓ Iteration node maps to MockIterationNode class") - - assert factory._mock_node_types[BuiltinNodeTypes.LOOP] == MockLoopNode - print("✓ Loop node maps to MockLoopNode class") - - -def test_mock_iteration_node_preserves_config(): - """Test that MockIterationNode preserves mock configuration.""" - - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode - - # Create mock config - mock_config = MockConfigBuilder().with_llm_response("Test response").build() - - # Create minimal graph init params - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={"nodes": [], "edges": []}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - - # Create minimal runtime state - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - - # Create mock iteration node - node_config = { - "id": "iter1", - "data": { - "type": "iteration", - "title": "Test", - "iterator_selector": ["start", "items"], - "output_selector": ["node", "text"], - "start_node_id": "node1", - }, - } - - mock_node = MockIterationNode( - id="iter1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Verify the mock config is preserved - assert mock_node.mock_config == mock_config - print("✓ MockIterationNode preserves mock configuration") - - # Check that _create_graph_engine method exists and is overridden - assert hasattr(mock_node, "_create_graph_engine") - assert MockIterationNode._create_graph_engine != MockIterationNode.__bases__[1]._create_graph_engine - print("✓ MockIterationNode overrides _create_graph_engine method") - - -def test_mock_loop_node_preserves_config(): - """Test that MockLoopNode preserves mock configuration.""" - - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode - - # Create mock config - mock_config = MockConfigBuilder().with_http_response({"status": 200}).build() - - # Create minimal graph init params - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={"nodes": [], "edges": []}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - - # Create minimal runtime state - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - - # Create mock loop node - node_config = { - "id": "loop1", - "data": { - "type": "loop", - "title": "Test", - "loop_count": 3, - "start_node_id": "node1", - "loop_variables": [], - "outputs": {}, - "break_conditions": [], - "logical_operator": "and", - }, - } - - mock_node = MockLoopNode( - id="loop1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Verify the mock config is preserved - assert mock_node.mock_config == mock_config - print("✓ MockLoopNode preserves mock configuration") - - # Check that _create_graph_engine method exists and is overridden - assert hasattr(mock_node, "_create_graph_engine") - assert MockLoopNode._create_graph_engine != MockLoopNode.__bases__[1]._create_graph_engine - print("✓ MockLoopNode overrides _create_graph_engine method") - - -if __name__ == "__main__": - test_mock_factory_registers_iteration_node() - test_mock_iteration_node_preserves_config() - test_mock_loop_node_preserves_config() - print("\n✅ All tests passed! MockNodeFactory now supports iteration and loop nodes.") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 971b9b2bbf..8b7fbd1b30 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -10,10 +10,6 @@ from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock -from core.model_manager import ModelInstance -from core.workflow.node_runtime import DifyToolNodeRuntime -from core.workflow.nodes.agent import AgentNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent @@ -31,6 +27,11 @@ from graphon.nodes.template_transform import TemplateTransformNode from graphon.nodes.tool import ToolNode from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError +from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.nodes.agent import AgentNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode + if TYPE_CHECKING: from graphon.entities import GraphInitParams from graphon.runtime import GraphRuntimeState diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py deleted file mode 100644 index 15f6f51398..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ /dev/null @@ -1,670 +0,0 @@ -""" -Test cases for Mock Template Transform and Code nodes. - -This module tests the functionality of MockTemplateTransformNode and MockCodeNode -to ensure they work correctly with the TableTestRunner. -""" - -from configs import dify_config -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.nodes.code.limits import CodeNodeLimits -from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig -from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory -from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockCodeNode, MockTemplateTransformNode - -DEFAULT_CODE_LIMITS = CodeNodeLimits( - max_string_length=dify_config.CODE_MAX_STRING_LENGTH, - max_number=dify_config.CODE_MAX_NUMBER, - min_number=dify_config.CODE_MIN_NUMBER, - max_precision=dify_config.CODE_MAX_PRECISION, - max_depth=dify_config.CODE_MAX_DEPTH, - max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, - max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, - max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, -) - - -class _NoopCodeExecutor: - def execute(self, *, language: object, code: str, inputs: dict[str, object]) -> dict[str, object]: - _ = (language, code, inputs) - return {} - - def is_execution_error(self, error: Exception) -> bool: - _ = error - return False - - -class TestMockTemplateTransformNode: - """Test cases for MockTemplateTransformNode.""" - - def test_mock_template_transform_node_default_output(self): - """Test that MockTemplateTransformNode processes templates with Jinja2.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "output" in result.outputs - # The template "Hello {{ name }}" with no name variable renders as "Hello " - assert result.outputs["output"] == "Hello " - - def test_mock_template_transform_node_custom_output(self): - """Test that MockTemplateTransformNode returns custom configured output.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config with custom output - mock_config = ( - MockConfigBuilder().with_node_output("template_node_1", {"output": "Custom template output"}).build() - ) - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "output" in result.outputs - assert result.outputs["output"] == "Custom template output" - - def test_mock_template_transform_node_error_simulation(self): - """Test that MockTemplateTransformNode can simulate errors.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config with error - mock_config = MockConfigBuilder().with_node_error("template_node_1", "Simulated template error").build() - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.FAILED - assert result.error == "Simulated template error" - - def test_mock_template_transform_node_with_variables(self): - """Test that MockTemplateTransformNode processes templates with variables.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - from graphon.variables import StringVariable - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - # Add a variable to the pool - variable_pool.add(["test", "name"], StringVariable(name="name", value="World", selector=["test", "name"])) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config with a variable - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [{"variable": "name", "value_selector": ["test", "name"]}], - "template": "Hello {{ name }}!", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "output" in result.outputs - assert result.outputs["output"] == "Hello World!" - - -class TestMockCodeNode: - """Test cases for MockCodeNode.""" - - def test_mock_code_node_default_output(self): - """Test that MockCodeNode returns default output.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "result = 'test'", - "outputs": {}, # Empty outputs for default case - }, - } - - # Create mock node - mock_node = MockCodeNode( - id="code_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - code_executor=_NoopCodeExecutor(), - code_limits=DEFAULT_CODE_LIMITS, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "result" in result.outputs - assert result.outputs["result"] == "mocked code execution result" - - def test_mock_code_node_with_output_schema(self): - """Test that MockCodeNode generates outputs based on schema.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config with output schema - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "name = 'test'\ncount = 42\nitems = ['a', 'b']", - "outputs": { - "name": {"type": "string"}, - "count": {"type": "number"}, - "items": {"type": "array[string]"}, - }, - }, - } - - # Create mock node - mock_node = MockCodeNode( - id="code_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - code_executor=_NoopCodeExecutor(), - code_limits=DEFAULT_CODE_LIMITS, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "name" in result.outputs - assert result.outputs["name"] == "mocked_name" - assert "count" in result.outputs - assert result.outputs["count"] == 42 - assert "items" in result.outputs - assert result.outputs["items"] == ["item1", "item2"] - - def test_mock_code_node_custom_output(self): - """Test that MockCodeNode returns custom configured output.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config with custom output - mock_config = ( - MockConfigBuilder() - .with_node_output("code_node_1", {"result": "Custom code result", "status": "success"}) - .build() - ) - - # Create node config - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "result = 'test'", - "outputs": {}, # Empty outputs for default case - }, - } - - # Create mock node - mock_node = MockCodeNode( - id="code_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - code_executor=_NoopCodeExecutor(), - code_limits=DEFAULT_CODE_LIMITS, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "result" in result.outputs - assert result.outputs["result"] == "Custom code result" - assert "status" in result.outputs - assert result.outputs["status"] == "success" - - -class TestMockNodeFactory: - """Test cases for MockNodeFactory with new node types.""" - - def test_code_and_template_nodes_mocked_by_default(self): - """Test that CODE and TEMPLATE_TRANSFORM nodes are mocked by default (they require SSRF proxy).""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create factory - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - # Verify that CODE and TEMPLATE_TRANSFORM ARE mocked by default (they require SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Verify that other third-party service nodes ARE also mocked by default - assert factory.should_mock_node(BuiltinNodeTypes.LLM) - assert factory.should_mock_node(BuiltinNodeTypes.AGENT) - - def test_factory_creates_mock_template_transform_node(self): - """Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create factory - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create node through factory - node = factory.create_node(node_config) - - # Verify the correct mock type was created - assert isinstance(node, MockTemplateTransformNode) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - def test_factory_creates_mock_code_node(self): - """Test that MockNodeFactory creates MockCodeNode for code type.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create factory - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - # Create node config - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "result = 42", - "outputs": {}, # Required field for CodeNodeData - }, - } - - # Create node through factory - node = factory.create_node(node_config) - - # Verify the correct mock type was created - assert isinstance(node, MockCodeNode) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py deleted file mode 100644 index cb5200f8dc..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -Simple test to validate the auto-mock system without external dependencies. -""" - -import sys - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from graphon.enums import BuiltinNodeTypes -from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig -from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory - - -def test_mock_config_builder(): - """Test the MockConfigBuilder fluent interface.""" - print("Testing MockConfigBuilder...") - - config = ( - MockConfigBuilder() - .with_llm_response("LLM response") - .with_agent_response("Agent response") - .with_tool_response({"tool": "output"}) - .with_retrieval_response("Retrieval content") - .with_http_response({"status_code": 201, "body": "created"}) - .with_node_output("node1", {"output": "value"}) - .with_node_error("node2", "error message") - .with_delays(True) - .build() - ) - - assert config.default_llm_response == "LLM response" - assert config.default_agent_response == "Agent response" - assert config.default_tool_response == {"tool": "output"} - assert config.default_retrieval_response == "Retrieval content" - assert config.default_http_response == {"status_code": 201, "body": "created"} - assert config.simulate_delays is True - - node1_config = config.get_node_config("node1") - assert node1_config is not None - assert node1_config.outputs == {"output": "value"} - - node2_config = config.get_node_config("node2") - assert node2_config is not None - assert node2_config.error == "error message" - - print("✓ MockConfigBuilder test passed") - - -def test_mock_config_operations(): - """Test MockConfig operations.""" - print("Testing MockConfig operations...") - - config = MockConfig() - - # Test setting node outputs - config.set_node_outputs("test_node", {"result": "test_value"}) - node_config = config.get_node_config("test_node") - assert node_config is not None - assert node_config.outputs == {"result": "test_value"} - - # Test setting node error - config.set_node_error("error_node", "Test error") - error_config = config.get_node_config("error_node") - assert error_config is not None - assert error_config.error == "Test error" - - # Test default configs by node type - config.set_default_config(BuiltinNodeTypes.LLM, {"temperature": 0.7}) - llm_config = config.get_default_config(BuiltinNodeTypes.LLM) - assert llm_config == {"temperature": 0.7} - - print("✓ MockConfig operations test passed") - - -def test_node_mock_config(): - """Test NodeMockConfig.""" - print("Testing NodeMockConfig...") - - # Test with custom handler - def custom_handler(node): - return {"custom": "output"} - - node_config = NodeMockConfig( - node_id="test_node", outputs={"text": "test"}, error=None, delay=0.5, custom_handler=custom_handler - ) - - assert node_config.node_id == "test_node" - assert node_config.outputs == {"text": "test"} - assert node_config.delay == 0.5 - assert node_config.custom_handler is not None - - # Test custom handler - result = node_config.custom_handler(None) - assert result == {"custom": "output"} - - print("✓ NodeMockConfig test passed") - - -def test_mock_factory_detection(): - """Test MockNodeFactory node type detection.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - print("Testing MockNodeFactory detection...") - - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # Test that third-party service nodes are identified for mocking - assert factory.should_mock_node(BuiltinNodeTypes.LLM) - assert factory.should_mock_node(BuiltinNodeTypes.AGENT) - assert factory.should_mock_node(BuiltinNodeTypes.TOOL) - assert factory.should_mock_node(BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL) - assert factory.should_mock_node(BuiltinNodeTypes.HTTP_REQUEST) - assert factory.should_mock_node(BuiltinNodeTypes.PARAMETER_EXTRACTOR) - assert factory.should_mock_node(BuiltinNodeTypes.DOCUMENT_EXTRACTOR) - - # Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Test that non-service nodes are not mocked - assert not factory.should_mock_node(BuiltinNodeTypes.START) - assert not factory.should_mock_node(BuiltinNodeTypes.END) - assert not factory.should_mock_node(BuiltinNodeTypes.IF_ELSE) - assert not factory.should_mock_node(BuiltinNodeTypes.VARIABLE_AGGREGATOR) - - print("✓ MockNodeFactory detection test passed") - - -def test_mock_factory_registration(): - """Test registering and unregistering mock node types.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - print("Testing MockNodeFactory registration...") - - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Unregister mock - factory.unregister_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - assert not factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Register custom mock (using a dummy class for testing) - class DummyMockNode: - pass - - factory.register_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM, DummyMockNode) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - print("✓ MockNodeFactory registration test passed") - - -def run_all_tests(): - """Run all tests.""" - print("\n=== Running Auto-Mock System Tests ===\n") - - try: - test_mock_config_builder() - test_mock_config_operations() - test_node_mock_config() - test_mock_factory_detection() - test_mock_factory_registration() - - print("\n=== All tests passed! ✅ ===\n") - return True - except AssertionError as e: - print(f"\n❌ Test failed: {e}") - return False - except Exception as e: - print(f"\n❌ Unexpected error: {e}") - import traceback - - traceback.print_exc() - return False - - -if __name__ == "__main__": - success = run_all_tests() - sys.exit(0 if success else 1) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index 37b43bd374..8311a1e847 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,18 +4,10 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.entities import WorkflowStartReason from graphon.graph import Graph -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from graphon.graph_engine.config import GraphEngineConfig -from graphon.graph_engine.graph_engine import GraphEngine +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, @@ -31,6 +23,14 @@ from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState, VariablePool + +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py deleted file mode 100644 index 59e54bd39a..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ /dev/null @@ -1,336 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.graph import Graph -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from graphon.graph_engine.config import GraphEngineConfig -from graphon.graph_engine.graph_engine import GraphEngine -from graphon.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig, NodeMockConfig -from .test_mock_nodes import MockLLMNode - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def submission_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: - self._forms_by_node_id = dict(forms_by_node_id) - - def get_form(self, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_node_id.get(node_id) - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in resume scenario") - - -class DelayedHumanInputNode(HumanInputNode): - def __init__(self, delay_seconds: float, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._delay_seconds = delay_seconds - - def _run(self): - if self._delay_seconds > 0: - time.sleep(self._delay_seconds) - yield from super()._run() - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - - human_a_config = {"id": "human_a", "data": human_data.model_dump()} - human_a = HumanInputNode( - id=human_a_config["id"], - config=human_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), - ) - - human_b_config = {"id": "human_b", "data": human_data.model_dump()} - human_b = DelayedHumanInputNode( - id=human_b_config["id"], - config=human_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), - delay_seconds=0.2, - ) - - llm_data = LLMNodeData( - title="LLM A", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt A", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_config = {"id": "llm_a", "data": llm_data.model_dump()} - llm_a = MockLLMNode( - id=llm_config["id"], - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_a, from_node_id="start") - .add_node(human_b, from_node_id="start") - .add_node(llm_a, from_node_id="human_a", source_handle="approve") - .build() - ) - - -def test_parallel_human_input_pause_preserves_node_finished() -> None: - runtime_state = _build_runtime_state() - - runtime_state.graph_execution.start() - runtime_state.register_paused_node("human_a") - runtime_state.register_paused_node("human_b") - - submitted = StaticForm( - form_id="form-a", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - pending = StaticForm( - form_id="form-b", - rendered="rendered", - is_submitted=False, - action_id=None, - data=None, - status_value=HumanInputFormStatus.WAITING, - ) - repo = StaticRepo({"human_a": submitted, "human_b": pending}) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - - graph = _build_graph(runtime_state, repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - events = list(engine.run()) - - llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) - llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) - human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) - graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) - graph_started = any(isinstance(e, GraphRunStartedEvent) for e in events) - - assert graph_started - assert graph_paused - assert human_b_pause - assert llm_started - assert llm_succeeded - - -def test_parallel_human_input_pause_preserves_node_finished_after_snapshot_resume() -> None: - base_state = _build_runtime_state() - base_state.graph_execution.start() - base_state.register_paused_node("human_a") - base_state.register_paused_node("human_b") - snapshot = base_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - submitted = StaticForm( - form_id="form-a", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - pending = StaticForm( - form_id="form-b", - rendered="rendered", - is_submitted=False, - action_id=None, - data=None, - status_value=HumanInputFormStatus.WAITING, - ) - repo = StaticRepo({"human_a": submitted, "human_b": pending}) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - - graph = _build_graph(resumed_state, repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=resumed_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - events = list(engine.run()) - - start_event = next(e for e in events if isinstance(e, GraphRunStartedEvent)) - assert start_event.reason is WorkflowStartReason.RESUMPTION - - llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) - llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) - human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) - graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) - - assert graph_paused - assert human_b_pause - assert llm_started - assert llm_succeeded diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py deleted file mode 100644 index 1a43734462..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ /dev/null @@ -1,286 +0,0 @@ -""" -Test for parallel streaming workflow behavior. - -This test validates that: -- LLM 1 always speaks English -- LLM 2 always speaks Chinese -- 2 LLMs run parallel, but LLM 2 will output before LLM 1 -- All chunks should be sent before Answer Node started -""" - -import time -from unittest.mock import MagicMock, patch -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.model_manager import ModelInstance -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from core.workflow.system_variables import build_system_variables -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_events import ( - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.node_events import NodeRunResult, StreamCompletedEvent -from graphon.nodes.llm.node import LLMNode -from graphon.runtime import GraphRuntimeState, VariablePool -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_table_runner import TableTestRunner - - -def create_llm_generator_with_delay(chunks: list[str], delay: float = 0.1): - """Create a generator that simulates LLM streaming output with delay""" - - def llm_generator(self): - for i, chunk in enumerate(chunks): - time.sleep(delay) # Simulate network delay - yield NodeRunStreamChunkEvent( - id=str(uuid4()), - node_id=self.id, - node_type=self.node_type, - selector=[self.id, "text"], - chunk=chunk, - is_final=i == len(chunks) - 1, - ) - - # Complete response - full_text = "".join(chunks) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": full_text}, - ) - ) - - return llm_generator - - -def test_parallel_streaming_workflow(): - """ - Test parallel streaming workflow to verify: - 1. All chunks from LLM 2 are output before LLM 1 - 2. At least one chunk from LLM 2 is output before LLM 1 completes (Success) - 3. At least one chunk from LLM 1 is output before LLM 2 completes (EXPECTED TO FAIL) - 4. All chunks are output before End begins - 5. The final output content matches the order defined in the Answer - - Test setup: - - LLM 1 outputs English (slower) - - LLM 2 outputs Chinese (faster) - - Both run in parallel - - This test is expected to FAIL because chunks are currently buffered - until after node completion instead of streaming during execution. - """ - runner = TableTestRunner() - - # Load the workflow configuration - fixture_data = runner.workflow_runner.load_fixture("multilingual_parallel_llm_streaming_workflow") - workflow_config = fixture_data.get("workflow", {}) - graph_config = workflow_config.get("graph", {}) - - # Create graph initialization parameters - init_params = build_test_graph_init_params( - workflow_id="test_workflow", - graph_config=graph_config, - tenant_id="test_tenant", - app_id="test_app", - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - call_depth=0, - ) - - # Create variable pool with system variables - system_variables = build_system_variables( - user_id="test_user", - app_id="test_app", - workflow_id=init_params.workflow_id, - files=[], - query="Tell me about yourself", # User query - ) - variable_pool = VariablePool( - system_variables=system_variables, - user_inputs={}, - ) - - # Create graph runtime state - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - # Create node factory and graph - node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state) - with patch.object( - DifyNodeFactory, "_build_model_instance_for_llm_node", return_value=MagicMock(spec=ModelInstance), autospec=True - ): - graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=get_default_root_node_id(graph_config), - ) - - # Create the graph engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Define LLM outputs - llm1_chunks = ["Hello", ", ", "I", " ", "am", " ", "an", " ", "AI", " ", "assistant", "."] # English (slower) - llm2_chunks = ["你好", ",", "我", "是", "AI", "助手", "。"] # Chinese (faster) - - # Create generators with different delays (LLM 2 is faster) - llm1_generator = create_llm_generator_with_delay(llm1_chunks, delay=0.05) # Slower - llm2_generator = create_llm_generator_with_delay(llm2_chunks, delay=0.01) # Faster - - # Track which LLM node is being called - llm_call_order = [] - generators = { - "1754339718571": llm1_generator, # LLM 1 node ID - "1754339725656": llm2_generator, # LLM 2 node ID - } - - def mock_llm_run(self): - llm_call_order.append(self.id) - generator = generators.get(self.id) - if generator: - yield from generator(self) - else: - raise Exception(f"Unexpected LLM node ID: {self.id}") - - # Execute with mocked LLMs - with patch.object(LLMNode, "_run", new=mock_llm_run): - events = list(engine.run()) - - # Check for successful completion - success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] - assert len(success_events) > 0, "Workflow should complete successfully" - - # Get all streaming chunk events - stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] - - # Get Answer node start event - answer_start_events = [ - e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.ANSWER - ] - assert len(answer_start_events) == 1, f"Expected 1 Answer node start event, got {len(answer_start_events)}" - answer_start_event = answer_start_events[0] - - # Find the index of Answer node start - answer_start_index = events.index(answer_start_event) - - # Collect chunk events by node - llm1_chunks_events = [e for e in stream_chunk_events if e.node_id == "1754339718571"] - llm2_chunks_events = [e for e in stream_chunk_events if e.node_id == "1754339725656"] - - # Verify both LLMs produced chunks - assert len(llm1_chunks_events) == len(llm1_chunks), ( - f"Expected {len(llm1_chunks)} chunks from LLM 1, got {len(llm1_chunks_events)}" - ) - assert len(llm2_chunks_events) == len(llm2_chunks), ( - f"Expected {len(llm2_chunks)} chunks from LLM 2, got {len(llm2_chunks_events)}" - ) - - # 1. Verify chunk ordering based on actual implementation - llm1_chunk_indices = [events.index(e) for e in llm1_chunks_events] - llm2_chunk_indices = [events.index(e) for e in llm2_chunks_events] - - # In the current implementation, chunks may be interleaved or in a specific order - # Update this based on actual behavior observed - if llm1_chunk_indices and llm2_chunk_indices: - # Check the actual ordering - if LLM 2 chunks come first (as seen in debug) - assert max(llm2_chunk_indices) < min(llm1_chunk_indices), ( - f"All LLM 2 chunks should be output before LLM 1 chunks. " - f"LLM 2 chunk indices: {llm2_chunk_indices}, LLM 1 chunk indices: {llm1_chunk_indices}" - ) - - # Get indices of all chunk events - chunk_indices = [events.index(e) for e in stream_chunk_events if e in llm1_chunks_events + llm2_chunks_events] - - # 4. Verify all chunks were sent before Answer node started - assert all(idx < answer_start_index for idx in chunk_indices), ( - "All LLM chunks should be sent before Answer node starts" - ) - - # The test has successfully verified: - # 1. Both LLMs run in parallel (they start at the same time) - # 2. LLM 2 (Chinese) outputs all its chunks before LLM 1 (English) due to faster processing - # 3. All LLM chunks are sent before the Answer node starts - - # Get LLM completion events - llm_completed_events = [ - (i, e) - for i, e in enumerate(events) - if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM - ] - - # Check LLM completion order - in the current implementation, LLMs run sequentially - # LLM 1 completes first, then LLM 2 runs and completes - assert len(llm_completed_events) == 2, f"Expected 2 LLM completion events, got {len(llm_completed_events)}" - llm2_complete_idx = next((i for i, e in llm_completed_events if e.node_id == "1754339725656"), None) - llm1_complete_idx = next((i for i, e in llm_completed_events if e.node_id == "1754339718571"), None) - assert llm2_complete_idx is not None, "LLM 2 completion event not found" - assert llm1_complete_idx is not None, "LLM 1 completion event not found" - # In the actual implementation, LLM 1 completes before LLM 2 (sequential execution) - assert llm1_complete_idx < llm2_complete_idx, ( - f"LLM 1 should complete before LLM 2 in sequential execution, but LLM 1 completed at {llm1_complete_idx} " - f"and LLM 2 completed at {llm2_complete_idx}" - ) - - # 2. In sequential execution, LLM 2 chunks appear AFTER LLM 1 completes - if llm2_chunk_indices: - # LLM 1 completes first, then LLM 2 starts streaming - assert min(llm2_chunk_indices) > llm1_complete_idx, ( - f"LLM 2 chunks should appear after LLM 1 completes in sequential execution. " - f"First LLM 2 chunk at index {min(llm2_chunk_indices)}, LLM 1 completed at index {llm1_complete_idx}" - ) - - # 3. In the current implementation, LLM 1 chunks appear after LLM 2 completes - # This is because chunks are buffered and output after both nodes complete - if llm1_chunk_indices and llm2_complete_idx: - # Check if LLM 1 chunks exist and where they appear relative to LLM 2 completion - # In current behavior, LLM 1 chunks typically appear after LLM 2 completes - pass # Skipping this check as the chunk ordering is implementation-dependent - - # CURRENT BEHAVIOR: Chunks are buffered and appear after node completion - # In the sequential execution, LLM 1 completes first without streaming, - # then LLM 2 streams its chunks - assert stream_chunk_events, "Expected streaming events, but got none" - - first_chunk_index = events.index(stream_chunk_events[0]) - llm_success_indices = [i for i, e in llm_completed_events] - - # Current implementation: LLM 1 completes first, then chunks start appearing - # This is the actual behavior we're testing - if llm_success_indices: - # At least one LLM (LLM 1) completes before any chunks appear - assert min(llm_success_indices) < first_chunk_index, ( - f"In current implementation, LLM 1 completes before chunks start streaming. " - f"First chunk at index {first_chunk_index}, LLM 1 completed at index {min(llm_success_indices)}" - ) - - # 5. Verify final output content matches the order defined in Answer node - # According to Answer node configuration: '{{#1754339725656.text#}}{{#1754339718571.text#}}' - # This means LLM 2 output should come first, then LLM 1 output - answer_complete_events = [ - e for e in events if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.ANSWER - ] - assert len(answer_complete_events) == 1, f"Expected 1 Answer completion event, got {len(answer_complete_events)}" - - answer_outputs = answer_complete_events[0].node_run_result.outputs - expected_answer_text = "你好,我是AI助手。Hello, I am an AI assistant." - - if "answer" in answer_outputs: - actual_answer_text = answer_outputs["answer"] - assert actual_answer_text == expected_answer_text, ( - f"Answer content should match the order defined in Answer node. " - f"Expected: '{expected_answer_text}', Got: '{actual_answer_text}'" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py deleted file mode 100644 index bcf123ee80..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ /dev/null @@ -1,311 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.graph import Graph -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from graphon.graph_engine.config import GraphEngineConfig -from graphon.graph_engine.graph_engine import GraphEngine -from graphon.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig, NodeMockConfig -from .test_mock_nodes import MockLLMNode - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def submission_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, form: HumanInputFormEntity) -> None: - self._form = form - - def get_form(self, node_id: str) -> HumanInputFormEntity | None: - if node_id != "human_pause": - return None - return self._form - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in this test") - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - llm_a_data = LLMNodeData( - title="LLM A", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt A", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_a_config = {"id": "llm_a", "data": llm_a_data.model_dump()} - llm_a = MockLLMNode( - id=llm_a_config["id"], - config=llm_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - llm_b_data = LLMNodeData( - title="LLM B", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt B", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_b_config = {"id": "llm_b", "data": llm_b_data.model_dump()} - llm_b = MockLLMNode( - id=llm_b_config["id"], - config=llm_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Pause here", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - human_config = {"id": "human_pause", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), - ) - - end_human_data = EndNodeData(title="End Human", outputs=[], desc=None) - end_human_config = {"id": "end_human", "data": end_human_data.model_dump()} - end_human = EndNode( - id=end_human_config["id"], - config=end_human_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(llm_a, from_node_id="start") - .add_node(human_node, from_node_id="start") - .add_node(llm_b, from_node_id="llm_a") - .add_node(end_human, from_node_id="human_pause", source_handle="approve") - .build() - ) - - -def _get_node_started_event(events: list[object], node_id: str) -> NodeRunStartedEvent | None: - for event in events: - if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: - return event - return None - - -def test_pause_defers_ready_nodes_until_resume() -> None: - runtime_state = _build_runtime_state() - - paused_form = StaticForm( - form_id="form-pause", - rendered="rendered", - is_submitted=False, - status_value=HumanInputFormStatus.WAITING, - ) - pause_repo = StaticRepo(paused_form) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - mock_config.set_node_config( - "llm_b", - NodeMockConfig(node_id="llm_b", outputs={"text": "LLM B output"}, delay=0.0), - ) - - graph = _build_graph(runtime_state, pause_repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - paused_events = list(engine.run()) - - assert any(isinstance(e, GraphRunPausedEvent) for e in paused_events) - assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in paused_events) - assert _get_node_started_event(paused_events, "llm_b") is None - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - submitted_form = StaticForm( - form_id="form-pause", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - resume_repo = StaticRepo(submitted_form) - - resumed_graph = _build_graph(resumed_state, resume_repo, mock_config) - resumed_engine = GraphEngine( - workflow_id="workflow", - graph=resumed_graph, - graph_runtime_state=resumed_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - resumed_events = list(resumed_engine.run()) - - start_event = next(e for e in resumed_events if isinstance(e, GraphRunStartedEvent)) - assert start_event.reason is WorkflowStartReason.RESUMPTION - - llm_b_started = _get_node_started_event(resumed_events, "llm_b") - assert llm_b_started is not None - assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_b" for e in resumed_events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py deleted file mode 100644 index 79d3d5bcfe..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ /dev/null @@ -1,219 +0,0 @@ -import datetime -import time -from typing import Any -from unittest.mock import MagicMock - -from core.repositories.human_input_repository import ( - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.graph import Graph -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from graphon.graph_engine.graph_engine import GraphEngine -from graphon.graph_events import ( - GraphEngineEvent, - GraphRunPausedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from graphon.graph_events.graph import GraphRunStartedEvent -from graphon.nodes.base.entities import OutputVariableEntity -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="test-execution-id", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.submission_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = True - form_entity.selected_action_id = action_id - form_entity.submitted_data = {} - form_entity.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - repo.get_form.return_value = form_entity - return repo - - -def _mock_form_repository_without_submission() -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.submission_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = False - repo.create_form.return_value = form_entity - repo.get_form.return_value = None - return repo - - -def _build_human_input_graph( - runtime_state: GraphRuntimeState, - form_repository: HumanInputFormRepository, -) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="service-api", - call_depth=0, - ) - - start_data = StartNodeData(title="start", variables=[]) - start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="human", - form_content="Awaiting human input", - inputs=[], - user_actions=[ - UserAction(id="continue", title="Continue"), - ], - ) - human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - form_repository=form_repository, - runtime=DifyHumanInputNodeRuntime(params.run_context), - ) - - end_data = EndNodeData( - title="end", - outputs=[ - OutputVariableEntity(variable="result", value_selector=["human", "action_id"]), - ], - desc=None, - ) - end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_node) - .add_node(end_node, from_node_id="human", source_handle="continue") - .build() - ) - - -def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]: - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - ) - return list(engine.run()) - - -def _node_successes(events: list[GraphEngineEvent]) -> list[str]: - return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)] - - -def _node_start_event(events: list[GraphEngineEvent], node_id: str) -> NodeRunStartedEvent | None: - for event in events: - if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: - return event - return None - - -def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any: - segment = variable_pool.get(selector) - assert segment is not None - return getattr(segment, "value", segment) - - -def test_engine_resume_restores_state_and_completion(): - # Baseline run without pausing - baseline_state = _build_runtime_state() - baseline_repo = _mock_form_repository_with_submission(action_id="continue") - baseline_graph = _build_human_input_graph(baseline_state, baseline_repo) - baseline_events = _run_graph(baseline_graph, baseline_state) - assert baseline_events - first_paused_event = baseline_events[0] - assert isinstance(first_paused_event, GraphRunStartedEvent) - assert first_paused_event.reason is WorkflowStartReason.INITIAL - assert isinstance(baseline_events[-1], GraphRunSucceededEvent) - baseline_success_nodes = _node_successes(baseline_events) - - # Run with pause - paused_state = _build_runtime_state() - pause_repo = _mock_form_repository_without_submission() - paused_graph = _build_human_input_graph(paused_state, pause_repo) - paused_events = _run_graph(paused_graph, paused_state) - assert paused_events - first_paused_event = paused_events[0] - assert isinstance(first_paused_event, GraphRunStartedEvent) - assert first_paused_event.reason is WorkflowStartReason.INITIAL - assert isinstance(paused_events[-1], GraphRunPausedEvent) - snapshot = paused_state.dumps() - - # Resume from snapshot - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resume_repo = _mock_form_repository_with_submission(action_id="continue") - resumed_graph = _build_human_input_graph(resumed_state, resume_repo) - resumed_events = _run_graph(resumed_graph, resumed_state) - assert resumed_events - first_resumed_event = resumed_events[0] - assert isinstance(first_resumed_event, GraphRunStartedEvent) - assert first_resumed_event.reason is WorkflowStartReason.RESUMPTION - assert isinstance(resumed_events[-1], GraphRunSucceededEvent) - - combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events) - assert combined_success_nodes == baseline_success_nodes - - paused_human_started = _node_start_event(paused_events, "human") - resumed_human_started = _node_start_event(resumed_events, "human") - assert paused_human_started is not None - assert resumed_human_started is not None - assert paused_human_started.id == resumed_human_started.id - - assert baseline_state.outputs == resumed_state.outputs - assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value( - resumed_state.variable_pool, ("human", "__action_id") - ) - assert baseline_state.graph_execution.completed - assert resumed_state.graph_execution.completed diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py deleted file mode 100644 index 146b728dc2..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py +++ /dev/null @@ -1,268 +0,0 @@ -""" -Unit tests for Redis-based stop functionality in GraphEngine. - -Tests the integration of Redis command channel for stopping workflows -without user permission checks. -""" - -import json -from unittest.mock import MagicMock, Mock, patch - -import pytest -import redis - -from core.app.apps.base_app_queue_manager import AppQueueManager -from graphon.graph_engine.command_channels.redis_channel import RedisChannel -from graphon.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand -from graphon.graph_engine.manager import GraphEngineManager - - -class TestRedisStopIntegration: - """Test suite for Redis-based workflow stop functionality.""" - - def test_graph_engine_manager_sends_abort_command(self): - """Test that GraphEngineManager correctly sends abort command through Redis.""" - # Setup - task_id = "test-task-123" - expected_channel_key = f"workflow:{task_id}:commands" - - # Mock redis client - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - manager = GraphEngineManager(mock_redis) - - # Execute - manager.send_stop_command(task_id, reason="Test stop") - - # Verify - mock_redis.pipeline.assert_called_once() - - # Check that rpush was called with correct arguments - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - - # Verify the channel key - assert calls[0][0][0] == expected_channel_key - - # Verify the command data - command_json = calls[0][0][1] - command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.ABORT - assert command_data["reason"] == "Test stop" - - def test_graph_engine_manager_sends_pause_command(self): - """Test that GraphEngineManager correctly sends pause command through Redis.""" - task_id = "test-task-pause-123" - expected_channel_key = f"workflow:{task_id}:commands" - - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - manager = GraphEngineManager(mock_redis) - manager.send_pause_command(task_id, reason="Awaiting resources") - - mock_redis.pipeline.assert_called_once() - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - assert calls[0][0][0] == expected_channel_key - - command_json = calls[0][0][1] - command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.PAUSE.value - assert command_data["reason"] == "Awaiting resources" - - def test_graph_engine_manager_handles_redis_failure_gracefully(self): - """Test that GraphEngineManager handles Redis failures without raising exceptions.""" - task_id = "test-task-456" - - # Mock redis client to raise exception - mock_redis = MagicMock() - mock_redis.pipeline.side_effect = redis.ConnectionError("Redis connection failed") - manager = GraphEngineManager(mock_redis) - - # Should not raise exception - try: - manager.send_stop_command(task_id) - except Exception as e: - pytest.fail(f"GraphEngineManager.send_stop_command raised {e} unexpectedly") - - def test_app_queue_manager_no_user_check(self): - """Test that AppQueueManager.set_stop_flag_no_user_check works without user validation.""" - task_id = "test-task-789" - expected_cache_key = f"generate_task_stopped:{task_id}" - - # Mock redis client - mock_redis = MagicMock() - - with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): - # Execute - AppQueueManager.set_stop_flag_no_user_check(task_id) - - # Verify - mock_redis.setex.assert_called_once_with(expected_cache_key, 600, 1) - - def test_app_queue_manager_no_user_check_with_empty_task_id(self): - """Test that AppQueueManager.set_stop_flag_no_user_check handles empty task_id.""" - # Mock redis client - mock_redis = MagicMock() - - with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): - # Execute with empty task_id - AppQueueManager.set_stop_flag_no_user_check("") - - # Verify redis was not called - mock_redis.setex.assert_not_called() - - def test_redis_channel_send_abort_command(self): - """Test RedisChannel correctly serializes and sends AbortCommand.""" - # Setup - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - channel_key = "workflow:test:commands" - channel = RedisChannel(mock_redis, channel_key) - - # Create commands - abort_command = AbortCommand(reason="User requested stop") - pause_command = PauseCommand(reason="User requested pause") - - # Execute - channel.send_command(abort_command) - channel.send_command(pause_command) - - # Verify - mock_redis.pipeline.assert_called() - - # Check rpush was called - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 2 - assert calls[0][0][0] == channel_key - assert calls[1][0][0] == channel_key - - # Verify serialized commands - abort_command_json = calls[0][0][1] - abort_command_data = json.loads(abort_command_json) - assert abort_command_data["command_type"] == CommandType.ABORT.value - assert abort_command_data["reason"] == "User requested stop" - - pause_command_json = calls[1][0][1] - pause_command_data = json.loads(pause_command_json) - assert pause_command_data["command_type"] == CommandType.PAUSE.value - assert pause_command_data["reason"] == "User requested pause" - - # Check expire was set for each - assert mock_pipeline.expire.call_count == 2 - mock_pipeline.expire.assert_any_call(channel_key, 3600) - - def test_redis_channel_fetch_commands(self): - """Test RedisChannel correctly fetches and deserializes commands.""" - # Setup - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Mock command data - abort_command_json = json.dumps( - {"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None} - ) - pause_command_json = json.dumps( - {"command_type": CommandType.PAUSE.value, "reason": "Pause requested", "payload": None} - ) - - # Mock pipeline execute to return commands - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [ - [abort_command_json.encode(), pause_command_json.encode()], # lrange result - True, # delete result - ] - - channel_key = "workflow:test:commands" - channel = RedisChannel(mock_redis, channel_key) - - # Execute - commands = channel.fetch_commands() - - # Verify - assert len(commands) == 2 - assert isinstance(commands[0], AbortCommand) - assert commands[0].command_type == CommandType.ABORT - assert commands[0].reason == "Test abort" - assert isinstance(commands[1], PauseCommand) - assert commands[1].command_type == CommandType.PAUSE - assert commands[1].reason == "Pause requested" - - # Verify Redis operations - pending_pipe.get.assert_called_once_with(f"{channel_key}:pending") - pending_pipe.delete.assert_called_once_with(f"{channel_key}:pending") - fetch_pipe.lrange.assert_called_once_with(channel_key, 0, -1) - fetch_pipe.delete.assert_called_once_with(channel_key) - assert mock_redis.pipeline.call_count == 2 - - def test_redis_channel_fetch_commands_handles_invalid_json(self): - """Test RedisChannel gracefully handles invalid JSON in commands.""" - # Setup - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Mock invalid command data - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [ - [b"invalid json", b'{"command_type": "invalid_type"}'], # lrange result - True, # delete result - ] - - channel_key = "workflow:test:commands" - channel = RedisChannel(mock_redis, channel_key) - - # Execute - commands = channel.fetch_commands() - - # Should return empty list due to invalid commands - assert len(commands) == 0 - - def test_dual_stop_mechanism_compatibility(self): - """Test that both stop mechanisms can work together.""" - task_id = "test-task-dual" - - # Mock redis client - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): - # Execute both stop mechanisms - AppQueueManager.set_stop_flag_no_user_check(task_id) - GraphEngineManager(mock_redis).send_stop_command(task_id) - - # Verify legacy stop flag was set - expected_stop_flag_key = f"generate_task_stopped:{task_id}" - mock_redis.setex.assert_called_once_with(expected_stop_flag_key, 600, 1) - - # Verify command was sent through Redis channel - mock_redis.pipeline.assert_called() - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - assert calls[0][0][0] == f"workflow:{task_id}:commands" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py deleted file mode 100644 index 62ca7a630e..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Unit tests for response session creation.""" - -from __future__ import annotations - -import pytest - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, NodeType -from graphon.graph_engine.response_coordinator.session import ResponseSession -from graphon.nodes.base.template import Template, TextSegment - - -class DummyResponseNode: - """Minimal response-capable node for session tests.""" - - def __init__(self, *, node_id: str, node_type: NodeType, template: Template) -> None: - self.id = node_id - self.node_type = node_type - self.execution_type = NodeExecutionType.RESPONSE - self.state = NodeState.UNKNOWN - self._template = template - - def get_streaming_template(self) -> Template: - return self._template - - -class DummyNodeWithoutStreamingTemplate: - """Minimal node that violates the response-session contract.""" - - def __init__(self, *, node_id: str, node_type: NodeType) -> None: - self.id = node_id - self.node_type = node_type - self.execution_type = NodeExecutionType.RESPONSE - self.state = NodeState.UNKNOWN - - -def test_response_session_from_node_accepts_nodes_outside_previous_allowlist() -> None: - """Session creation depends on the streaming-template contract rather than node type.""" - node = DummyResponseNode( - node_id="llm-node", - node_type=BuiltinNodeTypes.LLM, - template=Template(segments=[TextSegment(text="hello")]), - ) - - session = ResponseSession.from_node(node) - - assert session.node_id == "llm-node" - assert session.template.segments == [TextSegment(text="hello")] - - -def test_response_session_from_node_requires_streaming_template_method() -> None: - """Allowed node types still need to implement the streaming-template contract.""" - node = DummyNodeWithoutStreamingTemplate(node_id="answer-node", node_type=BuiltinNodeTypes.ANSWER) - - with pytest.raises(TypeError, match="get_streaming_template"): - ResponseSession.from_node(node) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py deleted file mode 100644 index a359a5fef9..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py +++ /dev/null @@ -1,79 +0,0 @@ -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_streaming_conversation_variables(): - fixture_name = "test_streaming_conversation_variables" - - # The test expects the workflow to output the input query - # Since the workflow assigns sys.query to conversation variable "str" and then answers with it - input_query = "Hello, this is my test query" - - mock_config = MockConfigBuilder().build() - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=False, # Don't use auto mock since we want to test actual variable assignment - mock_config=mock_config, - query=input_query, # Pass query as the sys.query value - inputs={}, # No additional inputs needed - expected_outputs={"answer": input_query}, # Expecting the input query to be output - expected_event_sequence=[ - GraphRunStartedEvent, - # START node - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Variable Assigner node - NodeRunStartedEvent, - NodeRunVariableUpdatedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - # ANSWER node - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" - - -def test_streaming_conversation_variables_v1_overwrite_waits_for_assignment(): - fixture_name = "test_streaming_conversation_variables_v1_overwrite" - input_query = "overwrite-value" - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=False, - mock_config=MockConfigBuilder().build(), - query=input_query, - inputs={}, - expected_outputs={"answer": f"Current Value Of `conv_var` is:{input_query}"}, - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" - - events = result.events - conv_var_chunk_events = [ - event - for event in events - if isinstance(event, NodeRunStreamChunkEvent) and tuple(event.selector) == ("conversation", "conv_var") - ] - - assert conv_var_chunk_events, "Expected conversation variable chunk events to be emitted" - assert all(event.chunk == input_query for event in conv_var_chunk_events), ( - "Expected streamed conversation variable value to match the input query" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 81d68ba2aa..b11f957677 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -19,12 +19,7 @@ from functools import lru_cache from pathlib import Path from typing import Any -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.tools.utils.yaml_utils import _load_yaml_file -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool -from graphon.entities.graph_init_params import GraphInitParams +from graphon.entities import GraphInitParams from graphon.graph import Graph from graphon.graph_engine import GraphEngine, GraphEngineConfig from graphon.graph_engine.command_channels import InMemoryChannel @@ -44,6 +39,12 @@ from graphon.variables import ( StringVariable, ) +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.tools.utils.yaml_utils import _load_yaml_file +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool + from .test_mock_config import MockConfig from .test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py deleted file mode 100644 index a7309f64de..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Validate conversation variable updates inside an iteration workflow. - -This test uses the ``update-conversation-variable-in-iteration`` fixture, which -routes ``sys.query`` into the conversation variable ``answer`` from within an -iteration container. The workflow should surface that updated conversation -variable in the final answer output. - -Code nodes in the fixture are mocked because their concrete outputs are not -relevant to verifying variable propagation semantics. -""" - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_update_conversation_variable_in_iteration(): - fixture_name = "update-conversation-variable-in-iteration" - user_query = "ensure conversation variable syncs" - - mock_config = ( - MockConfigBuilder() - .with_node_output("1759032363865", {"result": [1]}) - .with_node_output("1759032476318", {"result": ""}) - .build() - ) - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - query=user_query, - expected_outputs={"answer": user_query}, - description="Conversation variable updated within iteration should flow to answer output.", - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - - assert result.success, f"Workflow execution failed: {result.error}" - assert result.actual_outputs is not None - assert result.actual_outputs.get("answer") == user_query diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py deleted file mode 100644 index 2ad41037a9..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py +++ /dev/null @@ -1,58 +0,0 @@ -from unittest.mock import patch - -import pytest - -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode - -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -class TestVariableAggregator: - """Test cases for the variable aggregator workflow.""" - - @pytest.mark.parametrize( - ("switch1", "switch2", "expected_group1", "expected_group2", "description"), - [ - (0, 0, "switch 1 off", "switch 2 off", "Both switches off"), - (0, 1, "switch 1 off", "switch 2 on", "Switch1 off, Switch2 on"), - (1, 0, "switch 1 on", "switch 2 off", "Switch1 on, Switch2 off"), - (1, 1, "switch 1 on", "switch 2 on", "Both switches on"), - ], - ) - def test_variable_aggregator_combinations( - self, - switch1: int, - switch2: int, - expected_group1: str, - expected_group2: str, - description: str, - ) -> None: - """Test all four combinations of switch1 and switch2.""" - - def mock_template_transform_run(self): - """Mock the TemplateTransformNode._run() method to return results based on node title.""" - title = self._node_data.title - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"output": title}) - - with patch.object( - TemplateTransformNode, - "_run", - mock_template_transform_run, - ): - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="dual_switch_variable_aggregator_workflow", - inputs={"switch1": switch1, "switch2": switch2}, - expected_outputs={"group1": expected_group1, "group2": expected_group2}, - description=description, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs == test_case.expected_outputs, ( - f"Output mismatch: expected {test_case.expected_outputs}, got {result.actual_outputs}" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py deleted file mode 100644 index 60cab77c0a..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py +++ /dev/null @@ -1,129 +0,0 @@ -import time -import uuid -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool -from graphon.entities import GraphInitParams -from graphon.graph import Graph -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import NodeRunVariableUpdatedEvent -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import StringVariable - -DEFAULT_NODE_ID = "node_id" - - -class CaptureVariableUpdateLayer(GraphEngineLayer): - def __init__(self) -> None: - super().__init__() - self.events: list[NodeRunVariableUpdatedEvent] = [] - self.observed_values: list[object | None] = [] - - def on_graph_start(self) -> None: - pass - - def on_event(self, event) -> None: - if not isinstance(event, NodeRunVariableUpdatedEvent): - return - - current_value = self.graph_runtime_state.variable_pool.get(event.variable.selector) - self.events.append(event) - self.observed_values.append(None if current_value is None else current_value.value) - - def on_graph_end(self, error: Exception | None) -> None: - pass - - -def test_graph_engine_applies_variable_updates_before_notifying_layers(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "over-write", - "input_variable_selector": ["node_id", "test_string_variable"], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - variable_pool = VariablePool() - add_variables_to_pool( - variable_pool, - build_bootstrap_variables( - system_variables=build_system_variables(conversation_id=str(uuid.uuid4())), - conversation_variables=[ - StringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value="the first value", - ) - ], - ), - ) - variable_pool.add( - [DEFAULT_NODE_ID, "test_string_variable"], - StringVariable( - id=str(uuid4()), - name="test_string_variable", - value="the second value", - ), - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - engine = GraphEngine( - workflow_id="workflow-id", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - capture_layer = CaptureVariableUpdateLayer() - engine.layer(capture_layer) - - events = list(engine.run()) - - update_events = [event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)] - assert len(update_events) == 1 - assert update_events[0].variable.value == "the second value" - - current_value = graph_runtime_state.variable_pool.get(["conversation", "test_conversation_variable"]) - assert current_value is not None - assert current_value.value == "the second value" - - assert len(capture_layer.events) == 1 - assert capture_layer.observed_values == ["the second value"] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py b/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py deleted file mode 100644 index 85132674b8..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py +++ /dev/null @@ -1,148 +0,0 @@ -import queue -from collections.abc import Generator -from datetime import UTC, datetime, timedelta -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph_engine.ready_queue import InMemoryReadyQueue -from graphon.graph_engine.worker import Worker -from graphon.graph_events import NodeRunFailedEvent, NodeRunStartedEvent - - -def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None: - fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - mock_datetime = mocker.patch("graphon.graph_engine.worker.datetime") - mock_datetime.now.return_value = fixed_time.replace(tzinfo=UTC) - - worker = Worker( - ready_queue=InMemoryReadyQueue(), - event_queue=queue.Queue(), - graph=MagicMock(), - layers=[], - ) - node = SimpleNamespace( - execution_id="exec-1", - id="node-1", - node_type=BuiltinNodeTypes.LLM, - ) - - event = worker._build_fallback_failure_event(node, RuntimeError("boom")) - - assert event.start_at == fixed_time - assert event.finished_at == fixed_time - assert event.error == "boom" - assert event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED - assert event.node_run_result.error == "boom" - assert event.node_run_result.error_type == "RuntimeError" - - -def test_worker_fallback_failure_event_reuses_observed_start_time() -> None: - start_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - failure_time = start_at + timedelta(seconds=5) - captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = [] - - class FakeNode: - execution_id = "exec-1" - id = "node-1" - node_type = BuiltinNodeTypes.LLM - - def ensure_execution_id(self) -> str: - return self.execution_id - - def run(self) -> Generator[NodeRunStartedEvent, None, None]: - yield NodeRunStartedEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - node_title="LLM", - start_at=start_at, - ) - - worker = Worker( - ready_queue=MagicMock(), - event_queue=MagicMock(), - graph=MagicMock(nodes={"node-1": FakeNode()}), - layers=[], - ) - - worker._ready_queue.get.side_effect = ["node-1"] - - def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None: - captured_events.append(event) - if len(captured_events) == 1: - raise RuntimeError("queue boom") - worker.stop() - - worker._event_queue.put.side_effect = put_side_effect - - with patch("graphon.graph_engine.worker.datetime") as mock_datetime: - mock_datetime.now.return_value = failure_time.replace(tzinfo=UTC) - worker.run() - - fallback_event = captured_events[-1] - - assert isinstance(fallback_event, NodeRunFailedEvent) - assert fallback_event.start_at == start_at - assert fallback_event.finished_at == failure_time - assert fallback_event.error == "queue boom" - assert fallback_event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED - - -def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_times() -> None: - parent_start = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - child_start = parent_start + timedelta(seconds=3) - failure_time = parent_start + timedelta(seconds=5) - captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = [] - - class FakeIterationNode: - execution_id = "iteration-exec" - id = "iteration-node" - node_type = BuiltinNodeTypes.ITERATION - - def ensure_execution_id(self) -> str: - return self.execution_id - - def run(self) -> Generator[NodeRunStartedEvent, None, None]: - yield NodeRunStartedEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - node_title="Iteration", - start_at=parent_start, - ) - yield NodeRunStartedEvent( - id="child-exec", - node_id="child-node", - node_type=BuiltinNodeTypes.LLM, - node_title="LLM", - start_at=child_start, - in_iteration_id=self.id, - ) - - worker = Worker( - ready_queue=MagicMock(), - event_queue=MagicMock(), - graph=MagicMock(nodes={"iteration-node": FakeIterationNode()}), - layers=[], - ) - - worker._ready_queue.get.side_effect = ["iteration-node"] - - def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None: - captured_events.append(event) - if len(captured_events) == 2: - raise RuntimeError("queue boom") - worker.stop() - - worker._event_queue.put.side_effect = put_side_effect - - with patch("graphon.graph_engine.worker.datetime") as mock_datetime: - mock_datetime.now.return_value = failure_time.replace(tzinfo=UTC) - worker.run() - - fallback_event = captured_events[-1] - - assert isinstance(fallback_event, NodeRunFailedEvent) - assert fallback_event.start_at == parent_start - assert fallback_event.finished_at == failure_time diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py index 1f4509af9a..cbc920705c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py @@ -1,8 +1,9 @@ from unittest.mock import patch +from graphon.enums import BuiltinNodeTypes + from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer -from graphon.enums import BuiltinNodeTypes def test_transform_passes_conversation_id_to_tool_file_message_transformer() -> None: diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py index c86de7f6e6..59dd763b59 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py @@ -1,9 +1,10 @@ from types import SimpleNamespace from unittest.mock import Mock, patch -from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport from graphon.model_runtime.entities.model_entities import ModelType +from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport + def test_fetch_model_reuses_single_model_assembly(): provider_configuration = SimpleNamespace( diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 9c0ad25b58..7195471eb6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,14 +2,15 @@ import time import uuid from unittest.mock import MagicMock +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.answer.answer_node import AnswerNode +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables from extensions.ext_database import db -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.nodes.answer.answer_node import AnswerNode -from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index ec4cef1955..343bcd3919 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,10 +1,10 @@ import pytest - -from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node +from core.workflow.node_factory import get_node_type_classes_mapping + # Ensures that all production node classes are imported and registered. _ = get_node_type_classes_mapping() diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index ef0df55995..b9371a34f4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -1,7 +1,6 @@ import types from collections.abc import Mapping -from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node @@ -14,6 +13,8 @@ from graphon.nodes.variable_assigner.v2.node import ( VariableAssignerNode as VariableAssignerV2, ) +from core.workflow.node_factory import get_node_type_classes_mapping + def test_variable_assigner_latest_prefers_highest_numeric_version(): # Act diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index ce0c9b79c6..d155124c50 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,4 +1,3 @@ -from configs import dify_config from graphon.nodes.code.code_node import CodeNode from graphon.nodes.code.entities import CodeLanguage, CodeNodeData from graphon.nodes.code.exc import ( @@ -9,6 +8,8 @@ from graphon.nodes.code.exc import ( from graphon.nodes.code.limits import CodeNodeLimits from graphon.variables.types import SegmentType +from configs import dify_config + CodeNode._limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, max_number=dify_config.CODE_MAX_NUMBER, diff --git a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py deleted file mode 100644 index 20fe2c1a74..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py +++ /dev/null @@ -1,352 +0,0 @@ -import pytest -from pydantic import ValidationError - -from graphon.nodes.code.entities import CodeLanguage, CodeNodeData -from graphon.variables.types import SegmentType - - -class TestCodeNodeDataOutput: - """Test suite for CodeNodeData.Output model.""" - - def test_output_with_string_type(self): - """Test Output with STRING type.""" - output = CodeNodeData.Output(type=SegmentType.STRING) - - assert output.type == SegmentType.STRING - assert output.children is None - - def test_output_with_number_type(self): - """Test Output with NUMBER type.""" - output = CodeNodeData.Output(type=SegmentType.NUMBER) - - assert output.type == SegmentType.NUMBER - assert output.children is None - - def test_output_with_boolean_type(self): - """Test Output with BOOLEAN type.""" - output = CodeNodeData.Output(type=SegmentType.BOOLEAN) - - assert output.type == SegmentType.BOOLEAN - - def test_output_with_object_type(self): - """Test Output with OBJECT type.""" - output = CodeNodeData.Output(type=SegmentType.OBJECT) - - assert output.type == SegmentType.OBJECT - - def test_output_with_array_string_type(self): - """Test Output with ARRAY_STRING type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_STRING) - - assert output.type == SegmentType.ARRAY_STRING - - def test_output_with_array_number_type(self): - """Test Output with ARRAY_NUMBER type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER) - - assert output.type == SegmentType.ARRAY_NUMBER - - def test_output_with_array_object_type(self): - """Test Output with ARRAY_OBJECT type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_OBJECT) - - assert output.type == SegmentType.ARRAY_OBJECT - - def test_output_with_array_boolean_type(self): - """Test Output with ARRAY_BOOLEAN type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN) - - assert output.type == SegmentType.ARRAY_BOOLEAN - - def test_output_with_nested_children(self): - """Test Output with nested children for OBJECT type.""" - child_output = CodeNodeData.Output(type=SegmentType.STRING) - parent_output = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={"name": child_output}, - ) - - assert parent_output.type == SegmentType.OBJECT - assert parent_output.children is not None - assert "name" in parent_output.children - assert parent_output.children["name"].type == SegmentType.STRING - - def test_output_with_deeply_nested_children(self): - """Test Output with deeply nested children.""" - inner_child = CodeNodeData.Output(type=SegmentType.NUMBER) - middle_child = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={"value": inner_child}, - ) - outer_output = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={"nested": middle_child}, - ) - - assert outer_output.children is not None - assert outer_output.children["nested"].children is not None - assert outer_output.children["nested"].children["value"].type == SegmentType.NUMBER - - def test_output_with_multiple_children(self): - """Test Output with multiple children.""" - output = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - "age": CodeNodeData.Output(type=SegmentType.NUMBER), - "active": CodeNodeData.Output(type=SegmentType.BOOLEAN), - }, - ) - - assert output.children is not None - assert len(output.children) == 3 - assert output.children["name"].type == SegmentType.STRING - assert output.children["age"].type == SegmentType.NUMBER - assert output.children["active"].type == SegmentType.BOOLEAN - - def test_output_rejects_invalid_type(self): - """Test Output rejects invalid segment types.""" - with pytest.raises(ValidationError): - CodeNodeData.Output(type=SegmentType.FILE) - - def test_output_rejects_array_file_type(self): - """Test Output rejects ARRAY_FILE type.""" - with pytest.raises(ValidationError): - CodeNodeData.Output(type=SegmentType.ARRAY_FILE) - - -class TestCodeNodeDataDependency: - """Test suite for CodeNodeData.Dependency model.""" - - def test_dependency_basic(self): - """Test Dependency with name and version.""" - dependency = CodeNodeData.Dependency(name="numpy", version="1.24.0") - - assert dependency.name == "numpy" - assert dependency.version == "1.24.0" - - def test_dependency_with_complex_version(self): - """Test Dependency with complex version string.""" - dependency = CodeNodeData.Dependency(name="pandas", version=">=2.0.0,<3.0.0") - - assert dependency.name == "pandas" - assert dependency.version == ">=2.0.0,<3.0.0" - - def test_dependency_with_empty_version(self): - """Test Dependency with empty version.""" - dependency = CodeNodeData.Dependency(name="requests", version="") - - assert dependency.name == "requests" - assert dependency.version == "" - - -class TestCodeNodeData: - """Test suite for CodeNodeData model.""" - - def test_code_node_data_python3(self): - """Test CodeNodeData with Python3 language.""" - data = CodeNodeData( - title="Test Code Node", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'result': 42}", - outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)}, - ) - - assert data.title == "Test Code Node" - assert data.code_language == CodeLanguage.PYTHON3 - assert data.code == "def main(): return {'result': 42}" - assert "result" in data.outputs - assert data.dependencies is None - - def test_code_node_data_javascript(self): - """Test CodeNodeData with JavaScript language.""" - data = CodeNodeData( - title="JS Code Node", - variables=[], - code_language=CodeLanguage.JAVASCRIPT, - code="function main() { return { result: 'hello' }; }", - outputs={"result": CodeNodeData.Output(type=SegmentType.STRING)}, - ) - - assert data.code_language == CodeLanguage.JAVASCRIPT - assert "result" in data.outputs - assert data.outputs["result"].type == SegmentType.STRING - - def test_code_node_data_with_dependencies(self): - """Test CodeNodeData with dependencies.""" - data = CodeNodeData( - title="Code with Deps", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="import numpy as np\ndef main(): return {'sum': 10}", - outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)}, - dependencies=[ - CodeNodeData.Dependency(name="numpy", version="1.24.0"), - CodeNodeData.Dependency(name="pandas", version="2.0.0"), - ], - ) - - assert data.dependencies is not None - assert len(data.dependencies) == 2 - assert data.dependencies[0].name == "numpy" - assert data.dependencies[1].name == "pandas" - - def test_code_node_data_with_multiple_outputs(self): - """Test CodeNodeData with multiple outputs.""" - data = CodeNodeData( - title="Multi Output", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'name': 'test', 'count': 5, 'items': ['a', 'b']}", - outputs={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - "count": CodeNodeData.Output(type=SegmentType.NUMBER), - "items": CodeNodeData.Output(type=SegmentType.ARRAY_STRING), - }, - ) - - assert len(data.outputs) == 3 - assert data.outputs["name"].type == SegmentType.STRING - assert data.outputs["count"].type == SegmentType.NUMBER - assert data.outputs["items"].type == SegmentType.ARRAY_STRING - - def test_code_node_data_with_object_output(self): - """Test CodeNodeData with nested object output.""" - data = CodeNodeData( - title="Object Output", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'user': {'name': 'John', 'age': 30}}", - outputs={ - "user": CodeNodeData.Output( - type=SegmentType.OBJECT, - children={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - "age": CodeNodeData.Output(type=SegmentType.NUMBER), - }, - ), - }, - ) - - assert data.outputs["user"].type == SegmentType.OBJECT - assert data.outputs["user"].children is not None - assert len(data.outputs["user"].children) == 2 - - def test_code_node_data_with_array_object_output(self): - """Test CodeNodeData with array of objects output.""" - data = CodeNodeData( - title="Array Object Output", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'users': [{'name': 'A'}, {'name': 'B'}]}", - outputs={ - "users": CodeNodeData.Output( - type=SegmentType.ARRAY_OBJECT, - children={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - }, - ), - }, - ) - - assert data.outputs["users"].type == SegmentType.ARRAY_OBJECT - assert data.outputs["users"].children is not None - - def test_code_node_data_empty_code(self): - """Test CodeNodeData with empty code.""" - data = CodeNodeData( - title="Empty Code", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="", - outputs={}, - ) - - assert data.code == "" - assert len(data.outputs) == 0 - - def test_code_node_data_multiline_code(self): - """Test CodeNodeData with multiline code.""" - multiline_code = """ -def main(): - result = 0 - for i in range(10): - result += i - return {'sum': result} -""" - data = CodeNodeData( - title="Multiline Code", - variables=[], - code_language=CodeLanguage.PYTHON3, - code=multiline_code, - outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)}, - ) - - assert "for i in range(10)" in data.code - assert "result += i" in data.code - - def test_code_node_data_with_special_characters_in_code(self): - """Test CodeNodeData with special characters in code.""" - code_with_special = "def main(): return {'msg': 'Hello\\nWorld\\t!'}" - data = CodeNodeData( - title="Special Chars", - variables=[], - code_language=CodeLanguage.PYTHON3, - code=code_with_special, - outputs={"msg": CodeNodeData.Output(type=SegmentType.STRING)}, - ) - - assert "\\n" in data.code - assert "\\t" in data.code - - def test_code_node_data_with_unicode_in_code(self): - """Test CodeNodeData with unicode characters in code.""" - unicode_code = "def main(): return {'greeting': '你好世界'}" - data = CodeNodeData( - title="Unicode Code", - variables=[], - code_language=CodeLanguage.PYTHON3, - code=unicode_code, - outputs={"greeting": CodeNodeData.Output(type=SegmentType.STRING)}, - ) - - assert "你好世界" in data.code - - def test_code_node_data_empty_dependencies_list(self): - """Test CodeNodeData with empty dependencies list.""" - data = CodeNodeData( - title="No Deps", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {}", - outputs={}, - dependencies=[], - ) - - assert data.dependencies is not None - assert len(data.dependencies) == 0 - - def test_code_node_data_with_boolean_array_output(self): - """Test CodeNodeData with boolean array output.""" - data = CodeNodeData( - title="Boolean Array", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'flags': [True, False, True]}", - outputs={"flags": CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN)}, - ) - - assert data.outputs["flags"].type == SegmentType.ARRAY_BOOLEAN - - def test_code_node_data_with_number_array_output(self): - """Test CodeNodeData with number array output.""" - data = CodeNodeData( - title="Number Array", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'values': [1, 2, 3, 4, 5]}", - outputs={"values": CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER)}, - ) - - assert data.outputs["values"].type == SegmentType.ARRAY_NUMBER diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index 1d76067ec2..fb03ae9998 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,7 +1,8 @@ +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent + from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent class _VarSeg: diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py deleted file mode 100644 index f1a48f49b9..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py +++ /dev/null @@ -1,33 +0,0 @@ -from graphon.nodes.http_request import build_http_request_config - - -def test_build_http_request_config_uses_literal_defaults(): - config = build_http_request_config() - - assert config.max_connect_timeout == 10 - assert config.max_read_timeout == 600 - assert config.max_write_timeout == 600 - assert config.max_binary_size == 10 * 1024 * 1024 - assert config.max_text_size == 1 * 1024 * 1024 - assert config.ssl_verify is True - assert config.ssrf_default_max_retries == 3 - - -def test_build_http_request_config_supports_explicit_overrides(): - config = build_http_request_config( - max_connect_timeout=5, - max_read_timeout=30, - max_write_timeout=40, - max_binary_size=2048, - max_text_size=1024, - ssl_verify=False, - ssrf_default_max_retries=8, - ) - - assert config.max_connect_timeout == 5 - assert config.max_read_timeout == 30 - assert config.max_write_timeout == 40 - assert config.max_binary_size == 2048 - assert config.max_text_size == 1024 - assert config.ssl_verify is False - assert config.ssrf_default_max_retries == 8 diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py deleted file mode 100644 index 88895608d9..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py +++ /dev/null @@ -1,233 +0,0 @@ -import json -from unittest.mock import Mock, PropertyMock, patch - -import httpx -import pytest - -from graphon.nodes.http_request.entities import Response - - -@pytest.fixture -def mock_response(): - response = Mock(spec=httpx.Response) - response.headers = {} - return response - - -def test_is_file_with_attachment_disposition(mock_response): - """Test is_file when content-disposition header contains 'attachment'""" - mock_response.headers = {"content-disposition": "attachment; filename=test.pdf", "content-type": "application/pdf"} - response = Response(mock_response) - assert response.is_file - - -def test_is_file_with_filename_disposition(mock_response): - """Test is_file when content-disposition header contains filename parameter""" - mock_response.headers = {"content-disposition": "inline; filename=test.pdf", "content-type": "application/pdf"} - response = Response(mock_response) - assert response.is_file - - -@pytest.mark.parametrize("content_type", ["application/pdf", "image/jpeg", "audio/mp3", "video/mp4"]) -def test_is_file_with_file_content_types(mock_response, content_type): - """Test is_file with various file content types""" - mock_response.headers = {"content-type": content_type} - # Mock binary content - type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) - response = Response(mock_response) - assert response.is_file, f"Content type {content_type} should be identified as a file" - - -@pytest.mark.parametrize( - "content_type", - [ - "application/json", - "application/xml", - "application/javascript", - "application/x-www-form-urlencoded", - "application/yaml", - "application/graphql", - ], -) -def test_text_based_application_types(mock_response, content_type): - """Test common text-based application types are not identified as files""" - mock_response.headers = {"content-type": content_type} - response = Response(mock_response) - assert not response.is_file, f"Content type {content_type} should not be identified as a file" - - -@pytest.mark.parametrize( - ("content", "content_type"), - [ - (b'{"key": "value"}', "application/octet-stream"), - (b"[1, 2, 3]", "application/unknown"), - (b"function test() {}", "application/x-unknown"), - (b"test", "application/binary"), - (b"var x = 1;", "application/data"), - ], -) -def test_content_based_detection(mock_response, content, content_type): - """Test content-based detection for text-like content""" - mock_response.headers = {"content-type": content_type} - type(mock_response).content = PropertyMock(return_value=content) - response = Response(mock_response) - assert not response.is_file, f"Content {content} with type {content_type} should not be identified as a file" - - -@pytest.mark.parametrize( - ("content", "content_type"), - [ - (bytes([0x00, 0xFF] * 512), "application/octet-stream"), - (bytes([0x89, 0x50, 0x4E, 0x47]), "application/unknown"), # PNG magic numbers - (bytes([0xFF, 0xD8, 0xFF]), "application/binary"), # JPEG magic numbers - ], -) -def test_binary_content_detection(mock_response, content, content_type): - """Test content-based detection for binary content""" - mock_response.headers = {"content-type": content_type} - type(mock_response).content = PropertyMock(return_value=content) - response = Response(mock_response) - assert response.is_file, f"Binary content with type {content_type} should be identified as a file" - - -@pytest.mark.parametrize( - ("content_type", "expected_main_type"), - [ - ("x-world/x-vrml", "model"), # VRML 3D model - ("font/ttf", "application"), # TrueType font - ("text/csv", "text"), # CSV text file - ("unknown/xyz", None), # Unknown type - ], -) -def test_mimetype_based_detection(mock_response, content_type, expected_main_type): - """Test detection using mimetypes.guess_type for non-application content types""" - mock_response.headers = {"content-type": content_type} - type(mock_response).content = PropertyMock(return_value=bytes([0x00])) # Dummy content - - with patch("graphon.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type: - # Mock the return value based on expected_main_type - if expected_main_type: - mock_guess_type.return_value = (f"{expected_main_type}/subtype", None) - else: - mock_guess_type.return_value = (None, None) - - response = Response(mock_response) - - # Check if the result matches our expectation - if expected_main_type in ("application", "image", "audio", "video"): - assert response.is_file, f"Content type {content_type} should be identified as a file" - else: - assert not response.is_file, f"Content type {content_type} should not be identified as a file" - - # Verify that guess_type was called - mock_guess_type.assert_called_once() - - -def test_is_file_with_inline_disposition(mock_response): - """Test is_file when content-disposition is 'inline'""" - mock_response.headers = {"content-disposition": "inline", "content-type": "application/pdf"} - # Mock binary content - type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) - response = Response(mock_response) - assert response.is_file - - -def test_is_file_with_no_content_disposition(mock_response): - """Test is_file when no content-disposition header is present""" - mock_response.headers = {"content-type": "application/pdf"} - # Mock binary content - type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) - response = Response(mock_response) - assert response.is_file - - -# UTF-8 Encoding Tests -@pytest.mark.parametrize( - ("content_bytes", "expected_text", "description"), - [ - # Chinese UTF-8 bytes - ( - b'{"message": "\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c"}', - '{"message": "你好世界"}', - "Chinese characters UTF-8", - ), - # Japanese UTF-8 bytes - ( - b'{"message": "\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf"}', - '{"message": "こんにちは"}', - "Japanese characters UTF-8", - ), - # Korean UTF-8 bytes - ( - b'{"message": "\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94"}', - '{"message": "안녕하세요"}', - "Korean characters UTF-8", - ), - # Arabic UTF-8 - (b'{"text": "\xd9\x85\xd8\xb1\xd8\xad\xd8\xa8\xd8\xa7"}', '{"text": "مرحبا"}', "Arabic characters UTF-8"), - # European characters UTF-8 - (b'{"text": "Caf\xc3\xa9 M\xc3\xbcnchen"}', '{"text": "Café München"}', "European accented characters"), - # Simple ASCII - (b'{"text": "Hello World"}', '{"text": "Hello World"}', "Simple ASCII text"), - ], -) -def test_text_property_utf8_decoding(mock_response, content_bytes, expected_text, description): - """Test that Response.text properly decodes UTF-8 content with charset_normalizer""" - mock_response.headers = {"content-type": "application/json; charset=utf-8"} - type(mock_response).content = PropertyMock(return_value=content_bytes) - # Mock httpx response.text to return something different (simulating potential encoding issues) - mock_response.text = "incorrect-fallback-text" # To ensure we are not falling back to httpx's text property - - response = Response(mock_response) - - # Our enhanced text property should decode properly using charset_normalizer - assert response.text == expected_text, ( - f"Failed for {description}: got {repr(response.text)}, expected {repr(expected_text)}" - ) - - -def test_text_property_fallback_to_httpx(mock_response): - """Test that Response.text falls back to httpx.text when charset_normalizer fails""" - mock_response.headers = {"content-type": "application/json"} - - # Create malformed UTF-8 bytes - malformed_bytes = b'{"text": "\xff\xfe\x00\x00 invalid"}' - type(mock_response).content = PropertyMock(return_value=malformed_bytes) - - # Mock httpx.text to return some fallback value - fallback_text = '{"text": "fallback"}' - mock_response.text = fallback_text - - response = Response(mock_response) - - # Should fall back to httpx's text when charset_normalizer fails - assert response.text == fallback_text - - -@pytest.mark.parametrize( - ("json_content", "description"), - [ - # JSON with escaped Unicode (like Flask jsonify()) - ('{"message": "\\u4f60\\u597d\\u4e16\\u754c"}', "JSON with escaped Unicode"), - # JSON with mixed escape sequences and UTF-8 - ('{"mixed": "Hello \\u4f60\\u597d"}', "Mixed escaped and regular text"), - # JSON with complex escape sequences - ('{"complex": "\\ud83d\\ude00\\u4f60\\u597d"}', "Emoji and Chinese escapes"), - ], -) -def test_text_property_with_escaped_unicode(mock_response, json_content, description): - """Test Response.text with JSON containing Unicode escape sequences""" - mock_response.headers = {"content-type": "application/json"} - - content_bytes = json_content.encode("utf-8") - type(mock_response).content = PropertyMock(return_value=content_bytes) - mock_response.text = json_content # httpx would return the same for valid UTF-8 - - response = Response(mock_response) - - # Should preserve the escape sequences (valid JSON) - assert response.text == json_content, f"Failed for {description}" - - # The text should be valid JSON that can be parsed back to proper Unicode - parsed = json.loads(response.text) - assert isinstance(parsed, dict), f"Invalid JSON for {description}" diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index be7cc073db..a5026b40cf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,8 +1,4 @@ import pytest - -from configs import dify_config -from core.helper.ssrf_proxy import ssrf_proxy -from core.workflow.system_variables import default_system_variables from graphon.file.file_manager import file_manager from graphon.nodes.http_request import ( BodyData, @@ -16,6 +12,10 @@ from graphon.nodes.http_request.exc import AuthorizationConfigError from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool +from configs import dify_config +from core.helper.ssrf_proxy import ssrf_proxy +from core.workflow.system_variables import default_system_variables + HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index a3cadc0681..4705b3f76e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -3,17 +3,17 @@ from typing import Any import httpx import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response -from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index 1d6a4da7c4..d16e1233ac 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,6 +1,7 @@ -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients from graphon.runtime import VariablePool +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients + def test_render_body_template_replaces_variable_values(): config = EmailDeliveryConfig( diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 5f28a07606..a2cdbbf132 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -2,14 +2,41 @@ Unit tests for human input node entities. """ +from collections.abc import Mapping +from dataclasses import dataclass, field +from datetime import datetime, timedelta from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock import pytest +from graphon.entities import GraphInitParams +from graphon.node_events import PauseRequestedEvent +from graphon.node_events.node import StreamCompletedEvent +from graphon.nodes.human_input.entities import ( + FormInput, + FormInputDefault, + HumanInputNodeData, + UserAction, +) +from graphon.nodes.human_input.enums import ( + ButtonStyle, + FormInputType, + HumanInputFormStatus, + PlaceholderType, + TimeoutUnit, +) +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.runtime import GraphRuntimeState, VariablePool from pydantic import ValidationError from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from core.repositories.human_input_repository import HumanInputFormRepository +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRecipientEntity, + HumanInputFormRepository, +) from core.workflow.human_input_compat import ( DeliveryMethodType, EmailDeliveryConfig, @@ -23,24 +50,90 @@ from core.workflow.human_input_compat import ( ) from core.workflow.node_runtime import DifyHumanInputNodeRuntime from core.workflow.system_variables import build_system_variables -from graphon.entities import GraphInitParams -from graphon.node_events import PauseRequestedEvent -from graphon.node_events.node import StreamCompletedEvent -from graphon.nodes.human_input.entities import ( - FormInput, - FormInputDefault, - HumanInputNodeData, - UserAction, -) -from graphon.nodes.human_input.enums import ( - ButtonStyle, - FormInputType, - PlaceholderType, - TimeoutUnit, -) -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.runtime import GraphRuntimeState, VariablePool -from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository +from libs.datetime_utils import naive_utc_now + + +@dataclass +class _InMemoryFormEntity(HumanInputFormEntity): + form_id: str + rendered: str + token: str | None = None + action_id: str | None = None + data: Mapping[str, Any] | None = None + is_submitted: bool = False + status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING + expiration: datetime = field(default_factory=lambda: naive_utc_now() + timedelta(days=1)) + + @property + def id(self) -> str: + return self.form_id + + @property + def submission_token(self) -> str | None: + return self.token + + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: + return [] + + @property + def rendered_content(self) -> str: + return self.rendered + + @property + def selected_action_id(self) -> str | None: + return self.action_id + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return self.data + + @property + def submitted(self) -> bool: + return self.is_submitted + + @property + def status(self) -> HumanInputFormStatus: + return self.status_value + + @property + def expiration_time(self) -> datetime: + return self.expiration + + +class InMemoryHumanInputFormRepository(HumanInputFormRepository): + """Minimal in-memory repository for Dify-owned HumanInputNode behavior tests.""" + + def __init__(self) -> None: + self._form_counter = 0 + self.created_params: list[FormCreateParams] = [] + self.created_forms: list[_InMemoryFormEntity] = [] + self._forms_by_node_id: dict[str, _InMemoryFormEntity] = {} + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + self.created_params.append(params) + self._form_counter += 1 + form_id = f"form-{self._form_counter}" + entity = _InMemoryFormEntity( + form_id=form_id, + rendered=params.rendered_content, + token=f"token-{form_id}", + ) + self.created_forms.append(entity) + self._forms_by_node_id[params.node_id] = entity + return entity + + def get_form(self, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_node_id.get(node_id) + + def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: + if not self.created_forms: + raise AssertionError("no form has been created to attach submission data") + entity = self.created_forms[-1] + entity.action_id = action_id + entity.data = form_data or {} + entity.is_submitted = True + entity.status_value = HumanInputFormStatus.SUBMITTED class TestDeliveryMethod: diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index fc4497f010..52802c7ce1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,10 +1,7 @@ import datetime from types import SimpleNamespace -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import default_system_variables -from graphon.entities.graph_init_params import GraphInitParams +from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes from graphon.graph_events import ( NodeRunHumanInputFormFilledEvent, @@ -14,6 +11,10 @@ from graphon.graph_events import ( from graphon.nodes.human_input.enums import HumanInputFormStatus from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import default_system_variables from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py deleted file mode 100644 index 8cc91bdb54..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py +++ /dev/null @@ -1,339 +0,0 @@ -from graphon.nodes.iteration.entities import ( - ErrorHandleMode, - IterationNodeData, - IterationStartNodeData, - IterationState, -) - - -class TestErrorHandleMode: - """Test suite for ErrorHandleMode enum.""" - - def test_terminated_value(self): - """Test TERMINATED enum value.""" - assert ErrorHandleMode.TERMINATED == "terminated" - assert ErrorHandleMode.TERMINATED.value == "terminated" - - def test_continue_on_error_value(self): - """Test CONTINUE_ON_ERROR enum value.""" - assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error" - assert ErrorHandleMode.CONTINUE_ON_ERROR.value == "continue-on-error" - - def test_remove_abnormal_output_value(self): - """Test REMOVE_ABNORMAL_OUTPUT enum value.""" - assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT == "remove-abnormal-output" - assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT.value == "remove-abnormal-output" - - def test_error_handle_mode_is_str_enum(self): - """Test ErrorHandleMode is a string enum.""" - assert isinstance(ErrorHandleMode.TERMINATED, str) - assert isinstance(ErrorHandleMode.CONTINUE_ON_ERROR, str) - assert isinstance(ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, str) - - def test_error_handle_mode_comparison(self): - """Test ErrorHandleMode can be compared with strings.""" - assert ErrorHandleMode.TERMINATED == "terminated" - assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error" - - def test_all_error_handle_modes(self): - """Test all ErrorHandleMode values are accessible.""" - modes = list(ErrorHandleMode) - - assert len(modes) == 3 - assert ErrorHandleMode.TERMINATED in modes - assert ErrorHandleMode.CONTINUE_ON_ERROR in modes - assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT in modes - - -class TestIterationNodeData: - """Test suite for IterationNodeData model.""" - - def test_iteration_node_data_basic(self): - """Test IterationNodeData with basic configuration.""" - data = IterationNodeData( - title="Test Iteration", - iterator_selector=["node1", "output"], - output_selector=["iteration", "result"], - ) - - assert data.title == "Test Iteration" - assert data.iterator_selector == ["node1", "output"] - assert data.output_selector == ["iteration", "result"] - - def test_iteration_node_data_default_values(self): - """Test IterationNodeData default values.""" - data = IterationNodeData( - title="Default Test", - iterator_selector=["start", "items"], - output_selector=["iter", "out"], - ) - - assert data.parent_loop_id is None - assert data.is_parallel is False - assert data.parallel_nums == 10 - assert data.error_handle_mode == ErrorHandleMode.TERMINATED - assert data.flatten_output is True - - def test_iteration_node_data_parallel_mode(self): - """Test IterationNodeData with parallel mode enabled.""" - data = IterationNodeData( - title="Parallel Iteration", - iterator_selector=["node", "list"], - output_selector=["iter", "output"], - is_parallel=True, - parallel_nums=5, - ) - - assert data.is_parallel is True - assert data.parallel_nums == 5 - - def test_iteration_node_data_custom_parallel_nums(self): - """Test IterationNodeData with custom parallel numbers.""" - data = IterationNodeData( - title="Custom Parallel", - iterator_selector=["a", "b"], - output_selector=["c", "d"], - parallel_nums=20, - ) - - assert data.parallel_nums == 20 - - def test_iteration_node_data_continue_on_error(self): - """Test IterationNodeData with continue on error mode.""" - data = IterationNodeData( - title="Continue Error", - iterator_selector=["x", "y"], - output_selector=["z", "w"], - error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR, - ) - - assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR - - def test_iteration_node_data_remove_abnormal_output(self): - """Test IterationNodeData with remove abnormal output mode.""" - data = IterationNodeData( - title="Remove Abnormal", - iterator_selector=["input", "array"], - output_selector=["output", "result"], - error_handle_mode=ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, - ) - - assert data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT - - def test_iteration_node_data_flatten_output_disabled(self): - """Test IterationNodeData with flatten output disabled.""" - data = IterationNodeData( - title="No Flatten", - iterator_selector=["a"], - output_selector=["b"], - flatten_output=False, - ) - - assert data.flatten_output is False - - def test_iteration_node_data_with_parent_loop_id(self): - """Test IterationNodeData with parent loop ID.""" - data = IterationNodeData( - title="Nested Loop", - iterator_selector=["parent", "items"], - output_selector=["child", "output"], - parent_loop_id="parent_loop_123", - ) - - assert data.parent_loop_id == "parent_loop_123" - - def test_iteration_node_data_complex_selectors(self): - """Test IterationNodeData with complex selectors.""" - data = IterationNodeData( - title="Complex Selectors", - iterator_selector=["node1", "output", "data", "items"], - output_selector=["iteration", "result", "value"], - ) - - assert len(data.iterator_selector) == 4 - assert len(data.output_selector) == 3 - - def test_iteration_node_data_all_options(self): - """Test IterationNodeData with all options configured.""" - data = IterationNodeData( - title="Full Config", - iterator_selector=["start", "list"], - output_selector=["end", "result"], - parent_loop_id="outer_loop", - is_parallel=True, - parallel_nums=15, - error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR, - flatten_output=False, - ) - - assert data.title == "Full Config" - assert data.parent_loop_id == "outer_loop" - assert data.is_parallel is True - assert data.parallel_nums == 15 - assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR - assert data.flatten_output is False - - -class TestIterationStartNodeData: - """Test suite for IterationStartNodeData model.""" - - def test_iteration_start_node_data_basic(self): - """Test IterationStartNodeData basic creation.""" - data = IterationStartNodeData(title="Iteration Start") - - assert data.title == "Iteration Start" - - def test_iteration_start_node_data_with_description(self): - """Test IterationStartNodeData with description.""" - data = IterationStartNodeData( - title="Start Node", - desc="This is the start of iteration", - ) - - assert data.title == "Start Node" - assert data.desc == "This is the start of iteration" - - -class TestIterationState: - """Test suite for IterationState model.""" - - def test_iteration_state_default_values(self): - """Test IterationState default values.""" - state = IterationState() - - assert state.outputs == [] - assert state.current_output is None - - def test_iteration_state_with_outputs(self): - """Test IterationState with outputs.""" - state = IterationState(outputs=["result1", "result2", "result3"]) - - assert len(state.outputs) == 3 - assert state.outputs[0] == "result1" - assert state.outputs[2] == "result3" - - def test_iteration_state_with_current_output(self): - """Test IterationState with current output.""" - state = IterationState(current_output="current_value") - - assert state.current_output == "current_value" - - def test_iteration_state_get_last_output_with_outputs(self): - """Test get_last_output with outputs present.""" - state = IterationState(outputs=["first", "second", "last"]) - - result = state.get_last_output() - - assert result == "last" - - def test_iteration_state_get_last_output_empty(self): - """Test get_last_output with empty outputs.""" - state = IterationState(outputs=[]) - - result = state.get_last_output() - - assert result is None - - def test_iteration_state_get_last_output_single(self): - """Test get_last_output with single output.""" - state = IterationState(outputs=["only_one"]) - - result = state.get_last_output() - - assert result == "only_one" - - def test_iteration_state_get_current_output(self): - """Test get_current_output method.""" - state = IterationState(current_output={"key": "value"}) - - result = state.get_current_output() - - assert result == {"key": "value"} - - def test_iteration_state_get_current_output_none(self): - """Test get_current_output when None.""" - state = IterationState() - - result = state.get_current_output() - - assert result is None - - def test_iteration_state_with_complex_outputs(self): - """Test IterationState with complex output types.""" - state = IterationState( - outputs=[ - {"id": 1, "name": "first"}, - {"id": 2, "name": "second"}, - [1, 2, 3], - "string_output", - ] - ) - - assert len(state.outputs) == 4 - assert state.outputs[0] == {"id": 1, "name": "first"} - assert state.outputs[2] == [1, 2, 3] - - def test_iteration_state_with_none_outputs(self): - """Test IterationState with None values in outputs.""" - state = IterationState(outputs=["value1", None, "value3"]) - - assert len(state.outputs) == 3 - assert state.outputs[1] is None - - def test_iteration_state_get_last_output_with_none(self): - """Test get_last_output when last output is None.""" - state = IterationState(outputs=["first", None]) - - result = state.get_last_output() - - assert result is None - - def test_iteration_state_metadata_class(self): - """Test IterationState.MetaData class.""" - metadata = IterationState.MetaData(iterator_length=10) - - assert metadata.iterator_length == 10 - - def test_iteration_state_metadata_different_lengths(self): - """Test IterationState.MetaData with different lengths.""" - metadata1 = IterationState.MetaData(iterator_length=0) - metadata2 = IterationState.MetaData(iterator_length=100) - metadata3 = IterationState.MetaData(iterator_length=1000000) - - assert metadata1.iterator_length == 0 - assert metadata2.iterator_length == 100 - assert metadata3.iterator_length == 1000000 - - def test_iteration_state_outputs_modification(self): - """Test modifying IterationState outputs.""" - state = IterationState(outputs=[]) - - state.outputs.append("new_output") - state.outputs.append("another_output") - - assert len(state.outputs) == 2 - assert state.get_last_output() == "another_output" - - def test_iteration_state_current_output_update(self): - """Test updating current_output.""" - state = IterationState() - - state.current_output = "first_value" - assert state.get_current_output() == "first_value" - - state.current_output = "updated_value" - assert state.get_current_output() == "updated_value" - - def test_iteration_state_with_numeric_outputs(self): - """Test IterationState with numeric outputs.""" - state = IterationState(outputs=[1, 2, 3, 4, 5]) - - assert state.get_last_output() == 5 - assert len(state.outputs) == 5 - - def test_iteration_state_with_boolean_outputs(self): - """Test IterationState with boolean outputs.""" - state = IterationState(outputs=[True, False, True]) - - assert state.get_last_output() is True - assert state.outputs[1] is False diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py deleted file mode 100644 index 58b82aa893..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py +++ /dev/null @@ -1,438 +0,0 @@ -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from graphon.nodes.iteration.exc import ( - InvalidIteratorValueError, - IterationGraphNotFoundError, - IterationIndexNotFoundError, - IterationNodeError, - IteratorVariableNotFoundError, - StartNodeIdNotFoundError, -) -from graphon.nodes.iteration.iteration_node import IterationNode - - -class TestIterationNodeExceptions: - """Test suite for iteration node exceptions.""" - - def test_iteration_node_error_is_value_error(self): - """Test IterationNodeError inherits from ValueError.""" - error = IterationNodeError("test error") - - assert isinstance(error, ValueError) - assert str(error) == "test error" - - def test_iterator_variable_not_found_error(self): - """Test IteratorVariableNotFoundError.""" - error = IteratorVariableNotFoundError("Iterator variable not found") - - assert isinstance(error, IterationNodeError) - assert isinstance(error, ValueError) - assert "Iterator variable not found" in str(error) - - def test_invalid_iterator_value_error(self): - """Test InvalidIteratorValueError.""" - error = InvalidIteratorValueError("Invalid iterator value") - - assert isinstance(error, IterationNodeError) - assert "Invalid iterator value" in str(error) - - def test_start_node_id_not_found_error(self): - """Test StartNodeIdNotFoundError.""" - error = StartNodeIdNotFoundError("Start node ID not found") - - assert isinstance(error, IterationNodeError) - assert "Start node ID not found" in str(error) - - def test_iteration_graph_not_found_error(self): - """Test IterationGraphNotFoundError.""" - error = IterationGraphNotFoundError("Iteration graph not found") - - assert isinstance(error, IterationNodeError) - assert "Iteration graph not found" in str(error) - - def test_iteration_index_not_found_error(self): - """Test IterationIndexNotFoundError.""" - error = IterationIndexNotFoundError("Iteration index not found") - - assert isinstance(error, IterationNodeError) - assert "Iteration index not found" in str(error) - - def test_exception_with_empty_message(self): - """Test exception with empty message.""" - error = IterationNodeError("") - - assert str(error) == "" - - def test_exception_with_detailed_message(self): - """Test exception with detailed message.""" - error = IteratorVariableNotFoundError("Variable 'items' not found in node 'start_node'") - - assert "items" in str(error) - assert "start_node" in str(error) - - def test_all_exceptions_inherit_from_base(self): - """Test all exceptions inherit from IterationNodeError.""" - exceptions = [ - IteratorVariableNotFoundError("test"), - InvalidIteratorValueError("test"), - StartNodeIdNotFoundError("test"), - IterationGraphNotFoundError("test"), - IterationIndexNotFoundError("test"), - ] - - for exc in exceptions: - assert isinstance(exc, IterationNodeError) - assert isinstance(exc, ValueError) - - -class TestIterationNodeClassAttributes: - """Test suite for IterationNode class attributes.""" - - def test_node_type(self): - """Test IterationNode node_type attribute.""" - assert IterationNode.node_type == BuiltinNodeTypes.ITERATION - - def test_version(self): - """Test IterationNode version method.""" - version = IterationNode.version() - - assert version == "1" - - -class TestIterationNodeDefaultConfig: - """Test suite for IterationNode get_default_config.""" - - def test_get_default_config_returns_dict(self): - """Test get_default_config returns a dictionary.""" - config = IterationNode.get_default_config() - - assert isinstance(config, dict) - - def test_get_default_config_type(self): - """Test get_default_config includes type.""" - config = IterationNode.get_default_config() - - assert config.get("type") == "iteration" - - def test_get_default_config_has_config_section(self): - """Test get_default_config has config section.""" - config = IterationNode.get_default_config() - - assert "config" in config - assert isinstance(config["config"], dict) - - def test_get_default_config_is_parallel_default(self): - """Test get_default_config is_parallel default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["is_parallel"] is False - - def test_get_default_config_parallel_nums_default(self): - """Test get_default_config parallel_nums default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["parallel_nums"] == 10 - - def test_get_default_config_error_handle_mode_default(self): - """Test get_default_config error_handle_mode default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["error_handle_mode"] == ErrorHandleMode.TERMINATED - - def test_get_default_config_flatten_output_default(self): - """Test get_default_config flatten_output default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["flatten_output"] is True - - def test_get_default_config_with_none_filters(self): - """Test get_default_config with None filters.""" - config = IterationNode.get_default_config(filters=None) - - assert config is not None - assert "type" in config - - def test_get_default_config_with_empty_filters(self): - """Test get_default_config with empty filters.""" - config = IterationNode.get_default_config(filters={}) - - assert config is not None - - -class TestIterationNodeInitialization: - """Test suite for IterationNode initialization.""" - - def test_init_node_data_basic(self): - """Test init_node_data with basic configuration.""" - node = IterationNode.__new__(IterationNode) - data = { - "title": "Test Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration", "result"], - } - - node.init_node_data(data) - - assert node._node_data.title == "Test Iteration" - assert node._node_data.iterator_selector == ["start", "items"] - - def test_init_node_data_with_parallel(self): - """Test init_node_data with parallel configuration.""" - node = IterationNode.__new__(IterationNode) - data = { - "title": "Parallel Iteration", - "iterator_selector": ["node", "list"], - "output_selector": ["out", "result"], - "is_parallel": True, - "parallel_nums": 5, - } - - node.init_node_data(data) - - assert node._node_data.is_parallel is True - assert node._node_data.parallel_nums == 5 - - def test_init_node_data_with_error_handle_mode(self): - """Test init_node_data with error handle mode.""" - node = IterationNode.__new__(IterationNode) - data = { - "title": "Error Handle Test", - "iterator_selector": ["a", "b"], - "output_selector": ["c", "d"], - "error_handle_mode": "continue-on-error", - } - - node.init_node_data(data) - - assert node._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR - - def test_get_title(self): - """Test _get_title method.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="My Iteration", - iterator_selector=["x"], - output_selector=["y"], - ) - - assert node._get_title() == "My Iteration" - - def test_get_description_none(self): - """Test _get_description returns None when not set.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - assert node._get_description() is None - - def test_get_description_with_value(self): - """Test _get_description with value.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - desc="This is a description", - iterator_selector=["a"], - output_selector=["b"], - ) - - assert node._get_description() == "This is a description" - - def test_node_data_property(self): - """Test node_data property returns node data.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Base Test", - iterator_selector=["x"], - output_selector=["y"], - ) - - result = node.node_data - - assert result == node._node_data - - -class TestIterationNodeDataValidation: - """Test suite for IterationNodeData validation scenarios.""" - - def test_valid_iteration_node_data(self): - """Test valid IterationNodeData creation.""" - data = IterationNodeData( - title="Valid Iteration", - iterator_selector=["start", "items"], - output_selector=["end", "result"], - ) - - assert data.title == "Valid Iteration" - - def test_iteration_node_data_with_all_error_modes(self): - """Test IterationNodeData with all error handle modes.""" - modes = [ - ErrorHandleMode.TERMINATED, - ErrorHandleMode.CONTINUE_ON_ERROR, - ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, - ] - - for mode in modes: - data = IterationNodeData( - title=f"Test {mode}", - iterator_selector=["a"], - output_selector=["b"], - error_handle_mode=mode, - ) - assert data.error_handle_mode == mode - - def test_iteration_node_data_parallel_configuration(self): - """Test IterationNodeData parallel configuration combinations.""" - configs = [ - (False, 10), - (True, 1), - (True, 5), - (True, 20), - (True, 100), - ] - - for is_parallel, parallel_nums in configs: - data = IterationNodeData( - title="Parallel Test", - iterator_selector=["x"], - output_selector=["y"], - is_parallel=is_parallel, - parallel_nums=parallel_nums, - ) - assert data.is_parallel == is_parallel - assert data.parallel_nums == parallel_nums - - def test_iteration_node_data_flatten_output_options(self): - """Test IterationNodeData flatten_output options.""" - data_flatten = IterationNodeData( - title="Flatten True", - iterator_selector=["a"], - output_selector=["b"], - flatten_output=True, - ) - - data_no_flatten = IterationNodeData( - title="Flatten False", - iterator_selector=["a"], - output_selector=["b"], - flatten_output=False, - ) - - assert data_flatten.flatten_output is True - assert data_no_flatten.flatten_output is False - - def test_iteration_node_data_complex_selectors(self): - """Test IterationNodeData with complex selectors.""" - data = IterationNodeData( - title="Complex", - iterator_selector=["node1", "output", "data", "items", "list"], - output_selector=["iteration", "result", "value", "final"], - ) - - assert len(data.iterator_selector) == 5 - assert len(data.output_selector) == 4 - - def test_iteration_node_data_single_element_selectors(self): - """Test IterationNodeData with single element selectors.""" - data = IterationNodeData( - title="Single", - iterator_selector=["items"], - output_selector=["result"], - ) - - assert len(data.iterator_selector) == 1 - assert len(data.output_selector) == 1 - - -class TestIterationNodeErrorStrategies: - """Test suite for IterationNode error strategies.""" - - def test_get_error_strategy_default(self): - """Test _get_error_strategy with default value.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - result = node._get_error_strategy() - - assert result is None or result == node._node_data.error_strategy - - def test_get_retry_config(self): - """Test _get_retry_config method.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - result = node._get_retry_config() - - assert result is not None - - def test_get_default_value_dict(self): - """Test _get_default_value_dict method.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - result = node._get_default_value_dict() - - assert isinstance(result, dict) - - -def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: - seen_configs: list[object] = [] - original_validate_python = NodeConfigDictAdapter.validate_python - - def record_validate_python(value: object): - seen_configs.append(value) - return original_validate_python(value) - - monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) - - child_node_config = { - "id": "answer-node", - "data": { - "type": "answer", - "title": "Answer", - "answer": "", - "iteration_id": "iteration-node", - }, - } - - IterationNode._extract_variable_selector_to_variable_mapping( - graph_config={ - "nodes": [ - { - "id": "iteration-node", - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration", "result"], - }, - }, - child_node_config, - ], - "edges": [], - }, - node_id="iteration-node", - node_data=IterationNodeData( - title="Iteration", - iterator_selector=["start", "items"], - output_selector=["iteration", "result"], - ), - ) - - assert seen_configs == [child_node_config] diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py deleted file mode 100644 index 4c3ad85fcd..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py +++ /dev/null @@ -1,201 +0,0 @@ -from threading import Event -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph_events import GraphRunAbortedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import IterationFailedEvent, IterationStartedEvent, StreamCompletedEvent -from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from graphon.nodes.iteration.exc import ChildGraphAbortedError -from graphon.nodes.iteration.iteration_node import IterationNode -from tests.workflow_test_utils import build_test_variable_pool - - -def _usage_with_tokens(total_tokens: int) -> LLMUsage: - usage = LLMUsage.empty_usage() - usage.total_tokens = total_tokens - return usage - - -class _AbortOnRequestGraphEngine: - def __init__(self, *, index: int, total_tokens: int) -> None: - variable_pool = build_test_variable_pool() - variable_pool.add(["iteration-node", "index"], index) - - self.started = Event() - self.abort_requested = Event() - self.finished = Event() - self.abort_reason: str | None = None - self.graph_runtime_state = SimpleNamespace( - variable_pool=variable_pool, - llm_usage=_usage_with_tokens(total_tokens), - ) - - def request_abort(self, reason: str | None = None) -> None: - self.abort_reason = reason - self.abort_requested.set() - - def run(self): - self.started.set() - assert self.abort_requested.wait(1), "parallel sibling never received an abort request" - self.finished.set() - yield GraphRunAbortedEvent(reason=self.abort_reason) - - -def _build_immediate_abort_graph_engine( - *, - index: int, - total_tokens: int, - wait_before_abort: Event | None = None, -) -> SimpleNamespace: - variable_pool = build_test_variable_pool() - variable_pool.add(["iteration-node", "index"], index) - - started = Event() - finished = Event() - - def run(): - started.set() - if wait_before_abort is not None: - assert wait_before_abort.wait(1), "parallel sibling never started" - finished.set() - yield GraphRunAbortedEvent(reason="quota exceeded") - - return SimpleNamespace( - graph_runtime_state=SimpleNamespace( - variable_pool=variable_pool, - llm_usage=_usage_with_tokens(total_tokens), - ), - run=run, - request_abort=lambda reason=None: None, - started=started, - finished=finished, - ) - - -def _build_iteration_node( - *, - error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED, - is_parallel: bool = False, -) -> IterationNode: - node = IterationNode.__new__(IterationNode) - node._node_id = "iteration-node" - node._node_data = IterationNodeData( - title="Iteration", - iterator_selector=["start", "items"], - output_selector=["iteration-node", "output"], - start_node_id="child-start", - is_parallel=is_parallel, - parallel_nums=2, - error_handle_mode=error_handle_mode, - ) - - variable_pool = build_test_variable_pool() - variable_pool.add(["start", "items"], ["first", "second"]) - node.graph_runtime_state = SimpleNamespace( - variable_pool=variable_pool, - llm_usage=LLMUsage.empty_usage(), - ) - return node - - -def test_run_single_iter_raises_child_graph_aborted_error_on_abort_event() -> None: - node = _build_iteration_node() - variable_pool = build_test_variable_pool() - variable_pool.add(["iteration-node", "index"], 0) - graph_engine = SimpleNamespace( - run=lambda: iter([GraphRunAbortedEvent(reason="quota exceeded")]), - ) - - with pytest.raises(ChildGraphAbortedError, match="quota exceeded"): - list( - node._run_single_iter( - variable_pool=variable_pool, - outputs=[], - graph_engine=graph_engine, - ) - ) - - -def test_iteration_run_fails_on_sequential_child_abort() -> None: - node = _build_iteration_node(error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR) - graph_engine = SimpleNamespace( - graph_runtime_state=SimpleNamespace( - variable_pool=build_test_variable_pool(), - llm_usage=LLMUsage.empty_usage(), - ) - ) - node._create_graph_engine = MagicMock(return_value=graph_engine) - node._run_single_iter = MagicMock(side_effect=ChildGraphAbortedError("quota exceeded")) - - events = list(node._run()) - - assert isinstance(events[0], IterationStartedEvent) - assert isinstance(events[-2], IterationFailedEvent) - assert events[-2].error == "quota exceeded" - assert isinstance(events[-1], StreamCompletedEvent) - assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED - assert events[-1].node_run_result.error == "quota exceeded" - node._create_graph_engine.assert_called_once() - node._run_single_iter.assert_called_once() - - -def test_iteration_run_merges_child_usage_before_failing_on_sequential_child_abort() -> None: - node = _build_iteration_node(error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR) - graph_engine = SimpleNamespace( - graph_runtime_state=SimpleNamespace( - variable_pool=build_test_variable_pool(), - llm_usage=_usage_with_tokens(7), - ) - ) - node._create_graph_engine = MagicMock(return_value=graph_engine) - node._run_single_iter = MagicMock(side_effect=ChildGraphAbortedError("quota exceeded")) - - events = list(node._run()) - - assert isinstance(events[-1], StreamCompletedEvent) - assert events[-1].node_run_result.llm_usage.total_tokens == 7 - assert node.graph_runtime_state.llm_usage.total_tokens == 7 - - -@pytest.mark.parametrize( - "error_handle_mode", - [ - ErrorHandleMode.CONTINUE_ON_ERROR, - ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, - ], -) -def test_iteration_run_fails_on_parallel_child_abort_regardless_of_error_mode( - error_handle_mode: ErrorHandleMode, -) -> None: - node = _build_iteration_node( - error_handle_mode=error_handle_mode, - is_parallel=True, - ) - blocking_engine = _AbortOnRequestGraphEngine(index=1, total_tokens=5) - aborting_engine = _build_immediate_abort_graph_engine( - index=0, - total_tokens=3, - wait_before_abort=blocking_engine.started, - ) - node._create_graph_engine = MagicMock( - side_effect=lambda index, item: {0: aborting_engine, 1: blocking_engine}[index] - ) - - events = list(node._run()) - - assert isinstance(events[0], IterationStartedEvent) - assert isinstance(events[-2], IterationFailedEvent) - assert events[-2].error == "quota exceeded" - assert isinstance(events[-1], StreamCompletedEvent) - assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED - assert events[-1].node_run_result.error == "quota exceeded" - assert events[-1].node_run_result.llm_usage.total_tokens == 8 - assert node.graph_runtime_state.llm_usage.total_tokens == 8 - assert blocking_engine.started.is_set() - assert blocking_engine.abort_requested.is_set() - assert blocking_engine.finished.is_set() - assert blocking_engine.abort_reason == "quota exceeded" diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index 82cc734274..bbfe350f7e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -2,8 +2,6 @@ from collections.abc import Mapping from typing import Any import pytest - -from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.nodes.iteration.exc import IterationGraphNotFoundError from graphon.nodes.iteration.iteration_node import IterationNode @@ -13,6 +11,8 @@ from graphon.runtime import ( GraphRuntimeState, VariablePool, ) + +from core.workflow.system_variables import default_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py deleted file mode 100644 index 41d7c3193d..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py +++ /dev/null @@ -1,67 +0,0 @@ -import time -from datetime import UTC, datetime - -import pytest - -from graphon.enums import BuiltinNodeTypes -from graphon.graph_events import NodeRunSucceededEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from graphon.nodes.iteration.iteration_node import IterationNode - - -def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Parallel Iteration", - iterator_selector=["start", "items"], - output_selector=["iteration", "output"], - is_parallel=True, - parallel_nums=2, - error_handle_mode=ErrorHandleMode.TERMINATED, - ) - node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new) - - def fake_execute_tracked_iteration_parallel( - *, - index: int, - item: object, - started_child_engines: dict[int, object], - started_child_engines_lock: object, - ): - _ = started_child_engines - _ = started_child_engines_lock - return ( - 0.1 + (index * 0.1), - [ - NodeRunSucceededEvent( - id=f"exec-{index}", - node_id=f"llm-{index}", - node_type=BuiltinNodeTypes.LLM, - start_at=datetime.now(UTC).replace(tzinfo=None), - ), - ], - f"output-{item}", - LLMUsage.empty_usage(), - ) - - node._execute_tracked_iteration_parallel = fake_execute_tracked_iteration_parallel - - outputs: list[object] = [] - iter_run_map: dict[str, float] = {} - usage_accumulator = [LLMUsage.empty_usage()] - - generator = node._execute_parallel_iterations( - iterator_list_value=["a", "b"], - outputs=outputs, - iter_run_map=iter_run_map, - usage_accumulator=usage_accumulator, - ) - - for _ in generator: - # Simulate a slow consumer replaying buffered events. - time.sleep(0.02) - - assert outputs == ["output-a", "output-b"] - assert iter_run_map["0"] == pytest.approx(0.1) - assert iter_run_map["1"] == pytest.approx(0.2) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index a6fca1bfb4..f8802138b5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -3,6 +3,9 @@ import uuid from unittest.mock import Mock import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.segments import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -16,9 +19,6 @@ from core.workflow.nodes.knowledge_index.protocols import ( SummaryIndexServiceProtocol, ) from core.workflow.system_variables import SystemVariableKey, build_system_variables -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables.segments import StringSegment from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 45e8ae7d20..ab64be59ad 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -3,6 +3,10 @@ import uuid from unittest.mock import Mock import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.nodes.knowledge_retrieval.entities import ( @@ -17,10 +21,6 @@ from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source from core.workflow.system_variables import build_system_variables -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import StringSegment from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index eca34f05be..fdf1706765 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,14 +1,14 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.nodes.list_operator.node import ListOperatorNode from graphon.runtime import GraphRuntimeState from graphon.variables import ArrayNumberSegment, ArrayStringSegment +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY + class TestListOperatorNode: """Comprehensive tests for ListOperatorNode.""" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py deleted file mode 100644 index 4f9ba0194a..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ /dev/null @@ -1,170 +0,0 @@ -import uuid -from typing import NamedTuple -from unittest import mock -from unittest.mock import MagicMock - -import httpx -import pytest - -from graphon.file import FileTransferMethod, FileType -from graphon.nodes.llm.file_saver import ( - FileSaverImpl, - _extract_content_type_and_extension, - _get_extension, - _validate_extension_override, -) -from graphon.nodes.protocols import ToolFileManagerProtocol - -_PNG_DATA = b"\x89PNG\r\n\x1a\n" - - -def _gen_id(): - return str(uuid.uuid4()) - - -class TestFileSaverImpl: - def test_save_binary_string(self, monkeypatch: pytest.MonkeyPatch): - file_type = FileType.IMAGE - mime_type = "image/png" - mock_tool_file = MagicMock() - mock_tool_file.id = _gen_id() - mock_tool_file.name = f"{_gen_id()}.png" - mock_tool_file.file_key = "test-file-key" - mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManagerProtocol) - mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file - file_reference = MagicMock() - file_reference_factory = MagicMock() - file_reference_factory.build_from_mapping.return_value = file_reference - http_client = MagicMock() - - file_saver = FileSaverImpl( - tool_file_manager=mocked_tool_file_manager, - file_reference_factory=file_reference_factory, - http_client=http_client, - ) - - file = file_saver.save_binary_string(_PNG_DATA, mime_type, file_type) - assert file is file_reference - - mocked_tool_file_manager.create_file_by_raw.assert_called_once_with( - file_binary=_PNG_DATA, - mimetype=mime_type, - ) - file_reference_factory.build_from_mapping.assert_called_once_with( - mapping={ - "type": file_type, - "transfer_method": FileTransferMethod.TOOL_FILE, - "filename": mock_tool_file.name, - "extension": ".png", - "mime_type": mime_type, - "size": len(_PNG_DATA), - "tool_file_id": mock_tool_file.id, - "related_id": mock_tool_file.id, - "storage_key": mock_tool_file.file_key, - } - ) - - def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch): - _TEST_URL = "https://example.com/image.png" - mock_request = httpx.Request("GET", _TEST_URL) - mock_response = httpx.Response( - status_code=401, - request=mock_request, - ) - http_client = MagicMock() - http_client.get.return_value = mock_response - - file_saver = FileSaverImpl( - tool_file_manager=MagicMock(), - file_reference_factory=MagicMock(), - http_client=http_client, - ) - - with pytest.raises(httpx.HTTPStatusError) as exc: - file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) - http_client.get.assert_called_once_with(_TEST_URL) - assert exc.value.response.status_code == 401 - - def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch): - _TEST_URL = "https://example.com/image.png" - mime_type = "image/png" - - mock_request = httpx.Request("GET", _TEST_URL) - mock_response = httpx.Response( - status_code=200, - content=b"test-data", - headers={"Content-Type": mime_type}, - request=mock_request, - ) - http_client = MagicMock() - http_client.get.return_value = mock_response - - file_saver = FileSaverImpl( - tool_file_manager=MagicMock(), - file_reference_factory=MagicMock(), - http_client=http_client, - ) - expected_file = MagicMock() - mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=expected_file) - monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string) - - file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) - mock_save_binary_string.assert_called_once_with( - mock_response.content, - mime_type, - FileType.IMAGE, - extension_override=".png", - ) - assert file is expected_file - - -def test_validate_extension_override(): - class TestCase(NamedTuple): - extension_override: str | None - expected: str | None - - cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"] - - for valid_ext_override in [None, "", ".png", ".tar.gz"]: - assert valid_ext_override == _validate_extension_override(valid_ext_override) - - for invalid_ext_override in ["png", "tar.gz"]: - with pytest.raises(ValueError) as exc: - _validate_extension_override(invalid_ext_override) - - -class TestExtractContentTypeAndExtension: - def test_with_both_content_type_and_extension(self): - content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png") - assert content_type == "image/png" - assert extension == ".png" - - def test_url_with_file_extension(self): - for content_type in [None, ""]: - content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type) - assert content_type == "image/png" - assert extension == ".png" - - def test_response_with_content_type(self): - content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png") - assert content_type == "image/png" - assert extension == ".png" - - def test_no_content_type_and_no_extension(self): - for content_type in [None, ""]: - content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type) - assert content_type == "application/octet-stream" - assert extension == ".bin" - - -class TestGetExtension: - def test_with_extension_override(self): - mime_type = "image/png" - for override in [".jpg", ""]: - extension = _get_extension(mime_type, override) - assert extension == override - - def test_without_extension_override(self): - mime_type = "image/png" - extension = _get_extension(mime_type) - assert extension == ".png" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index dfc982f49c..c784f805c0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -1,10 +1,7 @@ from unittest import mock import pytest - -from core.model_manager import ModelInstance -from graphon.file import FileTransferMethod, FileType -from graphon.file.models import File +from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities import ( ImagePromptMessageContent, PromptMessageRole, @@ -36,6 +33,8 @@ from graphon.nodes.llm.exc import ( from graphon.runtime import VariablePool from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment +from core.model_manager import ModelInstance + def _build_model_schema( *, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index a2fbc50392..a215e9d350 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -4,19 +4,6 @@ from collections.abc import Sequence from unittest import mock import pytest - -from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom -from core.app.llm.model_access import ( - DifyCredentialsProvider, - DifyModelFactory, - build_dify_model_access, - fetch_model_config, -) -from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.common_entities import I18nObject @@ -79,6 +66,19 @@ from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment + +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom +from core.app.llm.model_access import ( + DifyCredentialsProvider, + DifyModelFactory, + build_dify_model_access, + fetch_model_config, +) +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.workflow.system_variables import default_system_variables from models.provider import ProviderType from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py deleted file mode 100644 index af1cff4e81..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py +++ /dev/null @@ -1,25 +0,0 @@ -from collections.abc import Mapping, Sequence - -from pydantic import BaseModel, Field - -from graphon.file import File -from graphon.model_runtime.entities.message_entities import PromptMessage -from graphon.model_runtime.entities.model_entities import ModelFeature -from graphon.nodes.llm.entities import LLMNodeChatModelMessage - - -class LLMNodeTestScenario(BaseModel): - """Test scenario for LLM node testing.""" - - description: str = Field(..., description="Description of the test scenario") - sys_query: str = Field(..., description="User query input") - sys_files: Sequence[File] = Field(default_factory=list, description="List of user files") - vision_enabled: bool = Field(default=False, description="Whether vision is enabled") - vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") - features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") - window_size: int = Field(..., description="Window size for memory") - prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") - file_variables: Mapping[str, File | Sequence[File]] = Field( - default_factory=dict, description="List of file variables" - ) - expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing") diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py deleted file mode 100644 index ccf1077838..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py +++ /dev/null @@ -1,27 +0,0 @@ -from graphon.nodes.parameter_extractor.entities import ParameterConfig -from graphon.variables.types import SegmentType - - -class TestParameterConfig: - def test_select_type(self): - data = { - "name": "yes_or_no", - "type": "select", - "options": ["yes", "no"], - "description": "a simple select made of `yes` and `no`", - "required": True, - } - - pc = ParameterConfig.model_validate(data) - assert pc.type == SegmentType.STRING - assert pc.options == data["options"] - - def test_validate_bool_type(self): - data = { - "name": "boolean", - "type": "bool", - "description": "a simple boolean parameter", - "required": True, - } - pc = ParameterConfig.model_validate(data) - assert pc.type == SegmentType.BOOLEAN diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index 8f8ec49f14..1c362a0a03 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -6,8 +6,6 @@ from dataclasses import dataclass from typing import Any import pytest - -from factories.variable_factory import build_segment_with_type from graphon.model_runtime.entities import LLMMode from graphon.nodes.llm import ModelConfig, VisionConfig from graphon.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData @@ -20,6 +18,8 @@ from graphon.nodes.parameter_extractor.exc import ( from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from graphon.variables.types import SegmentType +from factories.variable_factory import build_segment_with_type + @dataclass class ValidTestCase: diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py deleted file mode 100644 index 01878ed692..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py +++ /dev/null @@ -1,225 +0,0 @@ -import pytest -from pydantic import ValidationError - -from graphon.enums import ErrorStrategy -from graphon.nodes.template_transform.entities import TemplateTransformNodeData - - -class TestTemplateTransformNodeData: - """Test suite for TemplateTransformNodeData entity.""" - - def test_valid_template_transform_node_data(self): - """Test creating valid TemplateTransformNodeData.""" - data = { - "title": "Template Transform", - "desc": "Transform data using Jinja2 template", - "variables": [ - {"variable": "name", "value_selector": ["sys", "user_name"]}, - {"variable": "age", "value_selector": ["sys", "user_age"]}, - ], - "template": "Hello {{ name }}, you are {{ age }} years old!", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.title == "Template Transform" - assert node_data.desc == "Transform data using Jinja2 template" - assert len(node_data.variables) == 2 - assert node_data.variables[0].variable == "name" - assert node_data.variables[0].value_selector == ["sys", "user_name"] - assert node_data.variables[1].variable == "age" - assert node_data.variables[1].value_selector == ["sys", "user_age"] - assert node_data.template == "Hello {{ name }}, you are {{ age }} years old!" - - def test_template_transform_node_data_with_empty_variables(self): - """Test TemplateTransformNodeData with no variables.""" - data = { - "title": "Static Template", - "variables": [], - "template": "This is a static template with no variables.", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.title == "Static Template" - assert len(node_data.variables) == 0 - assert node_data.template == "This is a static template with no variables." - - def test_template_transform_node_data_with_complex_template(self): - """Test TemplateTransformNodeData with complex Jinja2 template.""" - data = { - "title": "Complex Template", - "variables": [ - {"variable": "items", "value_selector": ["sys", "item_list"]}, - {"variable": "total", "value_selector": ["sys", "total_count"]}, - ], - "template": ( - "{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}. Total: {{ total }}" - ), - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.title == "Complex Template" - assert len(node_data.variables) == 2 - assert "{% for item in items %}" in node_data.template - assert "{{ total }}" in node_data.template - - def test_template_transform_node_data_with_error_strategy(self): - """Test TemplateTransformNodeData with error handling strategy.""" - data = { - "title": "Template with Error Handling", - "variables": [{"variable": "value", "value_selector": ["sys", "input"]}], - "template": "{{ value }}", - "error_strategy": "fail-branch", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.error_strategy == ErrorStrategy.FAIL_BRANCH - - def test_template_transform_node_data_with_retry_config(self): - """Test TemplateTransformNodeData with retry configuration.""" - data = { - "title": "Template with Retry", - "variables": [{"variable": "data", "value_selector": ["sys", "data"]}], - "template": "{{ data }}", - "retry_config": {"enabled": True, "max_retries": 3, "retry_interval": 1000}, - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.retry_config.enabled is True - assert node_data.retry_config.max_retries == 3 - assert node_data.retry_config.retry_interval == 1000 - - def test_template_transform_node_data_missing_required_fields(self): - """Test that missing required fields raises ValidationError.""" - data = { - "title": "Incomplete Template", - # Missing 'variables' and 'template' - } - - with pytest.raises(ValidationError) as exc_info: - TemplateTransformNodeData.model_validate(data) - - errors = exc_info.value.errors() - assert len(errors) >= 2 - error_fields = {error["loc"][0] for error in errors} - assert "variables" in error_fields - assert "template" in error_fields - - def test_template_transform_node_data_invalid_variable_selector(self): - """Test that invalid variable selector format raises ValidationError.""" - data = { - "title": "Invalid Variable", - "variables": [ - {"variable": "name", "value_selector": "invalid_format"} # Should be list - ], - "template": "{{ name }}", - } - - with pytest.raises(ValidationError): - TemplateTransformNodeData.model_validate(data) - - def test_template_transform_node_data_with_default_value_dict(self): - """Test TemplateTransformNodeData with default value dictionary.""" - data = { - "title": "Template with Defaults", - "variables": [ - {"variable": "name", "value_selector": ["sys", "user_name"]}, - {"variable": "greeting", "value_selector": ["sys", "greeting"]}, - ], - "template": "{{ greeting }} {{ name }}!", - "default_value_dict": {"greeting": "Hello", "name": "Guest"}, - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.default_value_dict == {"greeting": "Hello", "name": "Guest"} - - def test_template_transform_node_data_with_nested_selectors(self): - """Test TemplateTransformNodeData with nested variable selectors.""" - data = { - "title": "Nested Selectors", - "variables": [ - {"variable": "user_info", "value_selector": ["sys", "user", "profile", "name"]}, - {"variable": "settings", "value_selector": ["sys", "config", "app", "theme"]}, - ], - "template": "User: {{ user_info }}, Theme: {{ settings }}", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert len(node_data.variables) == 2 - assert node_data.variables[0].value_selector == ["sys", "user", "profile", "name"] - assert node_data.variables[1].value_selector == ["sys", "config", "app", "theme"] - - def test_template_transform_node_data_with_multiline_template(self): - """Test TemplateTransformNodeData with multiline template.""" - data = { - "title": "Multiline Template", - "variables": [ - {"variable": "title", "value_selector": ["sys", "title"]}, - {"variable": "content", "value_selector": ["sys", "content"]}, - ], - "template": """ -# {{ title }} - -{{ content }} - ---- -Generated by Template Transform Node - """, - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert "# {{ title }}" in node_data.template - assert "{{ content }}" in node_data.template - assert "Generated by Template Transform Node" in node_data.template - - def test_template_transform_node_data_serialization(self): - """Test that TemplateTransformNodeData can be serialized and deserialized.""" - original_data = { - "title": "Serialization Test", - "desc": "Test serialization", - "variables": [{"variable": "test", "value_selector": ["sys", "test"]}], - "template": "{{ test }}", - } - - node_data = TemplateTransformNodeData.model_validate(original_data) - serialized = node_data.model_dump() - deserialized = TemplateTransformNodeData.model_validate(serialized) - - assert deserialized.title == node_data.title - assert deserialized.desc == node_data.desc - assert len(deserialized.variables) == len(node_data.variables) - assert deserialized.template == node_data.template - - def test_template_transform_node_data_with_special_characters(self): - """Test TemplateTransformNodeData with special characters in template.""" - data = { - "title": "Special Characters", - "variables": [{"variable": "text", "value_selector": ["sys", "input"]}], - "template": "Special: {{ text }} | Symbols: @#$%^&*() | Unicode: 你好 🎉", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert "@#$%^&*()" in node_data.template - assert "你好" in node_data.template - assert "🎉" in node_data.template - - def test_template_transform_node_data_empty_template(self): - """Test TemplateTransformNodeData with empty template string.""" - data = { - "title": "Empty Template", - "variables": [], - "template": "", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.template == "" - assert len(node_data.variables) == 0 diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index bc44ececd8..d86e0efe02 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -1,8 +1,6 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.nodes.base.entities import VariableSelector @@ -10,6 +8,8 @@ from graphon.nodes.template_transform.entities import TemplateTransformNodeData from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from graphon.runtime import GraphRuntimeState from graphon.template_rendering import TemplateRenderError + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py index 636237e56e..bd22a8e318 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -1,14 +1,14 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.nodes.base.entities import VariableSelector from graphon.nodes.template_transform.template_transform_node import ( DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH, TemplateTransformNode, ) from graphon.runtime import GraphRuntimeState + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params from .template_transform_node_spec import TestTemplateTransformNode # noqa: F401 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 0522dd9d14..e11ebf6eb8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -1,16 +1,16 @@ from collections.abc import Mapping import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_runtime import resolve_dify_run_context -from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams from graphon.entities.base_node_data import BaseNodeData from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.node import Node from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_runtime import resolve_dify_run_context +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 87ec2d5bce..555ff0c945 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -4,8 +4,6 @@ from unittest.mock import Mock, patch import pandas as pd import pytest from docx.oxml.text.paragraph import CT_P - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod @@ -21,6 +19,8 @@ from graphon.nodes.document_extractor.node import ( from graphon.variables import ArrayFileSegment from graphon.variables.segments import ArrayStringSegment from graphon.variables.variables import StringVariable + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 782750e02e..1b14f0ab13 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -3,11 +3,6 @@ import uuid from unittest.mock import MagicMock, Mock import pytest - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.graph import Graph @@ -16,6 +11,11 @@ from graphon.nodes.if_else.if_else_node import IfElseNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.utils.condition.entities import Condition, SubCondition, SubVariableCondition from graphon.variables import ArrayFileSegment + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index b217e4e8e7..d28c3e01e5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -1,8 +1,6 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.nodes.list_operator.entities import ( @@ -18,6 +16,8 @@ from graphon.nodes.list_operator.exc import InvalidKeyError from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func from graphon.variables import ArrayFileSegment +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom + @pytest.fixture def list_operator_node(): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py deleted file mode 100644 index d613ba154a..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py +++ /dev/null @@ -1,150 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph_events import GraphRunAbortedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import LoopFailedEvent, LoopStartedEvent, StreamCompletedEvent -from graphon.nodes.loop.entities import LoopNodeData -from graphon.nodes.loop.loop_node import LoopNode -from tests.workflow_test_utils import build_test_variable_pool - - -def _usage_with_tokens(total_tokens: int) -> LLMUsage: - usage = LLMUsage.empty_usage() - usage.total_tokens = total_tokens - return usage - - -def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: - seen_configs: list[object] = [] - original_validate_python = NodeConfigDictAdapter.validate_python - - def record_validate_python(value: object): - seen_configs.append(value) - return original_validate_python(value) - - monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) - - child_node_config = { - "id": "answer-node", - "data": { - "type": "answer", - "title": "Answer", - "answer": "", - "loop_id": "loop-node", - }, - } - - LoopNode._extract_variable_selector_to_variable_mapping( - graph_config={ - "nodes": [ - { - "id": "loop-node", - "data": { - "type": "loop", - "title": "Loop", - "loop_count": 1, - "break_conditions": [], - "logical_operator": "and", - }, - }, - child_node_config, - ], - "edges": [], - }, - node_id="loop-node", - node_data=LoopNodeData( - title="Loop", - loop_count=1, - break_conditions=[], - logical_operator="and", - ), - ) - - assert seen_configs == [child_node_config] - - -def test_run_single_loop_raises_on_child_abort_event() -> None: - node = LoopNode.__new__(LoopNode) - node._node_id = "loop-node" - node._node_data = LoopNodeData( - title="Loop", - loop_count=1, - break_conditions=[], - logical_operator="and", - start_node_id="child-start", - ) - - graph_engine = SimpleNamespace( - run=lambda: iter([GraphRunAbortedEvent(reason="quota exceeded")]), - ) - - with pytest.raises(RuntimeError, match="quota exceeded"): - list(node._run_single_loop(graph_engine=graph_engine, current_index=0)) - - -def test_loop_run_fails_on_child_abort_and_stops_subsequent_rounds() -> None: - node = LoopNode.__new__(LoopNode) - node._node_id = "loop-node" - node._node_data = LoopNodeData( - title="Loop", - loop_count=2, - break_conditions=[], - logical_operator="and", - start_node_id="child-start", - ) - node.graph_config = {"nodes": [], "edges": []} - node.graph_runtime_state = SimpleNamespace( - variable_pool=build_test_variable_pool(), - llm_usage=LLMUsage.empty_usage(), - ) - - aborting_engine = SimpleNamespace( - graph_runtime_state=SimpleNamespace(outputs={}, llm_usage=LLMUsage.empty_usage()), - ) - create_graph_engine = MagicMock(return_value=aborting_engine) - node._create_graph_engine = create_graph_engine - node._run_single_loop = lambda *, graph_engine, current_index: (_ for _ in ()).throw(RuntimeError("quota exceeded")) - - events = list(node._run()) - - assert isinstance(events[0], LoopStartedEvent) - assert isinstance(events[1], LoopFailedEvent) - assert events[1].error == "quota exceeded" - assert isinstance(events[2], StreamCompletedEvent) - assert events[2].node_run_result.status == WorkflowNodeExecutionStatus.FAILED - assert events[2].node_run_result.error == "quota exceeded" - create_graph_engine.assert_called_once() - - -def test_loop_run_merges_child_usage_before_failing_on_child_abort() -> None: - node = LoopNode.__new__(LoopNode) - node._node_id = "loop-node" - node._node_data = LoopNodeData( - title="Loop", - loop_count=1, - break_conditions=[], - logical_operator="and", - start_node_id="child-start", - ) - node.graph_config = {"nodes": [], "edges": []} - node.graph_runtime_state = SimpleNamespace( - variable_pool=build_test_variable_pool(), - llm_usage=LLMUsage.empty_usage(), - ) - - aborting_engine = SimpleNamespace( - graph_runtime_state=SimpleNamespace(outputs={}, llm_usage=_usage_with_tokens(7)), - ) - node._create_graph_engine = MagicMock(return_value=aborting_engine) - node._run_single_loop = lambda *, graph_engine, current_index: (_ for _ in ()).throw(RuntimeError("quota exceeded")) - - events = list(node._run()) - - assert isinstance(events[-1], StreamCompletedEvent) - assert events[-1].node_run_result.llm_usage.total_tokens == 7 - assert node.graph_runtime_state.llm_usage.total_tokens == 7 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py deleted file mode 100644 index efbf786a55..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py +++ /dev/null @@ -1,126 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock - -from graphon.model_runtime.entities import ImagePromptMessageContent -from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory -from graphon.nodes.protocols import HttpClientProtocol -from graphon.nodes.question_classifier import ( - QuestionClassifierNode, - QuestionClassifierNodeData, -) -from graphon.template_rendering import Jinja2TemplateRenderer -from tests.workflow_test_utils import build_test_graph_init_params - - -def test_init_question_classifier_node_data(): - data = { - "title": "test classifier node", - "query_variable_selector": ["id", "name"], - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, - "classes": [{"id": "1", "name": "class 1"}], - "instruction": "This is a test instruction", - "memory": { - "role_prefix": {"user": "Human:", "assistant": "AI:"}, - "window": {"enabled": True, "size": 5}, - "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", - }, - "vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}}, - } - - node_data = QuestionClassifierNodeData.model_validate(data) - - assert node_data.query_variable_selector == ["id", "name"] - assert node_data.model.provider == "openai" - assert node_data.classes[0].id == "1" - assert node_data.instruction == "This is a test instruction" - assert node_data.memory is not None - assert node_data.memory.role_prefix is not None - assert node_data.memory.role_prefix.user == "Human:" - assert node_data.memory.role_prefix.assistant == "AI:" - assert node_data.memory.window.enabled == True - assert node_data.memory.window.size == 5 - assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" - assert node_data.vision.enabled == True - assert node_data.vision.configs.variable_selector == ["image"] - assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.LOW - - -def test_init_question_classifier_node_data_without_vision_config(): - data = { - "title": "test classifier node", - "query_variable_selector": ["id", "name"], - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, - "classes": [{"id": "1", "name": "class 1"}], - "instruction": "This is a test instruction", - "memory": { - "role_prefix": {"user": "Human:", "assistant": "AI:"}, - "window": {"enabled": True, "size": 5}, - "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", - }, - } - - node_data = QuestionClassifierNodeData.model_validate(data) - - assert node_data.query_variable_selector == ["id", "name"] - assert node_data.model.provider == "openai" - assert node_data.classes[0].id == "1" - assert node_data.instruction == "This is a test instruction" - assert node_data.memory is not None - assert node_data.memory.role_prefix is not None - assert node_data.memory.role_prefix.user == "Human:" - assert node_data.memory.role_prefix.assistant == "AI:" - assert node_data.memory.window.enabled == True - assert node_data.memory.window.size == 5 - assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" - assert node_data.vision.enabled == False - assert node_data.vision.configs.variable_selector == ["sys", "files"] - assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH - - -def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch): - node_data = QuestionClassifierNodeData.model_validate( - { - "title": "test classifier node", - "query_variable_selector": ["id", "name"], - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, - "classes": [{"id": "1", "name": "class 1"}], - "instruction": "This is a test instruction", - } - ) - template_renderer = MagicMock(spec=Jinja2TemplateRenderer) - node = QuestionClassifierNode( - id="node-id", - config={"id": "node-id", "data": node_data.model_dump(mode="json")}, - graph_init_params=build_test_graph_init_params( - workflow_id="workflow-id", - graph_config={}, - tenant_id="tenant-id", - app_id="app-id", - user_id="user-id", - ), - graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()), - credentials_provider=MagicMock(spec=CredentialsProvider), - model_factory=MagicMock(spec=ModelFactory), - model_instance=MagicMock(), - http_client=MagicMock(spec=HttpClientProtocol), - llm_file_saver=MagicMock(), - template_renderer=template_renderer, - ) - fetch_prompt_messages = MagicMock(return_value=([], None)) - monkeypatch.setattr( - "graphon.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages", - fetch_prompt_messages, - ) - monkeypatch.setattr( - "graphon.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema", - MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])), - ) - - node._calculate_rest_token( - node_data=node_data, - query="hello", - model_instance=MagicMock(stop=(), parameters={}), - context="", - ) - - assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 543f9878de..833c303052 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -2,16 +2,16 @@ import json import time import pytest -from pydantic import ValidationError as PydanticValidationError - -from core.workflow.system_variables import build_system_variables -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState from graphon.variables import build_segment, segment_to_variable from graphon.variables.input_entities import VariableEntity, VariableEntityType from graphon.variables.variables import Variable +from pydantic import ValidationError as PydanticValidationError + +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index c806181340..1587014802 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -8,14 +8,14 @@ from typing import TYPE_CHECKING, Any from unittest.mock import MagicMock import pytest - -from core.workflow.system_variables import build_system_variables from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.segments import ArrayFileSegment + +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if TYPE_CHECKING: # pragma: no cover - imported for type checking only diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py index 438af211f3..c4dfc5a179 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -6,6 +6,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType +from graphon.nodes.tool.exc import ToolRuntimeInvocationError +from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from graphon.runtime import VariablePool from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError @@ -17,11 +22,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.system_variables import build_system_variables -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType -from graphon.nodes.tool.exc import ToolRuntimeInvocationError -from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage -from graphon.runtime import VariablePool from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index c8ddc53284..952e798430 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -1,12 +1,13 @@ from collections.abc import Mapping -from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode -from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.runtime import GraphRuntimeState + +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py deleted file mode 100644 index fabc8df73e..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ /dev/null @@ -1,312 +0,0 @@ -import time -import uuid -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool -from graphon.entities import GraphInitParams -from graphon.graph import Graph -from graphon.graph_events.node import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent -from graphon.nodes.variable_assigner.common import helpers as common_helpers -from graphon.nodes.variable_assigner.v1 import VariableAssignerNode -from graphon.nodes.variable_assigner.v1.node_data import WriteMode -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import ArrayStringVariable, StringVariable - -DEFAULT_NODE_ID = "node_id" - - -def _build_variable_pool( - *, - conversation_id: str, - conversation_variables: list[StringVariable | ArrayStringVariable], -) -> VariablePool: - variable_pool = VariablePool() - add_variables_to_pool( - variable_pool, - build_bootstrap_variables( - system_variables=build_system_variables(conversation_id=conversation_id), - conversation_variables=conversation_variables, - ), - ) - return variable_pool - - -def test_overwrite_string_variable(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "over-write", - "input_variable_selector": ["node_id", "test_string_variable"], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = StringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value="the first value", - ) - - input_variable = StringVariable( - id=str(uuid4()), - name="test_string_variable", - value="the second value", - ) - conversation_id = str(uuid.uuid4()) - - # construct variable pool - variable_pool = _build_variable_pool( - conversation_id=conversation_id, - conversation_variables=[conversation_variable], - ) - - variable_pool.add( - [DEFAULT_NODE_ID, input_variable.name], - input_variable, - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.OVER_WRITE, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - events = list(node.run()) - updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) - succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) - updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) - assert updated_variables is not None - assert updated_variables[0].name == conversation_variable.name - assert updated_variables[0].new_value == input_variable.value - assert updated_event.variable.value == "the second value" - assert tuple(updated_event.variable.selector) == ("conversation", conversation_variable.name) - - -def test_append_variable_to_array(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "append", - "input_variable_selector": ["node_id", "test_string_variable"], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["the first value"], - ) - - input_variable = StringVariable( - id=str(uuid4()), - name="test_string_variable", - value="the second value", - ) - conversation_id = str(uuid.uuid4()) - - variable_pool = _build_variable_pool( - conversation_id=conversation_id, - conversation_variables=[conversation_variable], - ) - variable_pool.add( - [DEFAULT_NODE_ID, input_variable.name], - input_variable, - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.APPEND, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - events = list(node.run()) - updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) - succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) - updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) - assert updated_variables is not None - assert updated_variables[0].name == conversation_variable.name - assert updated_variables[0].new_value == ["the first value", "the second value"] - assert updated_event.variable.value == ["the first value", "the second value"] - - -def test_clear_array(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "clear", - "input_variable_selector": [], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["the first value"], - ) - - conversation_id = str(uuid.uuid4()) - variable_pool = _build_variable_pool( - conversation_id=conversation_id, - conversation_variables=[conversation_variable], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.CLEAR, - "input_variable_selector": [], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - events = list(node.run()) - updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) - succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) - updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) - assert updated_variables is not None - assert updated_variables[0].name == conversation_variable.name - assert updated_variables[0].new_value == [] - assert updated_event.variable.value == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py deleted file mode 100644 index 8b13789179..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py deleted file mode 100644 index 9ac8bbe9c2..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py +++ /dev/null @@ -1,22 +0,0 @@ -from graphon.nodes.variable_assigner.v2.enums import Operation -from graphon.nodes.variable_assigner.v2.helpers import is_input_value_valid -from graphon.variables import SegmentType - - -def test_is_input_value_valid_overwrite_array_string(): - # Valid cases - assert is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["hello", "world"] - ) - assert is_input_value_valid(variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[]) - - # Invalid cases - assert not is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value="not an array" - ) - assert not is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[1, 2, 3] - ) - assert not is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["valid", 123, "invalid"] - ) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py deleted file mode 100644 index 53346c4a90..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ /dev/null @@ -1,430 +0,0 @@ -import time -import uuid -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool -from graphon.entities import GraphInitParams -from graphon.graph import Graph -from graphon.graph_events import NodeRunVariableUpdatedEvent -from graphon.nodes.variable_assigner.v2 import VariableAssignerNode -from graphon.nodes.variable_assigner.v2.enums import InputType, Operation -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import ArrayStringVariable - -DEFAULT_NODE_ID = "node_id" - - -def _build_variable_pool(*, conversation_variables: list[ArrayStringVariable]) -> VariablePool: - variable_pool = VariablePool() - add_variables_to_pool( - variable_pool, - build_bootstrap_variables( - system_variables=build_system_variables(conversation_id="conversation_id"), - conversation_variables=conversation_variables, - ), - ) - return variable_pool - - -def test_handle_item_directly(): - """Test the _handle_item method directly for remove operations.""" - # Create variables - variable1 = ArrayStringVariable( - id=str(uuid4()), - name="test_variable1", - value=["first", "second", "third"], - ) - - variable2 = ArrayStringVariable( - id=str(uuid4()), - name="test_variable2", - value=["first", "second", "third"], - ) - - # Create a mock class with just the _handle_item method - class MockNode: - def _handle_item(self, *, variable, operation, value): - match operation: - case Operation.REMOVE_FIRST: - if not variable.value: - return variable.value - return variable.value[1:] - case Operation.REMOVE_LAST: - if not variable.value: - return variable.value - return variable.value[:-1] - - node = MockNode() - - # Test remove-first - result1 = node._handle_item( - variable=variable1, - operation=Operation.REMOVE_FIRST, - value=None, - ) - - # Test remove-last - result2 = node._handle_item( - variable=variable2, - operation=Operation.REMOVE_LAST, - value=None, - ) - - # Check the results - assert result1 == ["second", "third"] - assert result2 == ["first", "second"] - - -def test_remove_first_from_array(): - """Test removing the first element from an array.""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["first", "second", "third"], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_FIRST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - # Run the node - result = list(node.run()) - - updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) - assert updated_event.variable.value == ["second", "third"] - - -def test_remove_last_from_array(): - """Test removing the last element from an array.""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["first", "second", "third"], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_LAST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - result = list(node.run()) - updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) - assert updated_event.variable.value == ["first", "second"] - - -def test_remove_first_from_empty_array(): - """Test removing the first element from an empty array (should do nothing).""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=[], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_FIRST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - result = list(node.run()) - updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) - assert updated_event.variable.value == [] - - -def test_remove_last_from_empty_array(): - """Test removing the last element from an empty array (should do nothing).""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=[], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_LAST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - result = list(node.run()) - updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) - assert updated_event.variable.value == [] - - -def test_node_factory_creates_variable_assigner_node(): - graph_config = { - "edges": [], - "nodes": [ - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - variable_pool = _build_variable_pool(conversation_variables=[]) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - - node = node_factory.create_node(graph_config["nodes"][0]) - - assert isinstance(node, VariableAssignerNode) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index 617554ee17..f1132af02b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -1,4 +1,5 @@ import pytest +from graphon.entities.exc import BaseNodeError from core.workflow.nodes.trigger_webhook.exc import ( WebhookConfigError, @@ -6,7 +7,6 @@ from core.workflow.nodes.trigger_webhook.exc import ( WebhookNotFoundError, WebhookTimeoutError, ) -from graphon.entities.exc import BaseNodeError def test_webhook_node_error_inheritance(): diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 6fbd26131d..cccd3fb676 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -8,6 +8,10 @@ when passing files to downstream LLM nodes. from unittest.mock import Mock, patch +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.nodes.trigger_webhook.entities import ( ContentType, @@ -17,10 +21,6 @@ from core.workflow.nodes.trigger_webhook.entities import ( ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode from core.workflow.system_variables import default_system_variables -from graphon.entities.graph_init_params import GraphInitParams -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.runtime.graph_runtime_state import GraphRuntimeState -from graphon.runtime.variable_pool import VariablePool from tests.workflow_test_utils import build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 9f954b2090..34c66a4f9f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -1,6 +1,11 @@ from unittest.mock import patch import pytest +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import FileVariable, StringVariable from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE @@ -13,12 +18,6 @@ from core.workflow.nodes.trigger_webhook.entities import ( ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode from core.workflow.system_variables import default_system_variables -from graphon.entities.graph_init_params import GraphInitParams -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType -from graphon.runtime.graph_runtime_state import GraphRuntimeState -from graphon.runtime.variable_pool import VariablePool -from graphon.variables import FileVariable, StringVariable from tests.workflow_test_utils import build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/test_enums.py b/api/tests/unit_tests/core/workflow/test_enums.py deleted file mode 100644 index 453e0a8502..0000000000 --- a/api/tests/unit_tests/core/workflow/test_enums.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Tests for workflow pause related enums and constants.""" - -from graphon.enums import ( - WorkflowExecutionStatus, -) - - -class TestWorkflowExecutionStatus: - """Test WorkflowExecutionStatus enum.""" - - def test_is_ended_method(self): - """Test is_ended method for different statuses.""" - # Test ended statuses - ended_statuses = [ - WorkflowExecutionStatus.SUCCEEDED, - WorkflowExecutionStatus.FAILED, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED, - WorkflowExecutionStatus.STOPPED, - ] - - for status in ended_statuses: - assert status.is_ended(), f"{status} should be considered ended" - - # Test non-ended statuses - non_ended_statuses = [ - WorkflowExecutionStatus.SCHEDULED, - WorkflowExecutionStatus.RUNNING, - WorkflowExecutionStatus.PAUSED, - ] - - for status in non_ended_statuses: - assert not status.is_ended(), f"{status} should not be considered ended" - - def test_ended_values(self): - """Test ended_values returns the expected status values.""" - assert set(WorkflowExecutionStatus.ended_values()) == { - WorkflowExecutionStatus.SUCCEEDED.value, - WorkflowExecutionStatus.FAILED.value, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value, - WorkflowExecutionStatus.STOPPED.value, - } diff --git a/api/tests/unit_tests/core/workflow/test_human_input_compat.py b/api/tests/unit_tests/core/workflow/test_human_input_compat.py index 0623800b30..cd41c43e4a 100644 --- a/api/tests/unit_tests/core/workflow/test_human_input_compat.py +++ b/api/tests/unit_tests/core/workflow/test_human_input_compat.py @@ -1,5 +1,6 @@ from types import SimpleNamespace +from graphon.enums import BuiltinNodeTypes from pydantic import BaseModel from core.workflow.human_input_compat import ( @@ -15,7 +16,6 @@ from core.workflow.human_input_compat import ( normalize_node_data_for_graph, parse_human_input_delivery_methods, ) -from graphon.enums import BuiltinNodeTypes def test_email_delivery_config_helpers_render_and_sanitize_text() -> None: diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index 1db848a010..bc0b339fec 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -2,15 +2,15 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch, sentinel import pytest +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.code.entities import CodeLanguage +from graphon.variables.segments import StringSegment from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext, InvokeFrom, UserFrom from core.workflow import node_factory from core.workflow import template_rendering as workflow_template_rendering from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.code.entities import CodeLanguage -from graphon.variables.segments import StringSegment def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None: diff --git a/api/tests/unit_tests/core/workflow/test_node_runtime.py b/api/tests/unit_tests/core/workflow/test_node_runtime.py index 71a2afb28a..4f9c1dad59 100644 --- a/api/tests/unit_tests/core/workflow/test_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/test_node_runtime.py @@ -2,6 +2,10 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, sentinel import pytest +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.nodes.human_input.entities import HumanInputNodeData from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext, InvokeFrom, UserFrom from core.llm_generator.output_parser.errors import OutputParserError @@ -26,10 +30,6 @@ from core.workflow.node_runtime import ( build_dify_llm_file_saver, resolve_dify_run_context, ) -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.nodes.human_input.entities import HumanInputNodeData from tests.workflow_test_utils import build_test_run_context diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py index 72a0557b7c..05ea3dc311 100644 --- a/api/tests/unit_tests/core/workflow/test_system_variable.py +++ b/api/tests/unit_tests/core/workflow/test_system_variable.py @@ -1,14 +1,14 @@ from types import SimpleNamespace +from graphon.file import File, FileTransferMethod, FileType +from graphon.nodes import BuiltinNodeTypes + from core.workflow.system_variables import ( build_system_variables, default_system_variables, get_node_creation_preload_selectors, system_variables_to_mapping, ) -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File -from graphon.nodes import BuiltinNodeTypes def test_build_system_variables_normalizes_workflow_execution_id(): diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index dddd6eb00c..e7b2b2914a 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -2,15 +2,6 @@ import uuid from collections import defaultdict import pytest - -from core.workflow.system_variables import build_system_variables, system_variables_to_mapping -from core.workflow.variable_pool_initializer import add_variables_to_pool -from core.workflow.variable_prefixes import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, - SYSTEM_VARIABLE_NODE_ID, -) -from factories.variable_factory import build_segment, segment_to_variable from graphon.file import File, FileTransferMethod, FileType from graphon.runtime import VariablePool from graphon.variables import FileSegment, StringSegment @@ -36,6 +27,15 @@ from graphon.variables.variables import ( Variable, ) +from core.workflow.system_variables import build_system_variables, system_variables_to_mapping +from core.workflow.variable_pool_initializer import add_variables_to_pool +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) +from factories.variable_factory import build_segment, segment_to_variable + @pytest.fixture def pool(): diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 4ae6ed1659..d8361d06c4 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -1,6 +1,12 @@ from types import SimpleNamespace import pytest +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.file import File, FileTransferMethod, FileType +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.runtime import VariablePool +from graphon.variables.variables import StringVariable from configs import dify_config from core.helper.code_executor.code_executor import CodeLanguage @@ -10,13 +16,6 @@ from core.workflow.variable_prefixes import ( ENVIRONMENT_VARIABLE_NODE_ID, ) from core.workflow.workflow_entry import WorkflowEntry -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.file.enums import FileType -from graphon.file.models import File, FileTransferMethod -from graphon.nodes.code.code_node import CodeNode -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.runtime import VariablePool -from graphon.variables.variables import StringVariable @pytest.fixture(autouse=True) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index 456ab5da41..879c0bb721 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -4,18 +4,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch, sentinel import pytest - -from core.app.apps.exc import GenerateTaskStoppedError -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.model_manager import ModelInstance -from core.workflow import workflow_entry -from core.workflow.system_variables import default_system_variables from graphon.entities.base_node_data import BaseNodeData from graphon.entities.graph_config import NodeConfigDictAdapter from graphon.enums import NodeType, WorkflowNodeExecutionStatus from graphon.errors import WorkflowNodeRunFailedError -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File +from graphon.file import File, FileTransferMethod, FileType from graphon.graph import Graph from graphon.graph_events import GraphRunFailedEvent from graphon.model_runtime.entities.llm_entities import LLMUsage @@ -24,6 +17,12 @@ from graphon.nodes import BuiltinNodeTypes from graphon.nodes.base.node import Node from graphon.runtime import ChildGraphNotFoundError, VariablePool from graphon.variables.variables import StringVariable + +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.model_manager import ModelInstance +from core.workflow import workflow_entry +from core.workflow.system_variables import default_system_variables from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py index b3ecfe4bc9..4b2f98aeff 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py @@ -2,10 +2,11 @@ from unittest.mock import MagicMock, patch +from graphon.graph_engine.command_channels import RedisChannel +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.workflow_entry import WorkflowEntry -from graphon.graph_engine.command_channels.redis_channel import RedisChannel -from graphon.runtime import GraphRuntimeState, VariablePool class TestWorkflowEntryRedisChannel: diff --git a/api/tests/unit_tests/core/workflow/utils/test_condition.py b/api/tests/unit_tests/core/workflow/utils/test_condition.py deleted file mode 100644 index f4c86aa77a..0000000000 --- a/api/tests/unit_tests/core/workflow/utils/test_condition.py +++ /dev/null @@ -1,52 +0,0 @@ -from graphon.runtime import VariablePool -from graphon.utils.condition.entities import Condition -from graphon.utils.condition.processor import ConditionProcessor - - -def test_number_formatting(): - condition_processor = ConditionProcessor() - variable_pool = VariablePool() - variable_pool.add(["test_node_id", "zone"], 0) - variable_pool.add(["test_node_id", "one"], 1) - variable_pool.add(["test_node_id", "one_one"], 1.1) - # 0 <= 0.95 - assert ( - condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=[Condition(variable_selector=["test_node_id", "zone"], comparison_operator="≤", value="0.95")], - operator="or", - ).final_result - == True - ) - - # 1 >= 0.95 - assert ( - condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=[Condition(variable_selector=["test_node_id", "one"], comparison_operator="≥", value="0.95")], - operator="or", - ).final_result - == True - ) - - # 1.1 >= 0.95 - assert ( - condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=[ - Condition(variable_selector=["test_node_id", "one_one"], comparison_operator="≥", value="0.95") - ], - operator="or", - ).final_result - == True - ) - - # 1.1 > 0 - assert ( - condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=[Condition(variable_selector=["test_node_id", "one_one"], comparison_operator=">", value="0")], - operator="or", - ).final_result - == True - ) diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py deleted file mode 100644 index 009c860f16..0000000000 --- a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py +++ /dev/null @@ -1,48 +0,0 @@ -import dataclasses - -from graphon.nodes.base import variable_template_parser -from graphon.nodes.base.entities import VariableSelector - - -def test_extract_selectors_from_template(): - template = ( - "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." - ) - selectors = variable_template_parser.extract_selectors_from_template(template) - assert selectors == [ - VariableSelector(variable="#sys.user_id#", value_selector=["sys", "user_id"]), - VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]), - VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]), - ] - - -def test_invalid_references(): - @dataclasses.dataclass - class TestCase: - name: str - template: str - - cases = [ - TestCase( - name="lack of closing brace", - template="Hello, {{#sys.user_id#", - ), - TestCase( - name="lack of opening brace", - template="Hello, #sys.user_id#}}", - ), - TestCase( - name="lack selector name", - template="Hello, {{#sys#}}", - ), - TestCase( - name="empty node name part", - template="Hello, {{#.user_id#}}", - ), - ] - for idx, c in enumerate(cases, 1): - fail_msg = f"Test case {c.name} failed, index={idx}" - selectors = variable_template_parser.extract_selectors_from_template(c.template) - assert selectors == [], fail_msg - parser = variable_template_parser.VariableTemplateParser(c.template) - assert parser.extract_variable_selectors() == [], fail_msg diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py index 511192001e..4fe3f2cb28 100644 --- a/api/tests/unit_tests/factories/test_build_from_mapping.py +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -2,13 +2,13 @@ import uuid from unittest.mock import MagicMock, patch import pytest +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig from httpx import Response from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope from core.workflow.file_reference import build_file_reference, parse_file_reference, resolve_file_record_id from factories.file_factory.builders import build_from_mapping as _build_from_mapping -from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig from models import ToolFile, UploadFile # Test Data diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index 70d7d8c575..8d573b1154 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -4,11 +4,6 @@ from typing import Any from uuid import uuid4 import pytest -from hypothesis import HealthCheck, given, settings -from hypothesis import strategies as st - -from factories import variable_factory -from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type from graphon.file import File, FileTransferMethod, FileType from graphon.variables import ( ArrayNumberVariable, @@ -36,6 +31,11 @@ from graphon.variables.segments import ( StringSegment, ) from graphon.variables.types import SegmentType +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st + +from factories import variable_factory +from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type def test_string_variable(): diff --git a/api/tests/unit_tests/fields/test_file_fields.py b/api/tests/unit_tests/fields/test_file_fields.py index 9d9f626b9e..0e848d6ef5 100644 --- a/api/tests/unit_tests/fields/test_file_fields.py +++ b/api/tests/unit_tests/fields/test_file_fields.py @@ -4,11 +4,11 @@ from datetime import datetime from types import SimpleNamespace import pytest +from graphon.file import File, FileTransferMethod, FileType from core.workflow.file_reference import build_file_reference from fields import conversation_fields, message_fields from fields.file_fields import FileResponse, FileWithSignedUrl, RemoteFileInfo, UploadConfig -from graphon.file import File, FileTransferMethod, FileType def test_file_response_serializes_datetime() -> None: diff --git a/api/tests/unit_tests/graphon/file/test_file_factory.py b/api/tests/unit_tests/graphon/file/test_file_factory.py deleted file mode 100644 index eeb537c28f..0000000000 --- a/api/tests/unit_tests/graphon/file/test_file_factory.py +++ /dev/null @@ -1,18 +0,0 @@ -from graphon.file import FileType -from graphon.file.file_factory import get_file_type_by_mime_type, standardize_file_type - - -def test_standardize_file_type_recognizes_case_insensitive_extension(): - assert standardize_file_type(extension=".PNG") == FileType.IMAGE - - -def test_standardize_file_type_recognizes_document_extension(): - assert standardize_file_type(extension=".txt") == FileType.DOCUMENT - - -def test_standardize_file_type_falls_back_to_mime_type(): - assert standardize_file_type(mime_type="video/mp4") == FileType.VIDEO - - -def test_get_file_type_by_mime_type_returns_custom_for_unknown_type(): - assert get_file_type_by_mime_type("application/octet-stream") == FileType.CUSTOM diff --git a/api/tests/unit_tests/graphon/file/test_file_manager.py b/api/tests/unit_tests/graphon/file/test_file_manager.py deleted file mode 100644 index 1eebb13f4e..0000000000 --- a/api/tests/unit_tests/graphon/file/test_file_manager.py +++ /dev/null @@ -1,133 +0,0 @@ -import base64 -from unittest.mock import MagicMock - -import pytest - -from core.workflow.file_reference import build_file_reference -from graphon.file import File, FileTransferMethod, FileType -from graphon.file.file_manager import download, to_prompt_message_content -from graphon.file.runtime import get_workflow_file_runtime, set_workflow_file_runtime -from graphon.model_runtime.entities import ( - DocumentPromptMessageContent, - ImagePromptMessageContent, - TextPromptMessageContent, -) - - -def _build_file( - *, - transfer_method: FileTransferMethod, - file_type: FileType = FileType.IMAGE, - reference: str | None = None, - remote_url: str | None = None, - filename: str = "image.png", - extension: str = ".png", - mime_type: str = "image/png", -) -> File: - return File( - id="file-id", - type=file_type, - transfer_method=transfer_method, - reference=reference, - remote_url=remote_url, - filename=filename, - extension=extension, - mime_type=mime_type, - size=128, - ) - - -@pytest.fixture -def workflow_file_runtime(): - previous_runtime = get_workflow_file_runtime() - runtime = MagicMock() - set_workflow_file_runtime(runtime) - try: - yield runtime - finally: - set_workflow_file_runtime(previous_runtime) - - -@pytest.mark.parametrize( - "transfer_method", - [ - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.TOOL_FILE, - FileTransferMethod.DATASOURCE_FILE, - ], -) -def test_download_delegates_storage_backed_files_to_runtime_loader(workflow_file_runtime, transfer_method) -> None: - workflow_file_runtime.load_file_bytes.return_value = b"payload" - file = _build_file( - transfer_method=transfer_method, - reference=build_file_reference(record_id="file-id", storage_key="files/payload.bin"), - ) - - assert download(file) == b"payload" - workflow_file_runtime.load_file_bytes.assert_called_once_with(file=file) - - -def test_download_remote_url_uses_runtime_http_get(workflow_file_runtime) -> None: - response = MagicMock() - response.content = b"remote-payload" - workflow_file_runtime.http_get.return_value = response - file = _build_file( - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url="https://example.com/image.png", - ) - - assert download(file) == b"remote-payload" - workflow_file_runtime.http_get.assert_called_once_with("https://example.com/image.png", follow_redirects=True) - response.raise_for_status.assert_called_once_with() - - -def test_to_prompt_message_content_uses_runtime_url_resolution_for_images(workflow_file_runtime) -> None: - workflow_file_runtime.multimodal_send_format = "url" - workflow_file_runtime.resolve_file_url.return_value = "https://cdn.example.com/image.png" - file = _build_file( - transfer_method=FileTransferMethod.LOCAL_FILE, - reference=build_file_reference(record_id="upload-file-id", storage_key="files/image.png"), - ) - - content = to_prompt_message_content(file, image_detail_config=ImagePromptMessageContent.DETAIL.HIGH) - - assert isinstance(content, ImagePromptMessageContent) - assert content.url == "https://cdn.example.com/image.png" - assert content.base64_data == "" - assert content.detail == ImagePromptMessageContent.DETAIL.HIGH - - -def test_to_prompt_message_content_uses_runtime_file_loader_for_base64_documents(workflow_file_runtime) -> None: - workflow_file_runtime.multimodal_send_format = "base64" - workflow_file_runtime.load_file_bytes.return_value = b"document-bytes" - file = _build_file( - transfer_method=FileTransferMethod.TOOL_FILE, - file_type=FileType.DOCUMENT, - reference=build_file_reference(record_id="tool-file-id", storage_key="docs/report.pdf"), - filename="report.pdf", - extension=".pdf", - mime_type="application/pdf", - ) - - content = to_prompt_message_content(file) - - assert isinstance(content, DocumentPromptMessageContent) - assert content.base64_data == base64.b64encode(b"document-bytes").decode("utf-8") - assert content.url == "" - workflow_file_runtime.load_file_bytes.assert_called_once_with(file=file) - - -def test_to_prompt_message_content_returns_text_placeholder_for_custom_files() -> None: - file = _build_file( - transfer_method=FileTransferMethod.REMOTE_URL, - file_type=FileType.CUSTOM, - remote_url="https://example.com/archive.bin", - filename="archive.bin", - extension=".bin", - mime_type="application/octet-stream", - ) - - content = to_prompt_message_content(file) - - assert isinstance(content, TextPromptMessageContent) - assert content.data == "[Unsupported file type: archive.bin (custom)]" diff --git a/api/tests/unit_tests/graphon/file/test_models.py b/api/tests/unit_tests/graphon/file/test_models.py deleted file mode 100644 index 17d244da5f..0000000000 --- a/api/tests/unit_tests/graphon/file/test_models.py +++ /dev/null @@ -1,54 +0,0 @@ -from core.workflow.file_reference import build_file_reference -from graphon.file import File, FileTransferMethod, FileType, helpers - - -def _build_local_file(*, reference: str, storage_key: str | None = None) -> File: - return File( - id="file-id", - type=FileType.DOCUMENT, - transfer_method=FileTransferMethod.LOCAL_FILE, - reference=reference, - filename="report.pdf", - extension=".pdf", - mime_type="application/pdf", - size=128, - storage_key=storage_key, - ) - - -def test_file_exposes_legacy_aliases_from_opaque_reference() -> None: - reference = build_file_reference(record_id="upload-file-id", storage_key="files/report.pdf") - - file = _build_local_file(reference=reference) - - assert file.reference == reference - assert file.related_id == "upload-file-id" - assert file.storage_key == "files/report.pdf" - - -def test_file_falls_back_to_raw_reference_when_opaque_reference_is_invalid() -> None: - file = _build_local_file(reference="dify-file-ref:not-base64", storage_key="fallback-key") - - assert file.related_id == "dify-file-ref:not-base64" - assert file.storage_key == "fallback-key" - - -def test_file_to_dict_keeps_reference_and_legacy_related_id(monkeypatch) -> None: - reference = build_file_reference(record_id="upload-file-id", storage_key="files/report.pdf") - file = _build_local_file(reference=reference) - monkeypatch.setattr(helpers, "resolve_file_url", lambda _file, for_external=True: "https://example.com/report.pdf") - - serialized = file.to_dict() - - assert serialized["reference"] == reference - assert serialized["related_id"] == "upload-file-id" - assert serialized["url"] == "https://example.com/report.pdf" - - -def test_file_related_id_setter_updates_reference_alias() -> None: - file = _build_local_file(reference="upload-file-id", storage_key="files/report.pdf") - - file.related_id = "replacement-upload-id" - - assert file.reference == "replacement-upload-id" - assert file.related_id == "replacement-upload-id" diff --git a/api/tests/unit_tests/graphon/model_runtime/__base/__init__.py b/api/tests/unit_tests/graphon/model_runtime/__base/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/unit_tests/graphon/model_runtime/__base/test_increase_tool_call.py b/api/tests/unit_tests/graphon/model_runtime/__base/test_increase_tool_call.py deleted file mode 100644 index 7b4fc5a04c..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/__base/test_increase_tool_call.py +++ /dev/null @@ -1,114 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage -from graphon.model_runtime.model_providers.__base.large_language_model import _increase_tool_call - -ToolCall = AssistantPromptMessage.ToolCall - -# CASE 1: Single tool call -INPUTS_CASE_1 = [ - ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), -] -EXPECTED_CASE_1 = [ - ToolCall( - id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') - ), -] - -# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...) -INPUTS_CASE_2 = [ - ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), - ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), -] -EXPECTED_CASE_2 = [ - ToolCall( - id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') - ), - ToolCall( - id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}') - ), -] - -# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...) -INPUTS_CASE_3 = [ - ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), - ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), - ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), - ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), - ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), - ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), -] -EXPECTED_CASE_3 = [ - ToolCall( - id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') - ), - ToolCall( - id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}') - ), -] - -# CASE 4: Tool call sequences with no IDs -INPUTS_CASE_4 = [ - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), -] -EXPECTED_CASE_4 = [ - ToolCall( - id="RANDOM_ID_1", - type="function", - function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'), - ), - ToolCall( - id="RANDOM_ID_2", - type="function", - function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'), - ), -] - - -def _run_case(inputs: list[ToolCall], expected: list[ToolCall]): - actual = [] - _increase_tool_call(inputs, actual) - assert actual == expected - - -def test__increase_tool_call(): - # case 1: - _run_case(INPUTS_CASE_1, EXPECTED_CASE_1) - - # case 2: - _run_case(INPUTS_CASE_2, EXPECTED_CASE_2) - - # case 3: - _run_case(INPUTS_CASE_3, EXPECTED_CASE_3) - - # case 4: - mock_id_generator = MagicMock() - mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4] - with patch( - "graphon.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator - ): - _run_case(INPUTS_CASE_4, EXPECTED_CASE_4) - - -def test__increase_tool_call__no_id_no_name_first_delta_should_raise(): - inputs = [ - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='"value"}')), - ] - actual: list[ToolCall] = [] - with patch("graphon.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", MagicMock()): - with pytest.raises(ValueError): - _increase_tool_call(inputs, actual) diff --git a/api/tests/unit_tests/graphon/model_runtime/__base/test_large_language_model_non_stream_parsing.py b/api/tests/unit_tests/graphon/model_runtime/__base/test_large_language_model_non_stream_parsing.py deleted file mode 100644 index c922fbaa60..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/__base/test_large_language_model_non_stream_parsing.py +++ /dev/null @@ -1,126 +0,0 @@ -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_runtime_result - - -def _make_chunk( - *, - model: str = "test-model", - content: str | list[TextPromptMessageContent] | None, - tool_calls: list[AssistantPromptMessage.ToolCall] | None = None, - usage: LLMUsage | None = None, - system_fingerprint: str | None = None, -) -> LLMResultChunk: - message = AssistantPromptMessage(content=content, tool_calls=tool_calls or []) - delta = LLMResultChunkDelta(index=0, message=message, usage=usage) - return LLMResultChunk(model=model, delta=delta, system_fingerprint=system_fingerprint) - - -def test__normalize_non_stream_runtime_result__from_first_chunk_str_content_and_tool_calls(): - prompt_messages = [UserPromptMessage(content="hi")] - - tool_calls = [ - AssistantPromptMessage.ToolCall( - id="1", - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments=""), - ), - AssistantPromptMessage.ToolCall( - id="", - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='{"arg1": '), - ), - AssistantPromptMessage.ToolCall( - id="", - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='"value"}'), - ), - ] - - usage = LLMUsage.empty_usage().model_copy(update={"prompt_tokens": 1, "total_tokens": 1}) - chunk = _make_chunk(content="hello", tool_calls=tool_calls, usage=usage, system_fingerprint="fp-1") - - result = _normalize_non_stream_runtime_result( - model="test-model", prompt_messages=prompt_messages, result=iter([chunk]) - ) - - assert result.model == "test-model" - assert result.prompt_messages == prompt_messages - assert result.message.content == "hello" - assert result.usage.prompt_tokens == 1 - assert result.system_fingerprint == "fp-1" - assert result.message.tool_calls == [ - AssistantPromptMessage.ToolCall( - id="1", - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'), - ) - ] - - -def test__normalize_non_stream_runtime_result__from_first_chunk_list_content(): - prompt_messages = [UserPromptMessage(content="hi")] - - content_list = [TextPromptMessageContent(data="a"), TextPromptMessageContent(data="b")] - chunk = _make_chunk(content=content_list, usage=LLMUsage.empty_usage()) - - result = _normalize_non_stream_runtime_result( - model="test-model", prompt_messages=prompt_messages, result=iter([chunk]) - ) - - assert result.message.content == content_list - - -def test__normalize_non_stream_runtime_result__passthrough_llm_result(): - prompt_messages = [UserPromptMessage(content="hi")] - llm_result = LLMResult( - model="test-model", - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content="ok"), - usage=LLMUsage.empty_usage(), - ) - - assert ( - _normalize_non_stream_runtime_result(model="test-model", prompt_messages=prompt_messages, result=llm_result) - == llm_result - ) - - -def test__normalize_non_stream_runtime_result__empty_iterator_defaults(): - prompt_messages = [UserPromptMessage(content="hi")] - - result = _normalize_non_stream_runtime_result(model="test-model", prompt_messages=prompt_messages, result=iter([])) - - assert result.model == "test-model" - assert result.prompt_messages == prompt_messages - assert result.message.content == [] - assert result.message.tool_calls == [] - assert result.usage == LLMUsage.empty_usage() - assert result.system_fingerprint is None - - -def test__normalize_non_stream_runtime_result__accumulates_all_chunks(): - """All chunks are accumulated from the iterator.""" - prompt_messages = [UserPromptMessage(content="hi")] - - closed: list[bool] = [] - - def _chunk_iter(): - try: - yield _make_chunk(content="hello", usage=LLMUsage.empty_usage()) - yield _make_chunk(content=" world", usage=LLMUsage.empty_usage()) - finally: - closed.append(True) - - result = _normalize_non_stream_runtime_result( - model="test-model", - prompt_messages=prompt_messages, - result=_chunk_iter(), - ) - - assert result.message.content == "hello world" - assert closed == [True] diff --git a/api/tests/unit_tests/graphon/model_runtime/__init__.py b/api/tests/unit_tests/graphon/model_runtime/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/unit_tests/graphon/model_runtime/callbacks/test_base_callback.py b/api/tests/unit_tests/graphon/model_runtime/callbacks/test_base_callback.py deleted file mode 100644 index 776fc230cb..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/callbacks/test_base_callback.py +++ /dev/null @@ -1,964 +0,0 @@ -"""Comprehensive unit tests for core/model_runtime/callbacks/base_callback.py""" - -from unittest.mock import MagicMock, patch - -import pytest - -from graphon.model_runtime.callbacks.base_callback import ( - _TEXT_COLOR_MAPPING, - Callback, -) -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool - -# --------------------------------------------------------------------------- -# Concrete implementation of the abstract Callback for testing -# --------------------------------------------------------------------------- - - -class ConcreteCallback(Callback): - """A minimal concrete subclass that satisfies all abstract methods.""" - - def __init__(self, raise_error: bool = False): - self.raise_error = raise_error - # Track invocations - self.before_invoke_calls: list[dict] = [] - self.new_chunk_calls: list[dict] = [] - self.after_invoke_calls: list[dict] = [] - self.invoke_error_calls: list[dict] = [] - - def on_before_invoke( - self, - llm_instance, - model, - credentials, - prompt_messages, - model_parameters, - tools=None, - stop=None, - stream=True, - user=None, - ): - self.before_invoke_calls.append( - { - "llm_instance": llm_instance, - "model": model, - "credentials": credentials, - "prompt_messages": prompt_messages, - "model_parameters": model_parameters, - "tools": tools, - "stop": stop, - "stream": stream, - "user": user, - } - ) - # To cover the 'raise NotImplementedError()' in the base class - try: - super().on_before_invoke( - llm_instance, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user - ) - except NotImplementedError: - pass - - def on_new_chunk( - self, - llm_instance, - chunk, - model, - credentials, - prompt_messages, - model_parameters, - tools=None, - stop=None, - stream=True, - user=None, - ): - self.new_chunk_calls.append( - { - "llm_instance": llm_instance, - "chunk": chunk, - "model": model, - "credentials": credentials, - "prompt_messages": prompt_messages, - "model_parameters": model_parameters, - "tools": tools, - "stop": stop, - "stream": stream, - "user": user, - } - ) - try: - super().on_new_chunk( - llm_instance, chunk, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user - ) - except NotImplementedError: - pass - - def on_after_invoke( - self, - llm_instance, - result, - model, - credentials, - prompt_messages, - model_parameters, - tools=None, - stop=None, - stream=True, - user=None, - ): - self.after_invoke_calls.append( - { - "llm_instance": llm_instance, - "result": result, - "model": model, - "credentials": credentials, - "prompt_messages": prompt_messages, - "model_parameters": model_parameters, - "tools": tools, - "stop": stop, - "stream": stream, - "user": user, - } - ) - try: - super().on_after_invoke( - llm_instance, result, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user - ) - except NotImplementedError: - pass - - def on_invoke_error( - self, - llm_instance, - ex, - model, - credentials, - prompt_messages, - model_parameters, - tools=None, - stop=None, - stream=True, - user=None, - ): - self.invoke_error_calls.append( - { - "llm_instance": llm_instance, - "ex": ex, - "model": model, - "credentials": credentials, - "prompt_messages": prompt_messages, - "model_parameters": model_parameters, - "tools": tools, - "stop": stop, - "stream": stream, - "user": user, - } - ) - try: - super().on_invoke_error( - llm_instance, ex, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user - ) - except NotImplementedError: - pass - - -# --------------------------------------------------------------------------- -# A subclass that deliberately leaves abstract methods un-implemented, -# used to verify that instantiation raises TypeError. -# --------------------------------------------------------------------------- - - -# =========================================================================== -# Tests for _TEXT_COLOR_MAPPING module-level constant -# =========================================================================== - - -class TestTextColorMapping: - """Tests for the module-level _TEXT_COLOR_MAPPING dictionary.""" - - def test_contains_all_expected_colors(self): - expected_keys = {"blue", "yellow", "pink", "green", "red"} - assert set(_TEXT_COLOR_MAPPING.keys()) == expected_keys - - def test_blue_escape_code(self): - assert _TEXT_COLOR_MAPPING["blue"] == "36;1" - - def test_yellow_escape_code(self): - assert _TEXT_COLOR_MAPPING["yellow"] == "33;1" - - def test_pink_escape_code(self): - assert _TEXT_COLOR_MAPPING["pink"] == "38;5;200" - - def test_green_escape_code(self): - assert _TEXT_COLOR_MAPPING["green"] == "32;1" - - def test_red_escape_code(self): - assert _TEXT_COLOR_MAPPING["red"] == "31;1" - - def test_mapping_is_dict(self): - assert isinstance(_TEXT_COLOR_MAPPING, dict) - - def test_all_values_are_strings(self): - for key, value in _TEXT_COLOR_MAPPING.items(): - assert isinstance(value, str), f"Value for {key!r} should be str" - - -# =========================================================================== -# Tests for the Callback ABC itself -# =========================================================================== - - -class TestCallbackAbstract: - """Tests verifying Callback is a proper ABC.""" - - def test_cannot_instantiate_abstract_class_directly(self): - """Callback cannot be instantiated since it has abstract methods.""" - with pytest.raises(TypeError): - Callback() # type: ignore[abstract] - - def test_concrete_subclass_can_be_instantiated(self): - cb = ConcreteCallback() - assert isinstance(cb, Callback) - - def test_default_raise_error_is_false(self): - cb = ConcreteCallback() - assert cb.raise_error is False - - def test_raise_error_can_be_set_to_true(self): - cb = ConcreteCallback(raise_error=True) - assert cb.raise_error is True - - def test_subclass_missing_on_before_invoke_raises_type_error(self): - """A subclass missing any single abstract method cannot be instantiated.""" - - class IncompleteCallback(Callback): - def on_new_chunk(self, *a, **kw): ... - def on_after_invoke(self, *a, **kw): ... - def on_invoke_error(self, *a, **kw): ... - - with pytest.raises(TypeError): - IncompleteCallback() # type: ignore[abstract] - - def test_subclass_missing_on_new_chunk_raises_type_error(self): - class IncompleteCallback(Callback): - def on_before_invoke(self, *a, **kw): ... - def on_after_invoke(self, *a, **kw): ... - def on_invoke_error(self, *a, **kw): ... - - with pytest.raises(TypeError): - IncompleteCallback() # type: ignore[abstract] - - def test_subclass_missing_on_after_invoke_raises_type_error(self): - class IncompleteCallback(Callback): - def on_before_invoke(self, *a, **kw): ... - def on_new_chunk(self, *a, **kw): ... - def on_invoke_error(self, *a, **kw): ... - - with pytest.raises(TypeError): - IncompleteCallback() # type: ignore[abstract] - - def test_subclass_missing_on_invoke_error_raises_type_error(self): - class IncompleteCallback(Callback): - def on_before_invoke(self, *a, **kw): ... - def on_new_chunk(self, *a, **kw): ... - def on_after_invoke(self, *a, **kw): ... - - with pytest.raises(TypeError): - IncompleteCallback() # type: ignore[abstract] - - -# =========================================================================== -# Tests for on_before_invoke -# =========================================================================== - - -class TestOnBeforeInvoke: - """Tests for the on_before_invoke callback method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - self.llm_instance = MagicMock() - self.model = "gpt-4" - self.credentials = {"api_key": "sk-test"} - self.prompt_messages = [MagicMock(spec=PromptMessage)] - self.model_parameters = {"temperature": 0.7} - - def test_on_before_invoke_called_with_required_args(self): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.before_invoke_calls) == 1 - call = self.cb.before_invoke_calls[0] - assert call["llm_instance"] is self.llm_instance - assert call["model"] == self.model - assert call["credentials"] == self.credentials - assert call["prompt_messages"] is self.prompt_messages - assert call["model_parameters"] is self.model_parameters - - def test_on_before_invoke_defaults_tools_none(self): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.before_invoke_calls[0]["tools"] is None - - def test_on_before_invoke_defaults_stop_none(self): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.before_invoke_calls[0]["stop"] is None - - def test_on_before_invoke_defaults_stream_true(self): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.before_invoke_calls[0]["stream"] is True - - def test_on_before_invoke_defaults_user_none(self): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.before_invoke_calls[0]["user"] is None - - def test_on_before_invoke_with_all_optional_args(self): - tools = [MagicMock(spec=PromptMessageTool)] - stop = ["stop1", "stop2"] - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - tools=tools, - stop=stop, - stream=False, - user="user-123", - ) - call = self.cb.before_invoke_calls[0] - assert call["tools"] is tools - assert call["stop"] == stop - assert call["stream"] is False - assert call["user"] == "user-123" - - def test_on_before_invoke_called_multiple_times(self): - for i in range(3): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=f"model-{i}", - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.before_invoke_calls) == 3 - assert self.cb.before_invoke_calls[2]["model"] == "model-2" - - -# =========================================================================== -# Tests for on_new_chunk -# =========================================================================== - - -class TestOnNewChunk: - """Tests for the on_new_chunk callback method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - self.llm_instance = MagicMock() - self.chunk = MagicMock(spec=LLMResultChunk) - self.model = "gpt-3.5-turbo" - self.credentials = {"api_key": "sk-test"} - self.prompt_messages = [MagicMock(spec=PromptMessage)] - self.model_parameters = {"max_tokens": 256} - - def test_on_new_chunk_called_with_required_args(self): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.new_chunk_calls) == 1 - call = self.cb.new_chunk_calls[0] - assert call["llm_instance"] is self.llm_instance - assert call["chunk"] is self.chunk - assert call["model"] == self.model - assert call["credentials"] == self.credentials - - def test_on_new_chunk_defaults_tools_none(self): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.new_chunk_calls[0]["tools"] is None - - def test_on_new_chunk_defaults_stop_none(self): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.new_chunk_calls[0]["stop"] is None - - def test_on_new_chunk_defaults_stream_true(self): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.new_chunk_calls[0]["stream"] is True - - def test_on_new_chunk_defaults_user_none(self): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.new_chunk_calls[0]["user"] is None - - def test_on_new_chunk_with_all_optional_args(self): - tools = [MagicMock(spec=PromptMessageTool)] - stop = ["END"] - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - tools=tools, - stop=stop, - stream=False, - user="chunk-user", - ) - call = self.cb.new_chunk_calls[0] - assert call["tools"] is tools - assert call["stop"] == stop - assert call["stream"] is False - assert call["user"] == "chunk-user" - - def test_on_new_chunk_called_multiple_times(self): - for i in range(5): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.new_chunk_calls) == 5 - - -# =========================================================================== -# Tests for on_after_invoke -# =========================================================================== - - -class TestOnAfterInvoke: - """Tests for the on_after_invoke callback method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - self.llm_instance = MagicMock() - self.result = MagicMock(spec=LLMResult) - self.model = "claude-3" - self.credentials = {"api_key": "anthropic-key"} - self.prompt_messages = [MagicMock(spec=PromptMessage)] - self.model_parameters = {"temperature": 1.0} - - def test_on_after_invoke_called_with_required_args(self): - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.after_invoke_calls) == 1 - call = self.cb.after_invoke_calls[0] - assert call["llm_instance"] is self.llm_instance - assert call["result"] is self.result - assert call["model"] == self.model - assert call["credentials"] is self.credentials - - def test_on_after_invoke_defaults_tools_none(self): - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.after_invoke_calls[0]["tools"] is None - - def test_on_after_invoke_defaults_stop_none(self): - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.after_invoke_calls[0]["stop"] is None - - def test_on_after_invoke_defaults_stream_true(self): - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.after_invoke_calls[0]["stream"] is True - - def test_on_after_invoke_defaults_user_none(self): - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.after_invoke_calls[0]["user"] is None - - def test_on_after_invoke_with_all_optional_args(self): - tools = [MagicMock(spec=PromptMessageTool)] - stop = ["STOP"] - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - tools=tools, - stop=stop, - stream=False, - user="after-user", - ) - call = self.cb.after_invoke_calls[0] - assert call["tools"] is tools - assert call["stop"] == stop - assert call["stream"] is False - assert call["user"] == "after-user" - - -# =========================================================================== -# Tests for on_invoke_error -# =========================================================================== - - -class TestOnInvokeError: - """Tests for the on_invoke_error callback method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - self.llm_instance = MagicMock() - self.ex = ValueError("something went wrong") - self.model = "gemini-pro" - self.credentials = {"api_key": "google-key"} - self.prompt_messages = [MagicMock(spec=PromptMessage)] - self.model_parameters = {"top_p": 0.9} - - def test_on_invoke_error_called_with_required_args(self): - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.invoke_error_calls) == 1 - call = self.cb.invoke_error_calls[0] - assert call["llm_instance"] is self.llm_instance - assert call["ex"] is self.ex - assert call["model"] == self.model - assert call["credentials"] is self.credentials - - def test_on_invoke_error_defaults_tools_none(self): - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.invoke_error_calls[0]["tools"] is None - - def test_on_invoke_error_defaults_stop_none(self): - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.invoke_error_calls[0]["stop"] is None - - def test_on_invoke_error_defaults_stream_true(self): - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.invoke_error_calls[0]["stream"] is True - - def test_on_invoke_error_defaults_user_none(self): - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.invoke_error_calls[0]["user"] is None - - def test_on_invoke_error_with_all_optional_args(self): - tools = [MagicMock(spec=PromptMessageTool)] - stop = ["HALT"] - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - tools=tools, - stop=stop, - stream=False, - user="error-user", - ) - call = self.cb.invoke_error_calls[0] - assert call["tools"] is tools - assert call["stop"] == stop - assert call["stream"] is False - assert call["user"] == "error-user" - - def test_on_invoke_error_accepts_various_exception_types(self): - for exc in [RuntimeError("r"), KeyError("k"), Exception("e")]: - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=exc, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.invoke_error_calls) == 3 - - -# =========================================================================== -# Tests for print_text (concrete method on Callback) -# =========================================================================== - - -class TestPrintText: - """Tests for the concrete print_text method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - - def test_print_text_without_color_prints_plain_text(self, capsys): - self.cb.print_text("hello world") - captured = capsys.readouterr() - assert captured.out == "hello world" - - def test_print_text_with_color_prints_colored_text(self, capsys): - self.cb.print_text("colored text", color="blue") - captured = capsys.readouterr() - # Should contain ANSI escape sequences - assert "colored text" in captured.out - assert "\001b[" in captured.out or "\033[" in captured.out or "\x1b[" in captured.out - - def test_print_text_without_color_no_ansi(self, capsys): - self.cb.print_text("plain text", color=None) - captured = capsys.readouterr() - assert captured.out == "plain text" - # No ANSI escape sequences - assert "\x1b" not in captured.out - - def test_print_text_default_end_is_empty_string(self, capsys): - self.cb.print_text("no newline") - captured = capsys.readouterr() - assert not captured.out.endswith("\n") - - def test_print_text_with_custom_end(self, capsys): - self.cb.print_text("with newline", end="\n") - captured = capsys.readouterr() - assert captured.out.endswith("\n") - - def test_print_text_with_empty_string(self, capsys): - self.cb.print_text("", color=None) - captured = capsys.readouterr() - assert captured.out == "" - - @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) - def test_print_text_all_colors_work(self, color, capsys): - """Verify no KeyError is thrown for any valid color.""" - self.cb.print_text("test", color=color) - captured = capsys.readouterr() - assert "test" in captured.out - - def test_print_text_calls_get_colored_text_when_color_given(self): - with patch.object(self.cb, "_get_colored_text", return_value="[COLORED]") as mock_gct: - with patch("builtins.print") as mock_print: - self.cb.print_text("hello", color="green") - mock_gct.assert_called_once_with("hello", "green") - mock_print.assert_called_once_with("[COLORED]", end="") - - def test_print_text_does_not_call_get_colored_text_when_no_color(self): - with patch.object(self.cb, "_get_colored_text") as mock_gct: - with patch("builtins.print"): - self.cb.print_text("hello", color=None) - mock_gct.assert_not_called() - - def test_print_text_passes_end_to_print(self): - with patch("builtins.print") as mock_print: - self.cb.print_text("text", end="---") - mock_print.assert_called_once_with("text", end="---") - - -# =========================================================================== -# Tests for _get_colored_text (private helper method) -# =========================================================================== - - -class TestGetColoredText: - """Tests for the _get_colored_text private method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - - @pytest.mark.parametrize(("color", "expected_code"), list(_TEXT_COLOR_MAPPING.items())) - def test_get_colored_text_uses_correct_escape_code(self, color, expected_code): - result = self.cb._get_colored_text("text", color) - assert expected_code in result - - @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) - def test_get_colored_text_contains_input_text(self, color): - result = self.cb._get_colored_text("hello", color) - assert "hello" in result - - @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) - def test_get_colored_text_starts_with_escape(self, color): - result = self.cb._get_colored_text("text", color) - # Should start with an ANSI escape (\x1b or \u001b) - assert result.startswith("\x1b[") or result.startswith("\u001b[") - - @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) - def test_get_colored_text_ends_with_reset(self, color): - result = self.cb._get_colored_text("text", color) - # Should end with the ANSI reset code - assert result.endswith("\x1b[0m") or result.endswith("\u001b[0m") - - def test_get_colored_text_returns_string(self): - result = self.cb._get_colored_text("text", "blue") - assert isinstance(result, str) - - def test_get_colored_text_blue_exact_format(self): - result = self.cb._get_colored_text("hello", "blue") - expected = f"\u001b[{_TEXT_COLOR_MAPPING['blue']}m\033[1;3mhello\u001b[0m" - assert result == expected - - def test_get_colored_text_red_exact_format(self): - result = self.cb._get_colored_text("error", "red") - expected = f"\u001b[{_TEXT_COLOR_MAPPING['red']}m\033[1;3merror\u001b[0m" - assert result == expected - - def test_get_colored_text_green_exact_format(self): - result = self.cb._get_colored_text("ok", "green") - expected = f"\u001b[{_TEXT_COLOR_MAPPING['green']}m\033[1;3mok\u001b[0m" - assert result == expected - - def test_get_colored_text_yellow_exact_format(self): - result = self.cb._get_colored_text("warn", "yellow") - expected = f"\u001b[{_TEXT_COLOR_MAPPING['yellow']}m\033[1;3mwarn\u001b[0m" - assert result == expected - - def test_get_colored_text_pink_exact_format(self): - result = self.cb._get_colored_text("info", "pink") - expected = f"\u001b[{_TEXT_COLOR_MAPPING['pink']}m\033[1;3minfo\u001b[0m" - assert result == expected - - def test_get_colored_text_empty_string(self): - result = self.cb._get_colored_text("", "blue") - assert isinstance(result, str) - # Empty text should still have escape codes - assert _TEXT_COLOR_MAPPING["blue"] in result - - def test_get_colored_text_invalid_color_raises_key_error(self): - with pytest.raises(KeyError): - self.cb._get_colored_text("text", "purple") - - def test_get_colored_text_with_special_characters(self): - special = "hello\nworld\ttab" - result = self.cb._get_colored_text(special, "blue") - assert special in result - - def test_get_colored_text_with_long_text(self): - long_text = "a" * 10000 - result = self.cb._get_colored_text(long_text, "green") - assert long_text in result - - -# =========================================================================== -# Integration-style tests: full workflow through a ConcreteCallback -# =========================================================================== - - -class TestConcreteCallbackIntegration: - """End-to-end workflow tests using ConcreteCallback.""" - - def test_full_invocation_lifecycle(self): - """Simulate a complete LLM invocation lifecycle through all callbacks.""" - cb = ConcreteCallback() - llm_instance = MagicMock() - model = "gpt-4o" - credentials = {"api_key": "sk-xyz"} - prompt_messages = [MagicMock(spec=PromptMessage)] - model_parameters = {"temperature": 0.5} - tools = [MagicMock(spec=PromptMessageTool)] - stop = [""] - user = "user-abc" - - # 1. Before invoke - cb.on_before_invoke( - llm_instance=llm_instance, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=True, - user=user, - ) - - # 2. Multiple chunks during streaming - for i in range(3): - chunk = MagicMock(spec=LLMResultChunk) - cb.on_new_chunk( - llm_instance=llm_instance, - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=True, - user=user, - ) - - # 3. After invoke - result = MagicMock(spec=LLMResult) - cb.on_after_invoke( - llm_instance=llm_instance, - result=result, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=True, - user=user, - ) - - assert len(cb.before_invoke_calls) == 1 - assert len(cb.new_chunk_calls) == 3 - assert len(cb.after_invoke_calls) == 1 - assert len(cb.invoke_error_calls) == 0 - - def test_error_lifecycle(self): - """Simulate an invoke that results in an error.""" - cb = ConcreteCallback() - llm_instance = MagicMock() - model = "gpt-4" - credentials = {} - prompt_messages = [] - model_parameters = {} - - cb.on_before_invoke( - llm_instance=llm_instance, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - ) - - ex = RuntimeError("API timeout") - cb.on_invoke_error( - llm_instance=llm_instance, - ex=ex, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - ) - - assert len(cb.before_invoke_calls) == 1 - assert len(cb.invoke_error_calls) == 1 - assert cb.invoke_error_calls[0]["ex"] is ex - assert len(cb.after_invoke_calls) == 0 - - def test_print_text_with_color_in_integration(self, capsys): - """verify print_text works correctly in a concrete instance.""" - cb = ConcreteCallback() - cb.print_text("SUCCESS", color="green", end="\n") - captured = capsys.readouterr() - assert "SUCCESS" in captured.out - assert "\n" in captured.out - - def test_print_text_no_color_in_integration(self, capsys): - cb = ConcreteCallback() - cb.print_text("plain output") - captured = capsys.readouterr() - assert captured.out == "plain output" diff --git a/api/tests/unit_tests/graphon/model_runtime/callbacks/test_logging_callback.py b/api/tests/unit_tests/graphon/model_runtime/callbacks/test_logging_callback.py deleted file mode 100644 index df9215826c..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/callbacks/test_logging_callback.py +++ /dev/null @@ -1,700 +0,0 @@ -""" -Comprehensive unit tests for core/model_runtime/callbacks/logging_callback.py - -Coverage targets: - - LoggingCallback.on_before_invoke (all branches: stop, tools, user, stream, - prompt_message.name, model_parameters) - - LoggingCallback.on_new_chunk (writes to stdout) - - LoggingCallback.on_after_invoke (all branches: tool_calls present / absent) - - LoggingCallback.on_invoke_error (logs exception via logger.exception) -""" - -from __future__ import annotations - -import json -from collections.abc import Sequence -from decimal import Decimal -from unittest.mock import MagicMock, patch - -import pytest - -from graphon.model_runtime.callbacks.logging_callback import LoggingCallback -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMUsage, -) -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageTool, - SystemPromptMessage, - UserPromptMessage, -) - -# --------------------------------------------------------------------------- -# Shared helpers -# --------------------------------------------------------------------------- - - -def _make_usage() -> LLMUsage: - """Return a minimal LLMUsage instance.""" - return LLMUsage( - prompt_tokens=10, - prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal("0.001"), - prompt_price=Decimal("0.01"), - completion_tokens=20, - completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal("0.002"), - completion_price=Decimal("0.04"), - total_tokens=30, - total_price=Decimal("0.05"), - currency="USD", - latency=0.5, - ) - - -def _make_llm_result( - content: str = "hello world", - tool_calls: list | None = None, - model: str = "gpt-4", - system_fingerprint: str | None = "fp-abc", -) -> LLMResult: - """Return an LLMResult with an AssistantPromptMessage.""" - assistant_msg = AssistantPromptMessage( - content=content, - tool_calls=tool_calls or [], - ) - return LLMResult( - model=model, - message=assistant_msg, - usage=_make_usage(), - system_fingerprint=system_fingerprint, - ) - - -def _make_chunk(content: str = "chunk-text") -> LLMResultChunk: - """Return a minimal LLMResultChunk.""" - return LLMResultChunk( - model="gpt-4", - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=content), - ), - ) - - -def _make_user_prompt(content: str = "Hello!", name: str | None = None) -> UserPromptMessage: - return UserPromptMessage(content=content, name=name) - - -def _make_system_prompt(content: str = "You are helpful.") -> SystemPromptMessage: - return SystemPromptMessage(content=content) - - -def _make_tool(name: str = "my_tool") -> PromptMessageTool: - return PromptMessageTool(name=name, description="A tool", parameters={}) - - -def _make_tool_call( - call_id: str = "call-1", - func_name: str = "some_func", - arguments: str = '{"key": "value"}', -) -> AssistantPromptMessage.ToolCall: - return AssistantPromptMessage.ToolCall( - id=call_id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=func_name, arguments=arguments), - ) - - -# --------------------------------------------------------------------------- -# Fixture: shared LoggingCallback instance (no heavy state) -# --------------------------------------------------------------------------- - - -@pytest.fixture -def cb() -> LoggingCallback: - return LoggingCallback() - - -@pytest.fixture -def llm_instance() -> MagicMock: - return MagicMock() - - -# =========================================================================== -# Tests for on_before_invoke -# =========================================================================== - - -class TestOnBeforeInvoke: - """Tests for LoggingCallback.on_before_invoke.""" - - def _invoke( - self, - cb: LoggingCallback, - llm_instance: MagicMock, - *, - model: str = "gpt-4", - credentials: dict | None = None, - prompt_messages: list | None = None, - model_parameters: dict | None = None, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - cb.on_before_invoke( - llm_instance=llm_instance, - model=model, - credentials=credentials or {}, - prompt_messages=prompt_messages or [], - model_parameters=model_parameters or {}, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - - def test_minimal_call_does_not_raise(self, cb: LoggingCallback, llm_instance: MagicMock): - """Calling with bare-minimum args should not raise.""" - self._invoke(cb, llm_instance) - - def test_model_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """The model name must appear in print_text calls.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, model="claude-3") - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "claude-3" in calls_text - - def test_model_parameters_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """Each key-value pair of model_parameters must be printed.""" - params = {"temperature": 0.7, "max_tokens": 512} - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, model_parameters=params) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "temperature" in calls_text - assert "0.7" in calls_text - assert "max_tokens" in calls_text - assert "512" in calls_text - - def test_empty_model_parameters(self, cb: LoggingCallback, llm_instance: MagicMock): - """Empty model_parameters dict should not raise.""" - self._invoke(cb, llm_instance, model_parameters={}) - - # ------------------------------------------------------------------ - # stop branch - # ------------------------------------------------------------------ - - def test_stop_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): - """stop words must appear in output when provided.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, stop=["STOP", "END"]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "stop" in calls_text - - def test_stop_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): - """When stop=None the stop line must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, stop=None) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "\tstop:" not in calls_text - - def test_stop_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock): - """When stop=[] (falsy) the stop line must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, stop=[]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "\tstop:" not in calls_text - - # ------------------------------------------------------------------ - # tools branch - # ------------------------------------------------------------------ - - def test_tools_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): - """Tool names must appear in output when tools are provided.""" - tools = [_make_tool("search"), _make_tool("calculate")] - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, tools=tools) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "search" in calls_text - assert "calculate" in calls_text - - def test_tools_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): - """When tools=None the Tools section must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, tools=None) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "Tools:" not in calls_text - - def test_tools_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock): - """When tools=[] (falsy) the Tools section must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, tools=[]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "Tools:" not in calls_text - - # ------------------------------------------------------------------ - # user branch - # ------------------------------------------------------------------ - - def test_user_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): - """User string must appear in output when provided.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, user="alice") - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "alice" in calls_text - - def test_user_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): - """When user=None the User line must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, user=None) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "User:" not in calls_text - - # ------------------------------------------------------------------ - # stream branch - # ------------------------------------------------------------------ - - def test_stream_true_prints_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock): - """When stream=True the [on_llm_new_chunk] marker must be printed.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, stream=True) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "[on_llm_new_chunk]" in calls_text - - def test_stream_false_no_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock): - """When stream=False the [on_llm_new_chunk] marker must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, stream=False) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "[on_llm_new_chunk]" not in calls_text - - # ------------------------------------------------------------------ - # prompt_messages branch - # ------------------------------------------------------------------ - - def test_prompt_message_with_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """When a PromptMessage has a name it must be printed.""" - msg = _make_user_prompt("hi", name="bob") - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, prompt_messages=[msg]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "bob" in calls_text - - def test_prompt_message_without_name_skips_name_line(self, cb: LoggingCallback, llm_instance: MagicMock): - """When a PromptMessage has no name the name line must NOT appear.""" - msg = _make_user_prompt("hi", name=None) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, prompt_messages=[msg]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "\tname:" not in calls_text - - def test_prompt_message_role_and_content_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """Role and content of each PromptMessage must appear in output.""" - msg = _make_system_prompt("Be concise.") - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, prompt_messages=[msg]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "system" in calls_text - assert "Be concise." in calls_text - - def test_multiple_prompt_messages_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """All entries in prompt_messages are iterated and printed.""" - msgs = [ - _make_system_prompt("sys"), - _make_user_prompt("user msg"), - ] - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, prompt_messages=msgs) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "sys" in calls_text - assert "user msg" in calls_text - - # ------------------------------------------------------------------ - # Combination: everything provided - # ------------------------------------------------------------------ - - def test_all_optional_fields_combined(self, cb: LoggingCallback, llm_instance: MagicMock): - """Supply stop, tools, user, multiple params, named message – no exception.""" - msgs = [_make_user_prompt("question", name="alice")] - tools = [_make_tool("tool_a")] - with patch.object(cb, "print_text"): - self._invoke( - cb, - llm_instance, - model="gpt-3.5", - model_parameters={"temperature": 1.0, "top_p": 0.9}, - tools=tools, - stop=["DONE"], - stream=True, - user="alice", - prompt_messages=msgs, - ) - - -# =========================================================================== -# Tests for on_new_chunk -# =========================================================================== - - -class TestOnNewChunk: - """Tests for LoggingCallback.on_new_chunk.""" - - def test_chunk_content_written_to_stdout(self, cb: LoggingCallback, llm_instance: MagicMock): - """on_new_chunk must write the chunk's text content to sys.stdout.""" - chunk = _make_chunk("hello from LLM") - written = [] - - with patch("sys.stdout") as mock_stdout: - mock_stdout.write.side_effect = written.append - cb.on_new_chunk( - llm_instance=llm_instance, - chunk=chunk, - model="gpt-4", - credentials={}, - prompt_messages=[], - model_parameters={}, - ) - mock_stdout.write.assert_called_once_with("hello from LLM") - mock_stdout.flush.assert_called_once() - - def test_chunk_content_empty_string(self, cb: LoggingCallback, llm_instance: MagicMock): - """Works correctly even when the chunk content is an empty string.""" - chunk = _make_chunk("") - with patch("sys.stdout") as mock_stdout: - cb.on_new_chunk( - llm_instance=llm_instance, - chunk=chunk, - model="gpt-4", - credentials={}, - prompt_messages=[], - model_parameters={}, - ) - mock_stdout.write.assert_called_once_with("") - mock_stdout.flush.assert_called_once() - - def test_chunk_passes_all_optional_params(self, cb: LoggingCallback, llm_instance: MagicMock): - """All optional parameters are accepted without errors.""" - chunk = _make_chunk("data") - with patch("sys.stdout"): - cb.on_new_chunk( - llm_instance=llm_instance, - chunk=chunk, - model="gpt-4", - credentials={"key": "secret"}, - prompt_messages=[_make_user_prompt("q")], - model_parameters={"temperature": 0.5}, - tools=[_make_tool("t1")], - stop=["EOS"], - stream=True, - user="bob", - ) - - -# =========================================================================== -# Tests for on_after_invoke -# =========================================================================== - - -class TestOnAfterInvoke: - """Tests for LoggingCallback.on_after_invoke.""" - - def _invoke( - self, - cb: LoggingCallback, - llm_instance: MagicMock, - result: LLMResult, - **kwargs, - ): - cb.on_after_invoke( - llm_instance=llm_instance, - result=result, - model=result.model, - credentials={}, - prompt_messages=[], - model_parameters={}, - **kwargs, - ) - - def test_basic_result_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """After-invoke header, content, model, usage, fingerprint must be printed.""" - result = _make_llm_result() - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "[on_llm_after_invoke]" in calls_text - assert "hello world" in calls_text - assert "gpt-4" in calls_text - assert "fp-abc" in calls_text - - def test_no_tool_calls_skips_tool_call_block(self, cb: LoggingCallback, llm_instance: MagicMock): - """When there are no tool_calls the 'Tool calls:' block must NOT appear.""" - result = _make_llm_result(tool_calls=[]) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "Tool calls:" not in calls_text - - def test_with_tool_calls_prints_all_fields(self, cb: LoggingCallback, llm_instance: MagicMock): - """When tool_calls exist their id, name, and JSON arguments must be printed.""" - tc = _make_tool_call( - call_id="call-xyz", - func_name="fetch_data", - arguments='{"url": "https://example.com"}', - ) - result = _make_llm_result(tool_calls=[tc]) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "Tool calls:" in calls_text - assert "call-xyz" in calls_text - assert "fetch_data" in calls_text - # arguments should be JSON-dumped - assert "https://example.com" in calls_text - - def test_multiple_tool_calls_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """All tool calls in the list must be iterated.""" - tcs = [ - _make_tool_call("id-1", "func_a", '{"a": 1}'), - _make_tool_call("id-2", "func_b", '{"b": 2}'), - ] - result = _make_llm_result(tool_calls=tcs) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "id-1" in calls_text - assert "func_a" in calls_text - assert "id-2" in calls_text - assert "func_b" in calls_text - - def test_system_fingerprint_none_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """When system_fingerprint is None it should still be printed (as None).""" - result = _make_llm_result(system_fingerprint=None) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "System Fingerprint: None" in calls_text - - def test_usage_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """The usage object must appear in the printed output.""" - result = _make_llm_result() - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "Usage:" in calls_text - - def test_tool_call_arguments_are_json_dumped(self, cb: LoggingCallback, llm_instance: MagicMock): - """Verify json.dumps is applied to the arguments field (a string).""" - raw_args = '{"x": 42}' - tc = _make_tool_call(arguments=raw_args) - result = _make_llm_result(tool_calls=[tc]) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - - # Check if any call to print_text included the expected (json-encoded) arguments - # json.dumps(raw_args) produces a string starting and ending with quotes - expected_substring = json.dumps(raw_args) - found = any(expected_substring in str(call.args[0]) for call in mock_print.call_args_list) - assert found, f"Expected {expected_substring} to be printed in one of the calls" - - def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock): - """All optional parameters should be accepted without error.""" - result = _make_llm_result() - cb.on_after_invoke( - llm_instance=llm_instance, - result=result, - model=result.model, - credentials={"key": "secret"}, - prompt_messages=[_make_user_prompt("q")], - model_parameters={"temperature": 0.9}, - tools=[_make_tool("t")], - stop=[""], - stream=False, - user="carol", - ) - - -# =========================================================================== -# Tests for on_invoke_error -# =========================================================================== - - -class TestOnInvokeError: - """Tests for LoggingCallback.on_invoke_error.""" - - def _invoke_error( - self, - cb: LoggingCallback, - llm_instance: MagicMock, - ex: Exception, - **kwargs, - ): - cb.on_invoke_error( - llm_instance=llm_instance, - ex=ex, - model="gpt-4", - credentials={}, - prompt_messages=[], - model_parameters={}, - **kwargs, - ) - - def test_prints_error_header(self, cb: LoggingCallback, llm_instance: MagicMock): - """The [on_llm_invoke_error] banner must be printed.""" - with patch.object(cb, "print_text") as mock_print: - with patch("graphon.model_runtime.callbacks.logging_callback.logger") as mock_logger: - self._invoke_error(cb, llm_instance, RuntimeError("boom")) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "[on_llm_invoke_error]" in calls_text - - def test_exception_logged_via_logger_exception(self, cb: LoggingCallback, llm_instance: MagicMock): - """logger.exception must be called with the exception.""" - ex = ValueError("something went wrong") - with patch.object(cb, "print_text"): - with patch("graphon.model_runtime.callbacks.logging_callback.logger") as mock_logger: - self._invoke_error(cb, llm_instance, ex) - mock_logger.exception.assert_called_once_with(ex) - - def test_exception_type_variety(self, cb: LoggingCallback, llm_instance: MagicMock): - """Works with any exception type (TypeError, IOError, etc.).""" - for exc_cls in (TypeError, IOError, KeyError, Exception): - ex = exc_cls("error") - with patch.object(cb, "print_text"): - with patch("graphon.model_runtime.callbacks.logging_callback.logger") as mock_logger: - self._invoke_error(cb, llm_instance, ex) - mock_logger.exception.assert_called_once_with(ex) - - def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock): - """All optional parameters should be accepted without error.""" - ex = RuntimeError("fail") - with patch.object(cb, "print_text"): - with patch("graphon.model_runtime.callbacks.logging_callback.logger"): - cb.on_invoke_error( - llm_instance=llm_instance, - ex=ex, - model="gpt-4", - credentials={"key": "secret"}, - prompt_messages=[_make_user_prompt("q")], - model_parameters={"temperature": 0.7}, - tools=[_make_tool("t")], - stop=["STOP"], - stream=True, - user="dave", - ) - - -# =========================================================================== -# Tests for print_text (inherited from Callback, exercised through LoggingCallback) -# =========================================================================== - - -class TestPrintText: - """Verify that print_text from the Callback base class works correctly.""" - - def test_print_text_with_color(self, cb: LoggingCallback, capsys): - """print_text with a known colour should emit an ANSI escape sequence.""" - cb.print_text("hello", color="blue") - captured = capsys.readouterr() - assert "hello" in captured.out - # ANSI escape codes should be present - assert "\x1b[" in captured.out - - def test_print_text_without_color(self, cb: LoggingCallback, capsys): - """print_text without colour should print plain text.""" - cb.print_text("plain text") - captured = capsys.readouterr() - assert "plain text" in captured.out - - def test_print_text_all_colours(self, cb: LoggingCallback, capsys): - """Verify all supported colour keys don't raise.""" - for colour in ("blue", "yellow", "pink", "green", "red"): - cb.print_text("x", color=colour) - captured = capsys.readouterr() - # All outputs should contain 'x' (5 calls) - assert captured.out.count("x") >= 5 - - -# =========================================================================== -# Integration-style test: real print_text called (no mocking) -# =========================================================================== - - -class TestLoggingCallbackIntegration: - """Light integration tests – real print_text calls, just checking no exceptions.""" - - def test_on_before_invoke_full_run(self, capsys): - """Full on_before_invoke run with all optional fields – verifies real output.""" - cb = LoggingCallback() - llm = MagicMock() - msgs = [_make_user_prompt("Who are you?", name="tester")] - tools = [_make_tool("calculator")] - cb.on_before_invoke( - llm_instance=llm, - model="gpt-4-turbo", - credentials={"api_key": "sk-xxx"}, - prompt_messages=msgs, - model_parameters={"temperature": 0.8}, - tools=tools, - stop=["STOP"], - stream=True, - user="test_user", - ) - captured = capsys.readouterr() - assert "gpt-4-turbo" in captured.out - assert "calculator" in captured.out - assert "test_user" in captured.out - assert "STOP" in captured.out - assert "tester" in captured.out - - def test_on_new_chunk_full_run(self, capsys): - """Full on_new_chunk run – verifies real stdout write.""" - cb = LoggingCallback() - chunk = _make_chunk("streaming token") - cb.on_new_chunk( - llm_instance=MagicMock(), - chunk=chunk, - model="gpt-4", - credentials={}, - prompt_messages=[], - model_parameters={}, - ) - captured = capsys.readouterr() - assert "streaming token" in captured.out - - def test_on_after_invoke_full_run_with_tool_calls(self, capsys): - """Full on_after_invoke run with tool calls – verifies real output.""" - cb = LoggingCallback() - tc = _make_tool_call("call-99", "do_thing", '{"n": 5}') - result = _make_llm_result(content="result content", tool_calls=[tc], system_fingerprint="fp-xyz") - cb.on_after_invoke( - llm_instance=MagicMock(), - result=result, - model=result.model, - credentials={}, - prompt_messages=[], - model_parameters={}, - ) - captured = capsys.readouterr() - assert "result content" in captured.out - assert "call-99" in captured.out - assert "do_thing" in captured.out - assert "fp-xyz" in captured.out - - def test_on_invoke_error_full_run(self, capsys): - """Full on_invoke_error run – just verifies no exception is raised.""" - cb = LoggingCallback() - ex = RuntimeError("something bad happened") - # logger.exception writes to stderr; we just confirm it doesn't crash - cb.on_invoke_error( - llm_instance=MagicMock(), - ex=ex, - model="gpt-4", - credentials={}, - prompt_messages=[], - model_parameters={}, - ) - captured = capsys.readouterr() - assert "[on_llm_invoke_error]" in captured.out diff --git a/api/tests/unit_tests/graphon/model_runtime/entities/test_common_entities.py b/api/tests/unit_tests/graphon/model_runtime/entities/test_common_entities.py deleted file mode 100644 index 7d6255c37a..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/entities/test_common_entities.py +++ /dev/null @@ -1,35 +0,0 @@ -from graphon.model_runtime.entities.common_entities import I18nObject - - -class TestI18nObject: - def test_i18n_object_with_both_languages(self): - """ - Test I18nObject when both zh_Hans and en_US are provided. - """ - i18n = I18nObject(zh_Hans="你好", en_US="Hello") - assert i18n.zh_Hans == "你好" - assert i18n.en_US == "Hello" - - def test_i18n_object_fallback_to_en_us(self): - """ - Test I18nObject when zh_Hans is missing, it should fallback to en_US. - """ - i18n = I18nObject(en_US="Hello") - assert i18n.zh_Hans == "Hello" - assert i18n.en_US == "Hello" - - def test_i18n_object_with_none_zh_hans(self): - """ - Test I18nObject when zh_Hans is None, it should fallback to en_US. - """ - i18n = I18nObject(zh_Hans=None, en_US="Hello") - assert i18n.zh_Hans == "Hello" - assert i18n.en_US == "Hello" - - def test_i18n_object_with_empty_zh_hans(self): - """ - Test I18nObject when zh_Hans is an empty string, it should fallback to en_US. - """ - i18n = I18nObject(zh_Hans="", en_US="Hello") - assert i18n.zh_Hans == "Hello" - assert i18n.en_US == "Hello" diff --git a/api/tests/unit_tests/graphon/model_runtime/entities/test_llm_entities.py b/api/tests/unit_tests/graphon/model_runtime/entities/test_llm_entities.py deleted file mode 100644 index 51a6c38fa9..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/entities/test_llm_entities.py +++ /dev/null @@ -1,148 +0,0 @@ -"""Tests for LLMUsage entity.""" - -from decimal import Decimal - -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata - - -class TestLLMUsage: - """Test cases for LLMUsage class.""" - - def test_from_metadata_with_all_tokens(self): - """Test from_metadata when all token types are provided.""" - metadata: LLMUsageMetadata = { - "prompt_tokens": 100, - "completion_tokens": 50, - "total_tokens": 150, - "prompt_unit_price": 0.001, - "completion_unit_price": 0.002, - "total_price": 0.2, - "currency": "USD", - "latency": 1.5, - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 100 - assert usage.completion_tokens == 50 - assert usage.total_tokens == 150 - assert usage.prompt_unit_price == Decimal("0.001") - assert usage.completion_unit_price == Decimal("0.002") - assert usage.total_price == Decimal("0.2") - assert usage.currency == "USD" - assert usage.latency == 1.5 - - def test_from_metadata_with_prompt_tokens_only(self): - """Test from_metadata when only prompt_tokens is provided.""" - metadata: LLMUsageMetadata = { - "prompt_tokens": 100, - "total_tokens": 100, - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 100 - assert usage.completion_tokens == 0 - assert usage.total_tokens == 100 - - def test_from_metadata_with_completion_tokens_only(self): - """Test from_metadata when only completion_tokens is provided.""" - metadata: LLMUsageMetadata = { - "completion_tokens": 50, - "total_tokens": 50, - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 0 - assert usage.completion_tokens == 50 - assert usage.total_tokens == 50 - - def test_from_metadata_calculates_total_when_missing(self): - """Test from_metadata calculates total_tokens when not provided.""" - metadata: LLMUsageMetadata = { - "prompt_tokens": 100, - "completion_tokens": 50, - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 100 - assert usage.completion_tokens == 50 - assert usage.total_tokens == 150 # Should be calculated - - def test_from_metadata_with_total_but_no_completion(self): - """ - Test from_metadata when total_tokens is provided but completion_tokens is 0. - This tests the fix for issue #24360 - prompt tokens should NOT be assigned to completion_tokens. - """ - metadata: LLMUsageMetadata = { - "prompt_tokens": 479, - "completion_tokens": 0, - "total_tokens": 521, - } - - usage = LLMUsage.from_metadata(metadata) - - # This is the key fix - prompt tokens should remain as prompt tokens - assert usage.prompt_tokens == 479 - assert usage.completion_tokens == 0 - assert usage.total_tokens == 521 - - def test_from_metadata_with_empty_metadata(self): - """Test from_metadata with empty metadata.""" - metadata: LLMUsageMetadata = {} - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 0 - assert usage.completion_tokens == 0 - assert usage.total_tokens == 0 - assert usage.currency == "USD" - assert usage.latency == 0.0 - - def test_from_metadata_preserves_zero_completion_tokens(self): - """ - Test that zero completion_tokens are preserved when explicitly set. - This is important for agent nodes that only use prompt tokens. - """ - metadata: LLMUsageMetadata = { - "prompt_tokens": 1000, - "completion_tokens": 0, - "total_tokens": 1000, - "prompt_unit_price": 0.15, - "completion_unit_price": 0.60, - "prompt_price": 0.00015, - "completion_price": 0, - "total_price": 0.00015, - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 1000 - assert usage.completion_tokens == 0 - assert usage.total_tokens == 1000 - assert usage.prompt_price == Decimal("0.00015") - assert usage.completion_price == Decimal(0) - assert usage.total_price == Decimal("0.00015") - - def test_from_metadata_with_decimal_values(self): - """Test from_metadata handles decimal values correctly.""" - metadata: LLMUsageMetadata = { - "prompt_tokens": 100, - "completion_tokens": 50, - "total_tokens": 150, - "prompt_unit_price": "0.001", - "completion_unit_price": "0.002", - "prompt_price": "0.1", - "completion_price": "0.1", - "total_price": "0.2", - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_unit_price == Decimal("0.001") - assert usage.completion_unit_price == Decimal("0.002") - assert usage.prompt_price == Decimal("0.1") - assert usage.completion_price == Decimal("0.1") - assert usage.total_price == Decimal("0.2") diff --git a/api/tests/unit_tests/graphon/model_runtime/entities/test_message_entities.py b/api/tests/unit_tests/graphon/model_runtime/entities/test_message_entities.py deleted file mode 100644 index 1918c324cc..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/entities/test_message_entities.py +++ /dev/null @@ -1,210 +0,0 @@ -import pytest - -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - AudioPromptMessageContent, - DocumentPromptMessageContent, - ImagePromptMessageContent, - PromptMessageContent, - PromptMessageContentType, - PromptMessageFunction, - PromptMessageRole, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, - VideoPromptMessageContent, -) - - -class TestPromptMessageRole: - def test_value_of(self): - assert PromptMessageRole.value_of("system") == PromptMessageRole.SYSTEM - assert PromptMessageRole.value_of("user") == PromptMessageRole.USER - assert PromptMessageRole.value_of("assistant") == PromptMessageRole.ASSISTANT - assert PromptMessageRole.value_of("tool") == PromptMessageRole.TOOL - - with pytest.raises(ValueError, match="invalid prompt message type value invalid"): - PromptMessageRole.value_of("invalid") - - -class TestPromptMessageEntities: - def test_prompt_message_tool(self): - tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"}) - assert tool.name == "test_tool" - assert tool.description == "test desc" - assert tool.parameters == {"foo": "bar"} - - def test_prompt_message_function(self): - tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"}) - func = PromptMessageFunction(function=tool) - assert func.type == "function" - assert func.function == tool - - -class TestPromptMessageContent: - def test_text_content(self): - content = TextPromptMessageContent(data="hello") - assert content.type == PromptMessageContentType.TEXT - assert content.data == "hello" - - def test_image_content(self): - content = ImagePromptMessageContent( - format="jpg", base64_data="abc", mime_type="image/jpeg", detail=ImagePromptMessageContent.DETAIL.HIGH - ) - assert content.type == PromptMessageContentType.IMAGE - assert content.detail == ImagePromptMessageContent.DETAIL.HIGH - assert content.data == "data:image/jpeg;base64,abc" - - def test_image_content_url(self): - content = ImagePromptMessageContent(format="jpg", url="https://example.com/image.jpg", mime_type="image/jpeg") - assert content.data == "https://example.com/image.jpg" - - def test_audio_content(self): - content = AudioPromptMessageContent(format="mp3", base64_data="abc", mime_type="audio/mpeg") - assert content.type == PromptMessageContentType.AUDIO - assert content.data == "data:audio/mpeg;base64,abc" - - def test_video_content(self): - content = VideoPromptMessageContent(format="mp4", base64_data="abc", mime_type="video/mp4") - assert content.type == PromptMessageContentType.VIDEO - assert content.data == "data:video/mp4;base64,abc" - - def test_document_content(self): - content = DocumentPromptMessageContent(format="pdf", base64_data="abc", mime_type="application/pdf") - assert content.type == PromptMessageContentType.DOCUMENT - assert content.data == "data:application/pdf;base64,abc" - - -class TestPromptMessages: - def test_user_prompt_message(self): - msg = UserPromptMessage(content="hello") - assert msg.role == PromptMessageRole.USER - assert msg.content == "hello" - assert msg.is_empty() is False - assert msg.get_text_content() == "hello" - - def test_user_prompt_message_complex_content(self): - content = [TextPromptMessageContent(data="hello "), TextPromptMessageContent(data="world")] - msg = UserPromptMessage(content=content) - assert msg.get_text_content() == "hello world" - - # Test validation from dict - msg2 = UserPromptMessage(content=[{"type": "text", "data": "hi"}]) - assert isinstance(msg2.content[0], TextPromptMessageContent) - assert msg2.content[0].data == "hi" - - def test_prompt_message_empty(self): - msg = UserPromptMessage(content=None) - assert msg.is_empty() is True - assert msg.get_text_content() == "" - - def test_assistant_prompt_message(self): - msg = AssistantPromptMessage(content="thinking...") - assert msg.role == PromptMessageRole.ASSISTANT - assert msg.is_empty() is False - - tool_call = AssistantPromptMessage.ToolCall( - id="call_1", - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"), - ) - msg_with_tools = AssistantPromptMessage(content=None, tool_calls=[tool_call]) - assert msg_with_tools.is_empty() is False - assert msg_with_tools.role == PromptMessageRole.ASSISTANT - - def test_assistant_tool_call_id_transform(self): - tool_call = AssistantPromptMessage.ToolCall( - id=123, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"), - ) - assert tool_call.id == "123" - - def test_system_prompt_message(self): - msg = SystemPromptMessage(content="you are a bot") - assert msg.role == PromptMessageRole.SYSTEM - assert msg.content == "you are a bot" - - def test_tool_prompt_message(self): - # Case 1: Both content and tool_call_id are present - msg = ToolPromptMessage(content="result", tool_call_id="call_1") - assert msg.role == PromptMessageRole.TOOL - assert msg.tool_call_id == "call_1" - assert msg.is_empty() is False - - # Case 2: Content is present, but tool_call_id is empty - msg_content_only = ToolPromptMessage(content="result", tool_call_id="") - assert msg_content_only.is_empty() is False - - # Case 3: Content is None, but tool_call_id is present - msg_id_only = ToolPromptMessage(content=None, tool_call_id="call_1") - assert msg_id_only.is_empty() is False - - # Case 4: Both content and tool_call_id are empty - msg_empty = ToolPromptMessage(content=None, tool_call_id="") - assert msg_empty.is_empty() is True - - def test_prompt_message_validation_errors(self): - with pytest.raises(KeyError): - # Invalid content type in list - UserPromptMessage(content=[{"type": "invalid", "data": "foo"}]) - - with pytest.raises(ValueError, match="invalid prompt message"): - # Not a dict or PromptMessageContent - UserPromptMessage(content=[123]) - - def test_prompt_message_serialization(self): - # Case: content is None - assert UserPromptMessage(content=None).serialize_content(None) is None - - # Case: content is str - assert UserPromptMessage(content="hello").serialize_content("hello") == "hello" - - # Case: content is list of dict - content_list = [{"type": "text", "data": "hi"}] - msg = UserPromptMessage(content=content_list) - assert msg.serialize_content(msg.content) == [{"type": PromptMessageContentType.TEXT, "data": "hi"}] - - # Case: content is Sequence but not list (e.g. tuple) - # To hit line 204, we can call serialize_content manually or - # try to pass a type that pydantic doesn't convert to list in its internal state. - # Actually, let's just call it manually on the instance. - msg = UserPromptMessage(content="test") - content_tuple = (TextPromptMessageContent(data="hi"),) - assert msg.serialize_content(content_tuple) == content_tuple - - def test_prompt_message_mixed_content_validation(self): - # Test branch: isinstance(prompt, PromptMessageContent) - # but not (TextPromptMessageContent | MultiModalPromptMessageContent) - # Line 187: prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump()) - - # We need a PromptMessageContent that is NOT Text or MultiModal. - # But PromptMessageContentUnionTypes discriminator handles this usually. - # We can bypass high-level validation by passing the object directly in a list. - - class MockContent(PromptMessageContent): - type: PromptMessageContentType = PromptMessageContentType.TEXT - data: str - - mock_item = MockContent(data="test") - msg = UserPromptMessage(content=[mock_item]) - # It should hit line 187 and convert to TextPromptMessageContent - assert isinstance(msg.content[0], TextPromptMessageContent) - assert msg.content[0].data == "test" - - def test_prompt_message_get_text_content_branches(self): - # content is None - msg_none = UserPromptMessage(content=None) - assert msg_none.get_text_content() == "" - - # content is list but no text content - image = ImagePromptMessageContent(format="jpg", base64_data="abc", mime_type="image/jpeg") - msg_image = UserPromptMessage(content=[image]) - assert msg_image.get_text_content() == "" - - # content is list with mixed - text = TextPromptMessageContent(data="hello") - msg_mixed = UserPromptMessage(content=[text, image]) - assert msg_mixed.get_text_content() == "hello" diff --git a/api/tests/unit_tests/graphon/model_runtime/entities/test_model_entities.py b/api/tests/unit_tests/graphon/model_runtime/entities/test_model_entities.py deleted file mode 100644 index 1988709faa..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/entities/test_model_entities.py +++ /dev/null @@ -1,220 +0,0 @@ -from decimal import Decimal - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ( - AIModelEntity, - DefaultParameterName, - FetchFrom, - ModelFeature, - ModelPropertyKey, - ModelType, - ModelUsage, - ParameterRule, - ParameterType, - PriceConfig, - PriceInfo, - PriceType, - ProviderModel, -) - - -class TestModelType: - def test_value_of(self): - assert ModelType.value_of("text-generation") == ModelType.LLM - assert ModelType.value_of(ModelType.LLM) == ModelType.LLM - assert ModelType.value_of("embeddings") == ModelType.TEXT_EMBEDDING - assert ModelType.value_of(ModelType.TEXT_EMBEDDING) == ModelType.TEXT_EMBEDDING - assert ModelType.value_of("reranking") == ModelType.RERANK - assert ModelType.value_of(ModelType.RERANK) == ModelType.RERANK - assert ModelType.value_of("speech2text") == ModelType.SPEECH2TEXT - assert ModelType.value_of(ModelType.SPEECH2TEXT) == ModelType.SPEECH2TEXT - assert ModelType.value_of("tts") == ModelType.TTS - assert ModelType.value_of(ModelType.TTS) == ModelType.TTS - assert ModelType.value_of(ModelType.MODERATION) == ModelType.MODERATION - - with pytest.raises(ValueError, match="invalid origin model type invalid"): - ModelType.value_of("invalid") - - def test_to_origin_model_type(self): - assert ModelType.LLM.to_origin_model_type() == "text-generation" - assert ModelType.TEXT_EMBEDDING.to_origin_model_type() == "embeddings" - assert ModelType.RERANK.to_origin_model_type() == "reranking" - assert ModelType.SPEECH2TEXT.to_origin_model_type() == "speech2text" - assert ModelType.TTS.to_origin_model_type() == "tts" - assert ModelType.MODERATION.to_origin_model_type() == "moderation" - - # Testing the else branch in to_origin_model_type - # Since it's a StrEnum, it's hard to get an invalid value here unless we mock or Force it. - # But if we look at the implementation: - # if self == self.LLM: ... elif ... else: raise ValueError - # We can try to create a "dummy" member if possible, or just skip it if we have 100% coverage otherwise. - # Actually, adding a new member to an enum at runtime is possible but messy. - # Let's see if we can trigger it. - - -class TestFetchFrom: - def test_values(self): - assert FetchFrom.PREDEFINED_MODEL == "predefined-model" - assert FetchFrom.CUSTOMIZABLE_MODEL == "customizable-model" - - -class TestModelFeature: - def test_values(self): - assert ModelFeature.TOOL_CALL == "tool-call" - assert ModelFeature.MULTI_TOOL_CALL == "multi-tool-call" - assert ModelFeature.AGENT_THOUGHT == "agent-thought" - assert ModelFeature.VISION == "vision" - assert ModelFeature.STREAM_TOOL_CALL == "stream-tool-call" - assert ModelFeature.DOCUMENT == "document" - assert ModelFeature.VIDEO == "video" - assert ModelFeature.AUDIO == "audio" - assert ModelFeature.STRUCTURED_OUTPUT == "structured-output" - - -class TestDefaultParameterName: - def test_value_of(self): - assert DefaultParameterName.value_of("temperature") == DefaultParameterName.TEMPERATURE - assert DefaultParameterName.value_of("top_p") == DefaultParameterName.TOP_P - - with pytest.raises(ValueError, match="invalid parameter name invalid"): - DefaultParameterName.value_of("invalid") - - -class TestParameterType: - def test_values(self): - assert ParameterType.FLOAT == "float" - assert ParameterType.INT == "int" - assert ParameterType.STRING == "string" - assert ParameterType.BOOLEAN == "boolean" - assert ParameterType.TEXT == "text" - - -class TestModelPropertyKey: - def test_values(self): - assert ModelPropertyKey.MODE == "mode" - assert ModelPropertyKey.CONTEXT_SIZE == "context_size" - - -class TestProviderModel: - def test_provider_model(self): - model = ProviderModel( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - ) - assert model.model == "gpt-4" - assert model.support_structure_output is False - - model_with_features = ProviderModel( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - features=[ModelFeature.STRUCTURED_OUTPUT], - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - ) - assert model_with_features.support_structure_output is True - - -class TestParameterRule: - def test_parameter_rule(self): - rule = ParameterRule( - name="temperature", - label=I18nObject(en_US="Temperature"), - type=ParameterType.FLOAT, - default=0.7, - min=0.0, - max=1.0, - precision=2, - ) - assert rule.name == "temperature" - assert rule.default == 0.7 - - -class TestPriceConfig: - def test_price_config(self): - config = PriceConfig(input=Decimal("0.01"), output=Decimal("0.02"), unit=Decimal("0.001"), currency="USD") - assert config.input == Decimal("0.01") - assert config.output == Decimal("0.02") - - -class TestAIModelEntity: - def test_ai_model_entity_no_json_schema(self): - entity = AIModelEntity( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - parameter_rules=[ - ParameterRule(name="temperature", label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT) - ], - ) - assert ModelFeature.STRUCTURED_OUTPUT not in (entity.features or []) - - def test_ai_model_entity_with_json_schema(self): - # Case: json_schema in parameter rules, features is None - entity = AIModelEntity( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - parameter_rules=[ - ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) - ], - ) - assert ModelFeature.STRUCTURED_OUTPUT in entity.features - - def test_ai_model_entity_with_json_schema_and_features_empty(self): - # Case: json_schema in parameter rules, features is empty list - entity = AIModelEntity( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - features=[], - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - parameter_rules=[ - ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) - ], - ) - assert ModelFeature.STRUCTURED_OUTPUT in entity.features - - def test_ai_model_entity_with_json_schema_and_other_features(self): - # Case: json_schema in parameter rules, features has other things - entity = AIModelEntity( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - features=[ModelFeature.VISION], - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - parameter_rules=[ - ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) - ], - ) - assert ModelFeature.STRUCTURED_OUTPUT in entity.features - assert ModelFeature.VISION in entity.features - - -class TestModelUsage: - def test_model_usage(self): - usage = ModelUsage() - assert isinstance(usage, ModelUsage) - - -class TestPriceType: - def test_values(self): - assert PriceType.INPUT == "input" - assert PriceType.OUTPUT == "output" - - -class TestPriceInfo: - def test_price_info(self): - info = PriceInfo(unit_price=Decimal("0.01"), unit=Decimal(1000), total_amount=Decimal("0.05"), currency="USD") - assert info.total_amount == Decimal("0.05") diff --git a/api/tests/unit_tests/graphon/model_runtime/errors/test_invoke.py b/api/tests/unit_tests/graphon/model_runtime/errors/test_invoke.py deleted file mode 100644 index 2004822230..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/errors/test_invoke.py +++ /dev/null @@ -1,63 +0,0 @@ -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) - - -class TestInvokeErrors: - def test_invoke_error_with_description(self): - error = InvokeError("Custom description") - assert error.description == "Custom description" - assert str(error) == "Custom description" - assert isinstance(error, ValueError) - - def test_invoke_error_without_description(self): - error = InvokeError() - assert error.description is None - assert str(error) == "InvokeError" - - def test_invoke_connection_error(self): - # Now preserves class-level description - error = InvokeConnectionError() - assert error.description == "Connection Error" - assert str(error) == "Connection Error" - assert isinstance(error, InvokeError) - - # Test with explicit description - error_with_desc = InvokeConnectionError("Connection Error") - assert error_with_desc.description == "Connection Error" - assert str(error_with_desc) == "Connection Error" - - def test_invoke_server_unavailable_error(self): - error = InvokeServerUnavailableError() - assert error.description == "Server Unavailable Error" - assert str(error) == "Server Unavailable Error" - assert isinstance(error, InvokeError) - - def test_invoke_rate_limit_error(self): - error = InvokeRateLimitError() - assert error.description == "Rate Limit Error" - assert str(error) == "Rate Limit Error" - assert isinstance(error, InvokeError) - - def test_invoke_authorization_error(self): - error = InvokeAuthorizationError() - assert error.description == "Incorrect model credentials provided, please check and try again. " - assert str(error) == "Incorrect model credentials provided, please check and try again. " - assert isinstance(error, InvokeError) - - def test_invoke_bad_request_error(self): - error = InvokeBadRequestError() - assert error.description == "Bad Request Error" - assert str(error) == "Bad Request Error" - assert isinstance(error, InvokeError) - - def test_invoke_error_inheritance(self): - # Test that we can override the default description in subclasses - error = InvokeBadRequestError("Overridden Error") - assert error.description == "Overridden Error" - assert str(error) == "Overridden Error" diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_ai_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_ai_model.py deleted file mode 100644 index 64edd69789..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_ai_model.py +++ /dev/null @@ -1,254 +0,0 @@ -import decimal -from unittest.mock import MagicMock, patch - -import pytest -from pydantic import BaseModel - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ( - AIModelEntity, - DefaultParameterName, - FetchFrom, - ModelPropertyKey, - ModelType, - ParameterRule, - ParameterType, - PriceConfig, - PriceType, -) -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - - -class _ConcreteAIModel(AIModel): - model_type = ModelType.LLM - - -class TestAIModel: - @pytest.fixture - def provider_schema(self) -> ProviderEntity: - return ProviderEntity( - provider="langgenius/openai/openai", - provider_name="openai", - label=I18nObject(en_US="OpenAI"), - supported_model_types=[ModelType.LLM], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - @pytest.fixture - def model_runtime(self) -> MagicMock: - return MagicMock() - - @pytest.fixture - def ai_model(self, provider_schema: ProviderEntity, model_runtime: MagicMock) -> AIModel: - return _ConcreteAIModel( - provider_schema=provider_schema, - model_runtime=model_runtime, - ) - - def test_init_stores_runtime_state_and_is_not_pydantic_model( - self, ai_model: AIModel, provider_schema: ProviderEntity, model_runtime: MagicMock - ) -> None: - assert ai_model.model_type == ModelType.LLM - assert ai_model.provider_schema is provider_schema - assert ai_model.model_runtime is model_runtime - assert ai_model.provider == "langgenius/openai/openai" - assert ai_model.provider_display_name == "OpenAI" - assert ai_model.started_at == 0 - assert not isinstance(ai_model, BaseModel) - - def test_direct_base_class_requires_subclass_model_type( - self, provider_schema: ProviderEntity, model_runtime: MagicMock - ) -> None: - with pytest.raises(TypeError, match="subclasses must define model_type"): - AIModel(provider_schema=provider_schema, model_runtime=model_runtime) - - def test_subclass_uses_class_level_model_type( - self, provider_schema: ProviderEntity, model_runtime: MagicMock - ) -> None: - model = _ConcreteAIModel(provider_schema=provider_schema, model_runtime=model_runtime) - assert model.model_type == ModelType.LLM - - def test_invoke_error_mapping(self, ai_model: AIModel) -> None: - mapping = ai_model._invoke_error_mapping - assert InvokeConnectionError in mapping - assert InvokeServerUnavailableError in mapping - assert InvokeRateLimitError in mapping - assert InvokeAuthorizationError in mapping - assert InvokeBadRequestError in mapping - assert ValueError in mapping - - def test_transform_invoke_error(self, ai_model: AIModel) -> None: - err = Exception("Original error") - - with patch.object(AIModel, "_invoke_error_mapping", {InvokeAuthorizationError: [Exception]}): - transformed = ai_model._transform_invoke_error(err) - assert isinstance(transformed, InvokeAuthorizationError) - assert "Incorrect model credentials provided" in str(transformed.description) - - class CustomNonInvokeError(Exception): - pass - - with patch.object(AIModel, "_invoke_error_mapping", {CustomNonInvokeError: [Exception]}): - transformed = ai_model._transform_invoke_error(err) - assert transformed == err - - transformed = ai_model._transform_invoke_error(Exception("Unmapped")) - assert isinstance(transformed, InvokeError) - assert transformed.description == "[OpenAI] Error: Unmapped" - - def test_get_price(self, ai_model: AIModel) -> None: - model_name = "test_model" - credentials = {"key": "value"} - - mock_schema = MagicMock(spec=AIModelEntity) - mock_schema.pricing = PriceConfig( - input=decimal.Decimal("0.002"), - output=decimal.Decimal("0.004"), - unit=decimal.Decimal(1000), - currency="USD", - ) - - with patch.object(AIModel, "get_model_schema", return_value=mock_schema): - price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 2000) - assert price_info.unit_price == decimal.Decimal("0.002") - - price_info = ai_model.get_price(model_name, credentials, PriceType.OUTPUT, 2000) - assert price_info.unit_price == decimal.Decimal("0.004") - - mock_schema.pricing = None - with patch.object(AIModel, "get_model_schema", return_value=mock_schema): - price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 1000) - assert price_info.total_amount == decimal.Decimal("0.0") - - def test_get_price_no_price_config_error(self, ai_model: AIModel) -> None: - class ChangingPriceConfig: - def __init__(self) -> None: - self.input = decimal.Decimal("0.01") - self.unit = decimal.Decimal(1) - self.currency = "USD" - self.called = 0 - - def __bool__(self) -> bool: - self.called += 1 - return self.called <= 2 - - mock_schema = MagicMock() - mock_schema.pricing = ChangingPriceConfig() - - with patch.object(AIModel, "get_model_schema", return_value=mock_schema): - with pytest.raises(ValueError, match="Price config not found"): - ai_model.get_price("test_model", {}, PriceType.INPUT, 1000) - - def test_get_model_schema_delegates_to_runtime( - self, ai_model: AIModel, model_runtime: MagicMock, provider_schema: ProviderEntity - ) -> None: - model_name = "test_model" - credentials = {"api_key": "abc"} - - mock_schema = AIModelEntity( - model="test_model", - label=I18nObject(en_US="Test Model"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, - parameter_rules=[], - ) - model_runtime.get_model_schema.return_value = mock_schema - - schema = ai_model.get_model_schema(model_name, credentials) - - assert schema == mock_schema - model_runtime.get_model_schema.assert_called_once_with( - provider=provider_schema.provider, - model_type=ModelType.LLM, - model=model_name, - credentials=credentials, - ) - - def test_get_customizable_model_schema_from_credentials_template_mapping_value_error( - self, ai_model: AIModel - ) -> None: - mock_schema = AIModelEntity( - model="test_model", - label=I18nObject(en_US="Test Model"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, - parameter_rules=[ - ParameterRule( - name="invalid", - use_template="invalid_template_name", - label=I18nObject(en_US="Invalid"), - type=ParameterType.FLOAT, - ) - ], - ) - - with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema): - schema = ai_model.get_customizable_model_schema_from_credentials("test_model", {}) - assert schema is not None - assert schema.parameter_rules[0].use_template == "invalid_template_name" - - def test_get_customizable_model_schema_from_credentials(self, ai_model: AIModel) -> None: - mock_schema = AIModelEntity( - model="test_model", - label=I18nObject(en_US="Test Model"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, - parameter_rules=[ - ParameterRule( - name="temp", use_template="temperature", label=I18nObject(en_US="Temp"), type=ParameterType.FLOAT - ), - ParameterRule( - name="top_p", - use_template="top_p", - label=I18nObject(en_US="Top P"), - type=ParameterType.FLOAT, - help=I18nObject(en_US=""), - ), - ParameterRule( - name="max_tokens", - use_template="max_tokens", - label=I18nObject(en_US="Max Tokens"), - type=ParameterType.INT, - help=I18nObject(en_US="", zh_Hans=""), - ), - ParameterRule(name="custom", label=I18nObject(en_US="Custom"), type=ParameterType.STRING), - ], - ) - - with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema): - schema = ai_model.get_customizable_model_schema_from_credentials("test_model", {}) - - assert schema is not None - assert schema.parameter_rules[0].max == 1.0 - assert schema.parameter_rules[1].help is not None - assert schema.parameter_rules[1].help.en_US != "" - assert schema.parameter_rules[2].help is not None - assert schema.parameter_rules[2].help.zh_Hans != "" - assert schema.parameter_rules[3].use_template is None - - def test_get_customizable_model_schema_from_credentials_none(self, ai_model: AIModel) -> None: - with patch.object(AIModel, "get_customizable_model_schema", return_value=None): - schema = ai_model.get_customizable_model_schema_from_credentials("model", {}) - assert schema is None - - def test_get_customizable_model_schema_default(self, ai_model: AIModel) -> None: - assert ai_model.get_customizable_model_schema("model", {}) is None - - def test_get_default_parameter_rule_variable_map(self, ai_model: AIModel) -> None: - result = ai_model._get_default_parameter_rule_variable_map(DefaultParameterName.TEMPERATURE) - assert result["default"] == 0.0 - - with pytest.raises(Exception, match="Invalid model parameter rule name"): - ai_model._get_default_parameter_rule_variable_map("invalid_name") diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_large_language_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_large_language_model.py deleted file mode 100644 index 668a7e3476..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_large_language_model.py +++ /dev/null @@ -1,452 +0,0 @@ -import logging -from collections.abc import Generator, Iterator, Sequence -from dataclasses import dataclass, field -from decimal import Decimal -from types import SimpleNamespace -from typing import Any -from unittest.mock import MagicMock - -import pytest - -import graphon.model_runtime.model_providers.__base.large_language_model as llm_module - -# Access large_language_model members via llm_module to avoid partial import issues in CI -from graphon.model_runtime.callbacks.base_callback import Callback -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMUsage, -) -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelType, PriceInfo -from graphon.model_runtime.model_providers.__base.large_language_model import _build_llm_result_from_chunks - - -def _usage(prompt_tokens: int = 1, completion_tokens: int = 2) -> LLMUsage: - return LLMUsage( - prompt_tokens=prompt_tokens, - prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal(1), - prompt_price=Decimal(prompt_tokens) * Decimal("0.001"), - completion_tokens=completion_tokens, - completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal(1), - completion_price=Decimal(completion_tokens) * Decimal("0.002"), - total_tokens=prompt_tokens + completion_tokens, - total_price=Decimal(prompt_tokens) * Decimal("0.001") + Decimal(completion_tokens) * Decimal("0.002"), - currency="USD", - latency=0.0, - ) - - -def _tool_call_delta( - *, - tool_call_id: str, - tool_type: str = "function", - function_name: str = "", - function_arguments: str = "", -) -> AssistantPromptMessage.ToolCall: - return AssistantPromptMessage.ToolCall( - id=tool_call_id, - type=tool_type, - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=function_name, arguments=function_arguments), - ) - - -def _chunk( - *, - model: str = "test-model", - content: str | list[Any] | None = None, - tool_calls: list[AssistantPromptMessage.ToolCall] | None = None, - usage: LLMUsage | None = None, - system_fingerprint: str | None = None, -) -> LLMResultChunk: - return LLMResultChunk( - model=model, - system_fingerprint=system_fingerprint, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=content, tool_calls=tool_calls or []), - usage=usage, - ), - ) - - -@dataclass -class SpyCallback(Callback): - raise_error: bool = False - before: list[dict[str, Any]] = field(default_factory=list) - new_chunk: list[dict[str, Any]] = field(default_factory=list) - after: list[dict[str, Any]] = field(default_factory=list) - error: list[dict[str, Any]] = field(default_factory=list) - - def on_before_invoke(self, **kwargs: Any) -> None: # type: ignore[override] - self.before.append(kwargs) - - def on_new_chunk(self, **kwargs: Any) -> None: # type: ignore[override] - self.new_chunk.append(kwargs) - - def on_after_invoke(self, **kwargs: Any) -> None: # type: ignore[override] - self.after.append(kwargs) - - def on_invoke_error(self, **kwargs: Any) -> None: # type: ignore[override] - self.error.append(kwargs) - - -class _TestLLM(llm_module.LargeLanguageModel): - def get_price(self, model: str, credentials: dict, price_type: Any, tokens: int) -> PriceInfo: # type: ignore[override] - return PriceInfo( - unit_price=Decimal("0.01"), - unit=Decimal(1), - total_amount=Decimal(tokens) * Decimal("0.01"), - currency="USD", - ) - - def _transform_invoke_error(self, error: Exception) -> Exception: # type: ignore[override] - return RuntimeError(f"transformed: {error}") - - -@pytest.fixture -def llm() -> _TestLLM: - provider_schema = SimpleNamespace(provider="provider", label=SimpleNamespace(en_US="Provider")) - model_runtime = MagicMock() - model_runtime.get_llm_num_tokens.return_value = 0 - return _TestLLM(provider_schema=provider_schema, model_runtime=model_runtime, started_at=1.0) - - -def test_gen_tool_call_id_is_uuid_based(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="abc123")) - assert llm_module._gen_tool_call_id() == "chatcmpl-tool-abc123" - - -def test_run_callbacks_no_callbacks_noop() -> None: - invoked: list[int] = [] - llm_module._run_callbacks(None, event="x", invoke=lambda _: invoked.append(1)) - llm_module._run_callbacks([], event="x", invoke=lambda _: invoked.append(1)) - assert invoked == [] - - -def test_run_callbacks_swallows_error_when_raise_error_false(caplog: pytest.LogCaptureFixture) -> None: - class Boom: - raise_error = False - - caplog.set_level(logging.WARNING) - llm_module._run_callbacks( - [Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom")) - ) - assert any("Callback" in record.message and "failed with error" in record.message for record in caplog.records) - - -def test_run_callbacks_reraises_when_raise_error_true() -> None: - class Boom: - raise_error = True - - with pytest.raises(ValueError, match="boom"): - llm_module._run_callbacks( - [Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom")) - ) - - -def test_get_or_create_tool_call_empty_id_returns_last() -> None: - calls = [ - _tool_call_delta(tool_call_id="id1", function_name="a"), - _tool_call_delta(tool_call_id="id2", function_name="b"), - ] - assert llm_module._get_or_create_tool_call(calls, "") is calls[-1] - - -def test_get_or_create_tool_call_empty_id_without_existing_raises() -> None: - with pytest.raises(ValueError, match="tool_call_id is empty"): - llm_module._get_or_create_tool_call([], "") - - -def test_get_or_create_tool_call_creates_if_missing() -> None: - calls: list[AssistantPromptMessage.ToolCall] = [] - tool_call = llm_module._get_or_create_tool_call(calls, "new-id") - assert tool_call.id == "new-id" - assert tool_call.function.name == "" - assert tool_call.function.arguments == "" - assert calls == [tool_call] - - -def test_get_or_create_tool_call_returns_existing_when_found() -> None: - existing = _tool_call_delta(tool_call_id="same-id", function_name="fn", function_arguments="{}") - calls = [existing] - assert llm_module._get_or_create_tool_call(calls, "same-id") is existing - - -def test_merge_tool_call_delta_updates_fields_and_appends_arguments() -> None: - tool_call = _tool_call_delta(tool_call_id="id", tool_type="function", function_name="x", function_arguments="{") - delta = _tool_call_delta(tool_call_id="id2", tool_type="function", function_name="y", function_arguments="}") - llm_module._merge_tool_call_delta(tool_call, delta) - assert tool_call.id == "id2" - assert tool_call.type == "function" - assert tool_call.function.name == "y" - assert tool_call.function.arguments == "{}" - - -def test_increase_tool_call_generates_id_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="fixed")) - delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{") - existing: list[AssistantPromptMessage.ToolCall] = [] - llm_module._increase_tool_call([delta], existing) - assert len(existing) == 1 - assert existing[0].id == "chatcmpl-tool-fixed" - assert existing[0].function.name == "fn" - assert existing[0].function.arguments == "{" - - -def test_increase_tool_call_merges_incremental_arguments() -> None: - existing: list[AssistantPromptMessage.ToolCall] = [] - llm_module._increase_tool_call( - [_tool_call_delta(tool_call_id="id", function_name="fn", function_arguments="{")], existing - ) - llm_module._increase_tool_call( - [_tool_call_delta(tool_call_id="id", function_name="", function_arguments="}")], existing - ) - assert len(existing) == 1 - assert existing[0].function.name == "fn" - assert existing[0].function.arguments == "{}" - - -@pytest.mark.parametrize( - ("content", "expected_type"), - [ - ("hello", str), - ([TextPromptMessageContent(data="hello")], list), - ], -) -def test_build_llm_result_from_chunks_accumulates_and_raises_error( - content: str | list[TextPromptMessageContent], - expected_type: type, - monkeypatch: pytest.MonkeyPatch, - caplog: pytest.LogCaptureFixture, -) -> None: - monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="drain")) - caplog.set_level(logging.DEBUG) - - tool_delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{}") - first = _chunk(content=content, tool_calls=[tool_delta], usage=_usage(3, 4), system_fingerprint="fp1") - - def iter_with_error() -> Iterator[LLMResultChunk]: - yield first - raise RuntimeError("drain boom") - - with pytest.raises(RuntimeError, match="drain boom"): - _build_llm_result_from_chunks( - model="m", prompt_messages=[UserPromptMessage(content="u")], chunks=iter_with_error() - ) - - assert any("Error while consuming non-stream plugin chunk iterator" in record.message for record in caplog.records) - - -def test_build_llm_result_from_chunks_empty_iterator() -> None: - def empty() -> Iterator[LLMResultChunk]: - if False: # pragma: no cover - yield _chunk() - return - - result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=empty()) - assert result.message.content == [] - assert result.usage.total_tokens == 0 - assert result.system_fingerprint is None - - -def test_build_llm_result_from_chunks_accumulates_all_chunks() -> None: - chunks = iter([_chunk(content="first"), _chunk(content="second")]) - result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=chunks) - assert result.message.content == "firstsecond" - - -def test_invoke_llm_via_runtime_passes_list_converted_stop(llm: _TestLLM) -> None: - llm.model_runtime = MagicMock() - prompt_messages: Sequence[PromptMessage] = (UserPromptMessage(content="hi"),) - result = llm_module._invoke_llm_via_runtime( - llm_model=llm, - provider="prov", - model="m", - credentials={"k": "v"}, - model_parameters={"temp": 1}, - prompt_messages=prompt_messages, - tools=None, - stop=("a", "b"), - stream=True, - ) - - llm.model_runtime.invoke_llm.assert_called_once_with( - provider="prov", - model="m", - credentials={"k": "v"}, - model_parameters={"temp": 1}, - prompt_messages=list(prompt_messages), - tools=None, - stop=("a", "b"), - stream=True, - ) - assert result is llm.model_runtime.invoke_llm.return_value - - -def test_normalize_non_stream_runtime_result_passthrough_llmresult() -> None: - llm_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()) - assert ( - llm_module._normalize_non_stream_runtime_result(model="m", prompt_messages=[], result=llm_result) is llm_result - ) - - -def test_normalize_non_stream_runtime_result_builds_from_chunks() -> None: - chunks = iter([_chunk(content="hello", usage=_usage(1, 1))]) - result = llm_module._normalize_non_stream_runtime_result( - model="m", prompt_messages=[UserPromptMessage(content="u")], result=chunks - ) - assert isinstance(result, LLMResult) - assert result.message.content == "hello" - - -def test_invoke_non_stream_normalizes_and_sets_prompt_messages(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: - plugin_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()) - monkeypatch.setattr( - "graphon.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime", - lambda **_: plugin_result, - ) - cb = SpyCallback() - prompt_messages = [UserPromptMessage(content="hi")] - result = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=False, callbacks=[cb]) - assert isinstance(result, LLMResult) - assert result.prompt_messages == prompt_messages - assert len(cb.before) == 1 - assert len(cb.after) == 1 - assert cb.after[0]["result"].prompt_messages == prompt_messages - - -def test_invoke_stream_wraps_generator_and_triggers_callbacks(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: - plugin_chunks = iter( - [ - _chunk(model="m1", content="a"), - _chunk( - model="m2", content=[TextPromptMessageContent(data="b")], usage=_usage(2, 3), system_fingerprint="fp" - ), - _chunk(model="m3", content=None), - ] - ) - monkeypatch.setattr( - "graphon.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime", - lambda **_: plugin_chunks, - ) - - cb = SpyCallback() - prompt_messages = [UserPromptMessage(content="hi")] - gen = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=True, callbacks=[cb]) - - assert isinstance(gen, Generator) - chunks = list(gen) - assert len(chunks) == 3 - assert all(chunk.prompt_messages == prompt_messages for chunk in chunks) - assert len(cb.before) == 1 - assert len(cb.new_chunk) == 3 - assert len(cb.after) == 1 - final_result: LLMResult = cb.after[0]["result"] - assert final_result.model == "m3" - assert final_result.system_fingerprint == "fp" - assert isinstance(final_result.message.content, list) - assert [c.data for c in final_result.message.content] == ["a", "b"] - assert final_result.usage.total_tokens == 5 - - -def test_invoke_triggers_error_callbacks_and_raises_transformed(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: - def boom(**_: Any) -> Any: - raise ValueError("plugin down") - - monkeypatch.setattr( - "graphon.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime", boom - ) - cb = SpyCallback() - with pytest.raises(RuntimeError, match="transformed: plugin down"): - llm.invoke( - model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False, callbacks=[cb] - ) - assert len(cb.error) == 1 - assert isinstance(cb.error[0]["ex"], ValueError) - - -def test_invoke_raises_not_implemented_for_unsupported_result_type( - llm: _TestLLM, monkeypatch: pytest.MonkeyPatch -) -> None: - monkeypatch.setattr(llm_module, "_invoke_llm_via_runtime", lambda **_: "not-a-result") - monkeypatch.setattr(llm_module, "_normalize_non_stream_runtime_result", lambda **_: "not-a-result") - with pytest.raises(NotImplementedError, match="unsupported invoke result type"): - llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False) - - -def test_invoke_appends_logging_callback_in_debug(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: - captured_callbacks: list[list[Callback]] = [] - - class FakeLoggingCallback(SpyCallback): - pass - - monkeypatch.setattr(llm_module, "LoggingCallback", FakeLoggingCallback) - monkeypatch.setattr(llm_module.logger, "isEnabledFor", lambda level: level == logging.DEBUG) - monkeypatch.setattr( - "graphon.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime", - lambda **_: LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()), - ) - - original_trigger = llm._trigger_before_invoke_callbacks - - def spy_trigger(*args: Any, **kwargs: Any) -> None: - captured_callbacks.append(list(kwargs["callbacks"])) - original_trigger(*args, **kwargs) - - monkeypatch.setattr(llm, "_trigger_before_invoke_callbacks", spy_trigger) - llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False) - assert any(isinstance(cb, FakeLoggingCallback) for cb in captured_callbacks[0]) - - -def test_get_num_tokens_returns_0_when_runtime_returns_0(llm: _TestLLM) -> None: - llm.model_runtime.get_llm_num_tokens.return_value = 0 - assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 0 - - -def test_get_num_tokens_uses_runtime(llm: _TestLLM) -> None: - llm.model_runtime.get_llm_num_tokens.return_value = 42 - assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 42 - llm.model_runtime.get_llm_num_tokens.assert_called_once_with( - provider="provider", - model_type=ModelType.LLM, - model="m", - credentials={}, - prompt_messages=[UserPromptMessage(content="x")], - tools=None, - ) - - -def test_calc_response_usage_uses_prices_and_latency(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(llm_module.time, "perf_counter", lambda: 4.5) - llm.started_at = 1.0 - usage = llm.calc_response_usage(model="m", credentials={}, prompt_tokens=10, completion_tokens=5) - assert usage.total_tokens == 15 - assert usage.total_price == Decimal("0.15") - assert usage.latency == 3.5 - - -def test_invoke_result_generator_raises_transformed_on_iteration_error(llm: _TestLLM) -> None: - def broken() -> Iterator[LLMResultChunk]: - yield _chunk(content="ok") - raise ValueError("chunk stream broken") - - gen = llm._invoke_result_generator( - model="m", - result=broken(), - credentials={}, - prompt_messages=[UserPromptMessage(content="u")], - model_parameters={}, - callbacks=[SpyCallback()], - ) - - with pytest.raises(RuntimeError, match="transformed: chunk stream broken"): - list(gen) diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_moderation_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_moderation_model.py deleted file mode 100644 index a42a930806..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_moderation_model.py +++ /dev/null @@ -1,56 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.errors.invoke import InvokeError -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel - - -@pytest.fixture -def provider_schema() -> ProviderEntity: - return ProviderEntity( - provider="test_provider", - label=I18nObject(en_US="test_provider"), - supported_model_types=[ModelType.MODERATION], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -@pytest.fixture -def model_runtime() -> MagicMock: - return MagicMock() - - -@pytest.fixture -def moderation_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> ModerationModel: - return ModerationModel(provider_schema=provider_schema, model_runtime=model_runtime) - - -def test_model_type(moderation_model: ModerationModel) -> None: - assert moderation_model.model_type == ModelType.MODERATION - - -def test_invoke_success(moderation_model: ModerationModel, model_runtime: MagicMock) -> None: - with patch("time.perf_counter", return_value=1.0): - model_runtime.invoke_moderation.return_value = True - - result = moderation_model.invoke(model="test_model", credentials={"api_key": "abc"}, text="test text") - - assert result is True - assert moderation_model.started_at == 1.0 - model_runtime.invoke_moderation.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - text="test text", - ) - - -def test_invoke_exception(moderation_model: ModerationModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_moderation.side_effect = Exception("Test error") - - with pytest.raises(InvokeError, match="Test error"): - moderation_model.invoke(model="test_model", credentials={"api_key": "abc"}, text="test text") diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_rerank_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_rerank_model.py deleted file mode 100644 index 9650ed2db7..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_rerank_model.py +++ /dev/null @@ -1,110 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult -from graphon.model_runtime.errors.invoke import InvokeError -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel - - -@pytest.fixture -def provider_schema() -> ProviderEntity: - return ProviderEntity( - provider="test_provider", - label=I18nObject(en_US="test_provider"), - supported_model_types=[ModelType.RERANK], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -@pytest.fixture -def model_runtime() -> MagicMock: - return MagicMock() - - -@pytest.fixture -def rerank_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> RerankModel: - return RerankModel(provider_schema=provider_schema, model_runtime=model_runtime) - - -def test_model_type_is_rerank_by_default(rerank_model: RerankModel) -> None: - assert rerank_model.model_type == ModelType.RERANK - - -def test_invoke_calls_runtime_and_passes_args(rerank_model: RerankModel, model_runtime: MagicMock) -> None: - expected = RerankResult(model="rerank", docs=[RerankDocument(index=0, text="a", score=0.5)]) - model_runtime.invoke_rerank.return_value = expected - - result = rerank_model.invoke( - model="rerank", - credentials={"k": "v"}, - query="q", - docs=["d1", "d2"], - score_threshold=0.2, - top_n=10, - ) - - assert result == expected - model_runtime.invoke_rerank.assert_called_once_with( - provider="test_provider", - model="rerank", - credentials={"k": "v"}, - query="q", - docs=["d1", "d2"], - score_threshold=0.2, - top_n=10, - ) - - -def test_invoke_transforms_and_raises_on_runtime_error(rerank_model: RerankModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_rerank.side_effect = Exception("runtime down") - - with pytest.raises(InvokeError, match="runtime down"): - rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"]) - - -def test_invoke_multimodal_calls_runtime_and_passes_args(rerank_model: RerankModel, model_runtime: MagicMock) -> None: - expected = RerankResult(model="mm", docs=[RerankDocument(index=0, text="x", score=0.9)]) - model_runtime.invoke_multimodal_rerank.return_value = expected - - query = {"type": "text", "text": "q"} - docs = [{"type": "text", "text": "d1"}] - result = rerank_model.invoke_multimodal_rerank( - model="mm", - credentials={"k": "v"}, - query=query, - docs=docs, - score_threshold=None, - top_n=None, - ) - - assert result == expected - model_runtime.invoke_multimodal_rerank.assert_called_once_with( - provider="test_provider", - model="mm", - credentials={"k": "v"}, - query=query, - docs=docs, - score_threshold=None, - top_n=None, - ) - - -def test_invoke_multimodal_transforms_and_raises_on_runtime_error( - rerank_model: RerankModel, model_runtime: MagicMock -) -> None: - model_runtime.invoke_multimodal_rerank.side_effect = Exception("multimodal runtime down") - - query = {"content": "q", "content_type": "text"} - docs = [{"content": "d1", "content_type": "text"}] - - with pytest.raises(InvokeError, match="multimodal runtime down"): - rerank_model.invoke_multimodal_rerank( - model="mm", - credentials={}, - query=query, - docs=docs, - ) diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_runtime_user_forwarding.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_runtime_user_forwarding.py deleted file mode 100644 index 98bb1eb1b8..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_runtime_user_forwarding.py +++ /dev/null @@ -1,170 +0,0 @@ -from decimal import Decimal -from io import BytesIO -from unittest.mock import MagicMock - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.entities.rerank_entities import RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel - - -def _provider_schema(model_type: ModelType) -> ProviderEntity: - return ProviderEntity( - provider="langgenius/openai/openai", - provider_name="openai", - label=I18nObject(en_US="OpenAI"), - supported_model_types=[model_type], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -def _embedding_usage() -> EmbeddingUsage: - return EmbeddingUsage( - tokens=1, - total_tokens=1, - unit_price=Decimal(0), - price_unit=Decimal(0), - total_price=Decimal(0), - currency="USD", - latency=0.0, - ) - - -def test_large_language_model_invokes_runtime_without_user_id() -> None: - runtime = MagicMock() - runtime.invoke_llm.return_value = LLMResult( - model="gpt-4o-mini", - prompt_messages=[], - message=AssistantPromptMessage(content="ok"), - usage=LLMUsage.empty_usage(), - ) - model = LargeLanguageModel(provider_schema=_provider_schema(ModelType.LLM), model_runtime=runtime) - - model.invoke( - model="gpt-4o-mini", - credentials={"api_key": "secret"}, - prompt_messages=[UserPromptMessage(content="hi")], - stream=False, - ) - - assert "user_id" not in runtime.invoke_llm.call_args.kwargs - - -def test_text_embedding_model_invokes_runtime_without_user_id_for_text_requests() -> None: - runtime = MagicMock() - runtime.invoke_text_embedding.return_value = EmbeddingResult( - model="text-embedding-3-small", - embeddings=[[0.1]], - usage=_embedding_usage(), - ) - model = TextEmbeddingModel(provider_schema=_provider_schema(ModelType.TEXT_EMBEDDING), model_runtime=runtime) - - model.invoke( - model="text-embedding-3-small", - credentials={"api_key": "secret"}, - texts=["hello"], - ) - - assert "user_id" not in runtime.invoke_text_embedding.call_args.kwargs - - -def test_text_embedding_model_invokes_runtime_without_user_id_for_multimodal_requests() -> None: - runtime = MagicMock() - runtime.invoke_multimodal_embedding.return_value = EmbeddingResult( - model="text-embedding-3-small", - embeddings=[[0.1]], - usage=_embedding_usage(), - ) - model = TextEmbeddingModel(provider_schema=_provider_schema(ModelType.TEXT_EMBEDDING), model_runtime=runtime) - - model.invoke( - model="text-embedding-3-small", - credentials={"api_key": "secret"}, - multimodel_documents=[{"content": "hello", "content_type": "text"}], - ) - - assert "user_id" not in runtime.invoke_multimodal_embedding.call_args.kwargs - - -def test_rerank_model_invokes_runtime_without_user_id_for_text_requests() -> None: - runtime = MagicMock() - runtime.invoke_rerank.return_value = RerankResult(model="rerank", docs=[]) - model = RerankModel(provider_schema=_provider_schema(ModelType.RERANK), model_runtime=runtime) - - model.invoke( - model="rerank", - credentials={"api_key": "secret"}, - query="q", - docs=["d1"], - ) - - assert "user_id" not in runtime.invoke_rerank.call_args.kwargs - - -def test_rerank_model_invokes_runtime_without_user_id_for_multimodal_requests() -> None: - runtime = MagicMock() - runtime.invoke_multimodal_rerank.return_value = RerankResult(model="rerank", docs=[]) - model = RerankModel(provider_schema=_provider_schema(ModelType.RERANK), model_runtime=runtime) - - model.invoke_multimodal_rerank( - model="rerank", - credentials={"api_key": "secret"}, - query={"content": "q", "content_type": "text"}, - docs=[{"content": "d1", "content_type": "text"}], - ) - - assert "user_id" not in runtime.invoke_multimodal_rerank.call_args.kwargs - - -def test_tts_model_invokes_runtime_without_user_id() -> None: - runtime = MagicMock() - runtime.invoke_tts.return_value = [b"chunk"] - model = TTSModel(provider_schema=_provider_schema(ModelType.TTS), model_runtime=runtime) - - list( - model.invoke( - model="tts-1", - credentials={"api_key": "secret"}, - content_text="hello", - voice="alloy", - ) - ) - - assert "user_id" not in runtime.invoke_tts.call_args.kwargs - - -def test_speech_to_text_model_invokes_runtime_without_user_id() -> None: - runtime = MagicMock() - runtime.invoke_speech_to_text.return_value = "transcript" - model = Speech2TextModel(provider_schema=_provider_schema(ModelType.SPEECH2TEXT), model_runtime=runtime) - - model.invoke( - model="whisper-1", - credentials={"api_key": "secret"}, - file=BytesIO(b"audio"), - ) - - assert "user_id" not in runtime.invoke_speech_to_text.call_args.kwargs - - -def test_moderation_model_invokes_runtime_without_user_id() -> None: - runtime = MagicMock() - runtime.invoke_moderation.return_value = True - model = ModerationModel(provider_schema=_provider_schema(ModelType.MODERATION), model_runtime=runtime) - - model.invoke( - model="omni-moderation-latest", - credentials={"api_key": "secret"}, - text="unsafe?", - ) - - assert "user_id" not in runtime.invoke_moderation.call_args.kwargs diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_speech2text_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_speech2text_model.py deleted file mode 100644 index b03923bbc2..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_speech2text_model.py +++ /dev/null @@ -1,56 +0,0 @@ -from io import BytesIO -from unittest.mock import MagicMock - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.errors.invoke import InvokeError -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel - - -@pytest.fixture -def provider_schema() -> ProviderEntity: - return ProviderEntity( - provider="test_provider", - label=I18nObject(en_US="test_provider"), - supported_model_types=[ModelType.SPEECH2TEXT], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -@pytest.fixture -def model_runtime() -> MagicMock: - return MagicMock() - - -@pytest.fixture -def speech2text_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> Speech2TextModel: - return Speech2TextModel(provider_schema=provider_schema, model_runtime=model_runtime) - - -def test_model_type(speech2text_model: Speech2TextModel) -> None: - assert speech2text_model.model_type == ModelType.SPEECH2TEXT - - -def test_invoke_success(speech2text_model: Speech2TextModel, model_runtime: MagicMock) -> None: - file = BytesIO(b"audio data") - model_runtime.invoke_speech_to_text.return_value = "transcribed text" - - result = speech2text_model.invoke(model="test_model", credentials={"api_key": "abc"}, file=file) - - assert result == "transcribed text" - model_runtime.invoke_speech_to_text.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - file=file, - ) - - -def test_invoke_exception(speech2text_model: Speech2TextModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_speech_to_text.side_effect = Exception("Test error") - - with pytest.raises(InvokeError, match="Test error"): - speech2text_model.invoke(model="test_model", credentials={"api_key": "abc"}, file=BytesIO(b"audio data")) diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_text_embedding_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_text_embedding_model.py deleted file mode 100644 index 64caf3a315..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_text_embedding_model.py +++ /dev/null @@ -1,146 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult -from graphon.model_runtime.errors.invoke import InvokeError -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel - - -@pytest.fixture -def provider_schema() -> ProviderEntity: - return ProviderEntity( - provider="test_provider", - label=I18nObject(en_US="test_provider"), - supported_model_types=[ModelType.TEXT_EMBEDDING], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -@pytest.fixture -def model_runtime() -> MagicMock: - return MagicMock() - - -@pytest.fixture -def text_embedding_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> TextEmbeddingModel: - return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=model_runtime) - - -def test_model_type(text_embedding_model: TextEmbeddingModel) -> None: - assert text_embedding_model.model_type == ModelType.TEXT_EMBEDDING - - -def test_invoke_with_texts(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: - expected_result = MagicMock(spec=EmbeddingResult) - model_runtime.invoke_text_embedding.return_value = expected_result - - result = text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}, texts=["hello", "world"]) - - assert result == expected_result - model_runtime.invoke_text_embedding.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - texts=["hello", "world"], - input_type=EmbeddingInputType.DOCUMENT, - ) - - -def test_invoke_with_multimodal_documents(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: - expected_result = MagicMock(spec=EmbeddingResult) - model_runtime.invoke_multimodal_embedding.return_value = expected_result - - result = text_embedding_model.invoke( - model="test_model", - credentials={"api_key": "abc"}, - multimodel_documents=[{"type": "text", "text": "hello"}], - ) - - assert result == expected_result - model_runtime.invoke_multimodal_embedding.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - documents=[{"type": "text", "text": "hello"}], - input_type=EmbeddingInputType.DOCUMENT, - ) - - -def test_invoke_no_input(text_embedding_model: TextEmbeddingModel) -> None: - with pytest.raises(ValueError, match="No texts or files provided"): - text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}) - - -def test_invoke_prefers_texts_over_multimodal_documents( - text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock -) -> None: - expected_result = MagicMock(spec=EmbeddingResult) - model_runtime.invoke_text_embedding.return_value = expected_result - - result = text_embedding_model.invoke( - model="test_model", - credentials={"api_key": "abc"}, - texts=["hello"], - multimodel_documents=[{"type": "text", "text": "world"}], - ) - - assert result == expected_result - model_runtime.invoke_text_embedding.assert_called_once() - model_runtime.invoke_multimodal_embedding.assert_not_called() - - -def test_invoke_exception(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_text_embedding.side_effect = Exception("Test error") - - with pytest.raises(InvokeError, match="Test error"): - text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}, texts=["hello"]) - - -def test_get_num_tokens(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: - model_runtime.get_text_embedding_num_tokens.return_value = [1, 1] - - result = text_embedding_model.get_num_tokens( - model="test_model", credentials={"api_key": "abc"}, texts=["hello", "world"] - ) - - assert result == [1, 1] - model_runtime.get_text_embedding_num_tokens.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - texts=["hello", "world"], - ) - - -def test_get_context_size(text_embedding_model: TextEmbeddingModel) -> None: - mock_schema = MagicMock() - mock_schema.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 2048} - - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 2048 - - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): - assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 1000 - - mock_schema.model_properties = {} - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 1000 - - -def test_get_max_chunks(text_embedding_model: TextEmbeddingModel) -> None: - mock_schema = MagicMock() - mock_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10} - - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 10 - - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): - assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 1 - - mock_schema.model_properties = {} - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 1 diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_tts_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_tts_model.py deleted file mode 100644 index d15efb69c3..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_tts_model.py +++ /dev/null @@ -1,83 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.errors.invoke import InvokeError -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel - - -@pytest.fixture -def provider_schema() -> ProviderEntity: - return ProviderEntity( - provider="test_provider", - label=I18nObject(en_US="test_provider"), - supported_model_types=[ModelType.TTS], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -@pytest.fixture -def model_runtime() -> MagicMock: - return MagicMock() - - -@pytest.fixture -def tts_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> TTSModel: - return TTSModel(provider_schema=provider_schema, model_runtime=model_runtime) - - -def test_model_type(tts_model: TTSModel) -> None: - assert tts_model.model_type == ModelType.TTS - - -def test_invoke_success(tts_model: TTSModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_tts.return_value = [b"audio_chunk"] - - result = tts_model.invoke( - model="test_model", - credentials={"api_key": "abc"}, - content_text="Hello world", - voice="alloy", - ) - - assert list(result) == [b"audio_chunk"] - model_runtime.invoke_tts.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - content_text="Hello world", - voice="alloy", - ) - - -def test_invoke_exception(tts_model: TTSModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_tts.side_effect = Exception("Test error") - - with pytest.raises(InvokeError, match="Test error"): - tts_model.invoke( - model="test_model", - credentials={"api_key": "abc"}, - content_text="Hello world", - voice="alloy", - ) - - -def test_get_tts_model_voices(tts_model: TTSModel, model_runtime: MagicMock) -> None: - model_runtime.get_tts_model_voices.return_value = [{"name": "Voice1"}] - - result = tts_model.get_tts_model_voices( - model="test_model", - credentials={"api_key": "abc"}, - language="en-US", - ) - - assert result == [{"name": "Voice1"}] - model_runtime.get_tts_model_voices.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - language="en-US", - ) diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py deleted file mode 100644 index d4d3eeb18c..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py +++ /dev/null @@ -1,96 +0,0 @@ -from unittest.mock import MagicMock, patch - -import graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer as gpt2_tokenizer_module -from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer - - -class TestGPT2Tokenizer: - def setup_method(self): - # Reset the global tokenizer before each test to ensure we test initialization - gpt2_tokenizer_module._tokenizer = None - - def test_get_encoder_tiktoken(self): - """ - Test that get_encoder successfully uses tiktoken when available. - """ - mock_encoding = MagicMock() - # Mock tiktoken to be sure it's used - with patch("tiktoken.get_encoding", return_value=mock_encoding) as mock_get_encoding: - encoder = GPT2Tokenizer.get_encoder() - assert encoder == mock_encoding - mock_get_encoding.assert_called_once_with("gpt2") - - # Verify singleton behavior within the same test - encoder2 = GPT2Tokenizer.get_encoder() - assert encoder2 is encoder - assert mock_get_encoding.call_count == 1 - - def test_get_encoder_tiktoken_fallback(self): - """ - Test that get_encoder falls back to transformers when tiktoken fails. - """ - # patch tiktoken.get_encoding to raise an exception - with patch("tiktoken.get_encoding", side_effect=Exception("Tiktoken failure")): - # patch transformers.GPT2Tokenizer - with patch("transformers.GPT2Tokenizer.from_pretrained") as mock_from_pretrained: - mock_transformer_tokenizer = MagicMock() - mock_from_pretrained.return_value = mock_transformer_tokenizer - - with patch( - "graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer.logger" - ) as mock_logger: - encoder = GPT2Tokenizer.get_encoder() - - assert encoder == mock_transformer_tokenizer - mock_from_pretrained.assert_called_once() - mock_logger.info.assert_called_once_with("Fallback to Transformers' GPT-2 tokenizer from tiktoken") - - def test_get_num_tokens(self): - """ - Test get_num_tokens returns the correct count. - """ - mock_encoder = MagicMock() - mock_encoder.encode.return_value = [1, 2, 3, 4, 5] - - with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder): - tokens_count = GPT2Tokenizer.get_num_tokens("test text") - assert tokens_count == 5 - mock_encoder.encode.assert_called_once_with("test text") - - def test_get_num_tokens_by_gpt2_direct(self): - """ - Test _get_num_tokens_by_gpt2 directly. - """ - mock_encoder = MagicMock() - mock_encoder.encode.return_value = [1, 2] - - with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder): - tokens_count = GPT2Tokenizer._get_num_tokens_by_gpt2("hello") - assert tokens_count == 2 - mock_encoder.encode.assert_called_once_with("hello") - - def test_get_encoder_already_initialized(self): - """ - Test that if _tokenizer is already set, it returns it immediately. - """ - mock_existing_tokenizer = MagicMock() - gpt2_tokenizer_module._tokenizer = mock_existing_tokenizer - - # Tiktoken should not be called if already initialized - with patch("tiktoken.get_encoding") as mock_get_encoding: - encoder = GPT2Tokenizer.get_encoder() - assert encoder == mock_existing_tokenizer - mock_get_encoding.assert_not_called() - - def test_get_encoder_thread_safety(self): - """ - Simple test to ensure the lock is used. - """ - mock_encoding = MagicMock() - with patch("tiktoken.get_encoding", return_value=mock_encoding): - # We patch the lock in the module - with patch("graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer._lock") as mock_lock: - encoder = GPT2Tokenizer.get_encoder() - assert encoder == mock_encoding - mock_lock.__enter__.assert_called_once() - mock_lock.__exit__.assert_called_once() diff --git a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_common_validator.py b/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_common_validator.py deleted file mode 100644 index 60ded4b90a..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_common_validator.py +++ /dev/null @@ -1,201 +0,0 @@ -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.provider_entities import ( - CredentialFormSchema, - FormOption, - FormShowOnObject, - FormType, -) -from graphon.model_runtime.schema_validators.common_validator import CommonValidator - - -class TestCommonValidator: - def test_validate_credential_form_schema_required_missing(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True - ) - with pytest.raises(ValueError, match="Variable api_key is required"): - validator._validate_credential_form_schema(schema, {}) - - def test_validate_credential_form_schema_not_required_missing_with_default(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key"), - type=FormType.TEXT_INPUT, - required=False, - default="default_value", - ) - assert validator._validate_credential_form_schema(schema, {}) == "default_value" - - def test_validate_credential_form_schema_not_required_missing_no_default(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=False - ) - assert validator._validate_credential_form_schema(schema, {}) is None - - def test_validate_credential_form_schema_max_length_exceeded(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, max_length=5 - ) - with pytest.raises(ValueError, match="Variable api_key length should not be greater than 5"): - validator._validate_credential_form_schema(schema, {"api_key": "123456"}) - - def test_validate_credential_form_schema_not_string(self): - validator = CommonValidator() - schema = CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT) - with pytest.raises(ValueError, match="Variable api_key should be string"): - validator._validate_credential_form_schema(schema, {"api_key": 123}) - - def test_validate_credential_form_schema_select_invalid_option(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="mode", - label=I18nObject(en_US="Mode"), - type=FormType.SELECT, - options=[ - FormOption(label=I18nObject(en_US="Fast"), value="fast"), - FormOption(label=I18nObject(en_US="Slow"), value="slow"), - ], - ) - with pytest.raises(ValueError, match="Variable mode is not in options"): - validator._validate_credential_form_schema(schema, {"mode": "medium"}) - - def test_validate_credential_form_schema_select_valid_option(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="mode", - label=I18nObject(en_US="Mode"), - type=FormType.SELECT, - options=[ - FormOption(label=I18nObject(en_US="Fast"), value="fast"), - FormOption(label=I18nObject(en_US="Slow"), value="slow"), - ], - ) - assert validator._validate_credential_form_schema(schema, {"mode": "fast"}) == "fast" - - def test_validate_credential_form_schema_switch_invalid(self): - validator = CommonValidator() - schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH) - with pytest.raises(ValueError, match="Variable enabled should be true or false"): - validator._validate_credential_form_schema(schema, {"enabled": "maybe"}) - - def test_validate_credential_form_schema_switch_valid(self): - validator = CommonValidator() - schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH) - assert validator._validate_credential_form_schema(schema, {"enabled": "true"}) is True - assert validator._validate_credential_form_schema(schema, {"enabled": "FALSE"}) is False - - def test_validate_and_filter_credential_form_schemas_with_show_on(self): - validator = CommonValidator() - schemas = [ - CredentialFormSchema( - variable="auth_type", - label=I18nObject(en_US="Auth Type"), - type=FormType.SELECT, - options=[ - FormOption(label=I18nObject(en_US="API Key"), value="api_key"), - FormOption(label=I18nObject(en_US="OAuth"), value="oauth"), - ], - ), - CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key"), - type=FormType.TEXT_INPUT, - show_on=[FormShowOnObject(variable="auth_type", value="api_key")], - ), - CredentialFormSchema( - variable="client_id", - label=I18nObject(en_US="Client ID"), - type=FormType.TEXT_INPUT, - show_on=[FormShowOnObject(variable="auth_type", value="oauth")], - ), - ] - - # Case 1: auth_type = api_key - credentials = {"auth_type": "api_key", "api_key": "my_secret"} - result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) - assert "auth_type" in result - assert "api_key" in result - assert "client_id" not in result - assert result["api_key"] == "my_secret" - - # Case 2: auth_type = oauth - credentials = {"auth_type": "oauth", "client_id": "my_client"} - result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) - # Note: 'auth_type' contains 'oauth'. 'result' contains keys that pass validation. - # Since 'oauth' is not an empty string, it is in result. - assert "auth_type" in result - assert "api_key" not in result - assert "client_id" in result - assert result["client_id"] == "my_client" - - def test_validate_and_filter_show_on_missing_variable(self): - validator = CommonValidator() - schemas = [ - CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key"), - type=FormType.TEXT_INPUT, - show_on=[FormShowOnObject(variable="auth_type", value="api_key")], - ) - ] - # auth_type is missing in credentials, so api_key should be filtered out - result = validator._validate_and_filter_credential_form_schemas(schemas, {}) - assert result == {} - - def test_validate_and_filter_show_on_mismatch_value(self): - validator = CommonValidator() - schemas = [ - CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key"), - type=FormType.TEXT_INPUT, - show_on=[FormShowOnObject(variable="auth_type", value="api_key")], - ) - ] - # auth_type is oauth, which doesn't match show_on - result = validator._validate_and_filter_credential_form_schemas(schemas, {"auth_type": "oauth"}) - assert result == {} - - def test_validate_and_filter_multiple_show_on(self): - validator = CommonValidator() - schemas = [ - CredentialFormSchema( - variable="target", - label=I18nObject(en_US="Target"), - type=FormType.TEXT_INPUT, - show_on=[FormShowOnObject(variable="v1", value="a"), FormShowOnObject(variable="v2", value="b")], - ) - ] - # Both match - assert "target" in validator._validate_and_filter_credential_form_schemas( - schemas, {"v1": "a", "v2": "b", "target": "val"} - ) - # One mismatch - assert "target" not in validator._validate_and_filter_credential_form_schemas( - schemas, {"v1": "a", "v2": "c", "target": "val"} - ) - # One missing - assert "target" not in validator._validate_and_filter_credential_form_schemas( - schemas, {"v1": "a", "target": "val"} - ) - - def test_validate_and_filter_skips_falsy_results(self): - validator = CommonValidator() - schemas = [ - CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH), - CredentialFormSchema( - variable="empty_str", label=I18nObject(en_US="Empty"), type=FormType.TEXT_INPUT, required=False - ), - ] - # Result of false switch is False. if result: is false. Not added. - # Result of empty string is "", if result: is false. Not added. - credentials = {"enabled": "false", "empty_str": ""} - result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) - assert "enabled" not in result - assert "empty_str" not in result diff --git a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_model_credential_schema_validator.py b/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_model_credential_schema_validator.py deleted file mode 100644 index 3932844b91..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_model_credential_schema_validator.py +++ /dev/null @@ -1,233 +0,0 @@ -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - CredentialFormSchema, - FieldModelSchema, - FormOption, - FormShowOnObject, - FormType, - ModelCredentialSchema, -) -from graphon.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator - - -def test_validate_and_filter_with_none_schema(): - validator = ModelCredentialSchemaValidator(ModelType.LLM, None) - with pytest.raises(ValueError, match="Model credential schema is None"): - validator.validate_and_filter({}) - - -def test_validate_and_filter_success(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key", zh_Hans="API Key"), - type=FormType.SECRET_INPUT, - required=True, - ), - CredentialFormSchema( - variable="optional_field", - label=I18nObject(en_US="Optional", zh_Hans="可选"), - type=FormType.TEXT_INPUT, - required=False, - default="default_val", - ), - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - credentials = {"api_key": "sk-123456"} - result = validator.validate_and_filter(credentials) - - assert result["api_key"] == "sk-123456" - assert result["optional_field"] == "default_val" - assert credentials["__model_type"] == ModelType.LLM.value - - -def test_validate_and_filter_with_show_on(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="mode", label=I18nObject(en_US="Mode", zh_Hans="模式"), type=FormType.TEXT_INPUT, required=True - ), - CredentialFormSchema( - variable="conditional_field", - label=I18nObject(en_US="Conditional", zh_Hans="条件"), - type=FormType.TEXT_INPUT, - required=True, - show_on=[FormShowOnObject(variable="mode", value="advanced")], - ), - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - # mode is 'simple', conditional_field should be filtered out - credentials = {"mode": "simple", "conditional_field": "secret"} - result = validator.validate_and_filter(credentials) - assert "conditional_field" not in result - assert result["mode"] == "simple" - - # mode is 'advanced', conditional_field should be kept - credentials = {"mode": "advanced", "conditional_field": "secret"} - result = validator.validate_and_filter(credentials) - assert result["conditional_field"] == "secret" - assert result["mode"] == "advanced" - - # show_on variable missing in credentials - credentials = {"conditional_field": "secret"} # mode missing - with pytest.raises(ValueError, match="Variable mode is required"): # because mode is required in schema - validator.validate_and_filter(credentials) - - -def test_validate_and_filter_show_on_missing_trigger_var(): - # specifically test all_show_on_match = False when variable not in credentials - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="optional_trigger", - label=I18nObject(en_US="Optional Trigger", zh_Hans="可选触发"), - type=FormType.TEXT_INPUT, - required=False, - ), - CredentialFormSchema( - variable="conditional_field", - label=I18nObject(en_US="Conditional", zh_Hans="条件"), - type=FormType.TEXT_INPUT, - required=False, - show_on=[FormShowOnObject(variable="optional_trigger", value="active")], - ), - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - # optional_trigger missing, conditional_field should be skipped - result = validator.validate_and_filter({"conditional_field": "val"}) - assert "conditional_field" not in result - - -def test_common_validator_logic_required(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key", zh_Hans="API Key"), - type=FormType.SECRET_INPUT, - required=True, - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - with pytest.raises(ValueError, match="Variable api_key is required"): - validator.validate_and_filter({}) - - with pytest.raises(ValueError, match="Variable api_key is required"): - validator.validate_and_filter({"api_key": ""}) - - -def test_common_validator_logic_max_length(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="key", - label=I18nObject(en_US="Key", zh_Hans="Key"), - type=FormType.TEXT_INPUT, - required=True, - max_length=5, - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - with pytest.raises(ValueError, match="Variable key length should not be greater than 5"): - validator.validate_and_filter({"key": "123456"}) - - -def test_common_validator_logic_invalid_type(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="key", label=I18nObject(en_US="Key", zh_Hans="Key"), type=FormType.TEXT_INPUT, required=True - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - with pytest.raises(ValueError, match="Variable key should be string"): - validator.validate_and_filter({"key": 123}) - - -def test_common_validator_logic_switch(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="enabled", - label=I18nObject(en_US="Enabled", zh_Hans="启用"), - type=FormType.SWITCH, - required=True, - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - result = validator.validate_and_filter({"enabled": "true"}) - assert result["enabled"] is True - - result = validator.validate_and_filter({"enabled": "false"}) - assert "enabled" not in result - - with pytest.raises(ValueError, match="Variable enabled should be true or false"): - validator.validate_and_filter({"enabled": "not_a_bool"}) - - -def test_common_validator_logic_options(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="choice", - label=I18nObject(en_US="Choice", zh_Hans="选择"), - type=FormType.SELECT, - required=True, - options=[ - FormOption(label=I18nObject(en_US="A", zh_Hans="A"), value="a"), - FormOption(label=I18nObject(en_US="B", zh_Hans="B"), value="b"), - ], - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - result = validator.validate_and_filter({"choice": "a"}) - assert result["choice"] == "a" - - with pytest.raises(ValueError, match="Variable choice is not in options"): - validator.validate_and_filter({"choice": "c"}) - - -def test_validate_and_filter_optional_no_default(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="optional", - label=I18nObject(en_US="Optional", zh_Hans="可选"), - type=FormType.TEXT_INPUT, - required=False, - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - result = validator.validate_and_filter({}) - assert "optional" not in result diff --git a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_provider_credential_schema_validator.py b/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_provider_credential_schema_validator.py deleted file mode 100644 index f7a2a5b623..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_provider_credential_schema_validator.py +++ /dev/null @@ -1,72 +0,0 @@ -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderCredentialSchema -from graphon.model_runtime.schema_validators.provider_credential_schema_validator import ( - ProviderCredentialSchemaValidator, -) - - -class TestProviderCredentialSchemaValidator: - def test_validate_and_filter_success(self): - # Setup schema - schema = ProviderCredentialSchema( - credential_form_schemas=[ - CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True - ), - CredentialFormSchema( - variable="endpoint", - label=I18nObject(en_US="Endpoint"), - type=FormType.TEXT_INPUT, - required=False, - default="https://api.example.com", - ), - ] - ) - validator = ProviderCredentialSchemaValidator(schema) - - # Test valid credentials - credentials = {"api_key": "my-secret-key"} - result = validator.validate_and_filter(credentials) - - assert result == {"api_key": "my-secret-key", "endpoint": "https://api.example.com"} - - def test_validate_and_filter_missing_required(self): - # Setup schema - schema = ProviderCredentialSchema( - credential_form_schemas=[ - CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True - ) - ] - ) - validator = ProviderCredentialSchemaValidator(schema) - - # Test missing required credentials - with pytest.raises(ValueError, match="Variable api_key is required"): - validator.validate_and_filter({}) - - def test_validate_and_filter_extra_fields_filtered(self): - # Setup schema - schema = ProviderCredentialSchema( - credential_form_schemas=[ - CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True - ) - ] - ) - validator = ProviderCredentialSchemaValidator(schema) - - # Test credentials with extra fields - credentials = {"api_key": "my-secret-key", "extra_field": "should-be-filtered"} - result = validator.validate_and_filter(credentials) - - assert "api_key" in result - assert "extra_field" not in result - assert result == {"api_key": "my-secret-key"} - - def test_init(self): - schema = ProviderCredentialSchema(credential_form_schemas=[]) - validator = ProviderCredentialSchemaValidator(schema) - assert validator.provider_credential_schema == schema diff --git a/api/tests/unit_tests/graphon/model_runtime/utils/test_encoders.py b/api/tests/unit_tests/graphon/model_runtime/utils/test_encoders.py deleted file mode 100644 index 8edc143fae..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/utils/test_encoders.py +++ /dev/null @@ -1,231 +0,0 @@ -import dataclasses -import datetime -from collections import deque -from decimal import Decimal -from enum import Enum -from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network -from pathlib import Path, PurePath -from re import compile -from typing import Any -from unittest.mock import MagicMock -from uuid import UUID - -import pytest -from pydantic import BaseModel, ConfigDict -from pydantic.networks import AnyUrl, NameEmail -from pydantic.types import SecretBytes, SecretStr -from pydantic_core import Url -from pydantic_extra_types.color import Color - -from graphon.model_runtime.utils.encoders import ( - _model_dump, - decimal_encoder, - generate_encoders_by_class_tuples, - isoformat, - jsonable_encoder, -) - - -class MockEnum(Enum): - A = "a" - B = "b" - - -class MockPydanticModel(BaseModel): - model_config = ConfigDict(populate_by_name=True) - name: str - age: int - - -@dataclasses.dataclass -class MockDataclass: - name: str - value: Any - - -class MockWithDict: - def __init__(self, data): - self.data = data - - def __iter__(self): - return iter(self.data.items()) - - -class MockWithVars: - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - - -class TestEncoders: - def test_model_dump(self): - model = MockPydanticModel(name="test", age=20) - result = _model_dump(model) - assert result == {"name": "test", "age": 20} - - def test_isoformat(self): - d = datetime.date(2023, 1, 1) - assert isoformat(d) == "2023-01-01" - t = datetime.time(12, 0, 0) - assert isoformat(t) == "12:00:00" - - def test_decimal_encoder(self): - assert decimal_encoder(Decimal("1.0")) == 1.0 - assert decimal_encoder(Decimal(1)) == 1 - assert decimal_encoder(Decimal("1.5")) == 1.5 - assert decimal_encoder(Decimal(0)) == 0 - assert decimal_encoder(Decimal(-1)) == -1 - - def test_generate_encoders_by_class_tuples(self): - type_map = {int: str, float: str, str: int} - result = generate_encoders_by_class_tuples(type_map) - assert result[str] == (int, float) - assert result[int] == (str,) - - def test_jsonable_encoder_basic_types(self): - assert jsonable_encoder("string") == "string" - assert jsonable_encoder(123) == 123 - assert jsonable_encoder(1.23) == 1.23 - assert jsonable_encoder(None) is None - - def test_jsonable_encoder_pydantic(self): - model = MockPydanticModel(name="test", age=20) - assert jsonable_encoder(model) == {"name": "test", "age": 20} - - def test_jsonable_encoder_pydantic_root(self): - # Manually create a mock that behaves like a model with __root__ - # because Pydantic v2 handles root differently, but the code checks for "__root__" - model = MagicMock(spec=BaseModel) - # _model_dump(obj, mode="json", ...) -> model.model_dump(mode="json", ...) - model.model_dump.return_value = {"__root__": [1, 2, 3]} - assert jsonable_encoder(model) == [1, 2, 3] - - def test_jsonable_encoder_dataclass(self): - obj = MockDataclass(name="test", value=1) - assert jsonable_encoder(obj) == {"name": "test", "value": 1} - # Test dataclass type (should not be treated as instance) - # It should fall back to vars() or dict() or at least not crash - with pytest.raises(ValueError): - jsonable_encoder(MockDataclass) - - def test_jsonable_encoder_enum(self): - assert jsonable_encoder(MockEnum.A) == "a" - - def test_jsonable_encoder_path(self): - assert jsonable_encoder(Path("/tmp/test")) == "/tmp/test" - assert jsonable_encoder(PurePath("/tmp/test")) == "/tmp/test" - - def test_jsonable_encoder_decimal(self): - # In jsonable_encoder, Decimal is formatted as string via format(obj, "f") - assert jsonable_encoder(Decimal("1.23")) == "1.23" - assert jsonable_encoder(Decimal("1.000")) == "1.000" - - def test_jsonable_encoder_dict(self): - d = {"a": 1, "b": [2, 3], "_private": "hidden"} - assert jsonable_encoder(d) == {"a": 1, "b": [2, 3], "_private": "hidden"} - assert jsonable_encoder(d, excluded_key_prefixes=("_",)) == {"a": 1, "b": [2, 3]} - - d_with_none = {"a": 1, "b": None} - assert jsonable_encoder(d_with_none, exclude_none=True) == {"a": 1} - assert jsonable_encoder(d_with_none, exclude_none=False) == {"a": 1, "b": None} - - def test_jsonable_encoder_collections(self): - assert jsonable_encoder([1, 2]) == [1, 2] - assert jsonable_encoder((1, 2)) == [1, 2] - assert jsonable_encoder({1, 2}) == [1, 2] - assert jsonable_encoder(frozenset([1, 2])) == [1, 2] - assert jsonable_encoder(deque([1, 2])) == [1, 2] - - def gen(): - yield 1 - yield 2 - - assert jsonable_encoder(gen()) == [1, 2] - - def test_jsonable_encoder_custom_encoder(self): - custom = {int: lambda x: str(x + 1)} - assert jsonable_encoder(1, custom_encoder=custom) == "2" - - # Test subclass matching for custom encoder - class SubInt(int): - pass - - assert jsonable_encoder(SubInt(1), custom_encoder=custom) == "2" - - def test_jsonable_encoder_special_types(self): - # These hit ENCODERS_BY_TYPE or encoders_by_class_tuples - assert jsonable_encoder(b"bytes") == "bytes" - assert jsonable_encoder(Color("red")) == "red" - - dt = datetime.datetime(2023, 1, 1, 12, 0, 0) - assert jsonable_encoder(dt) == dt.isoformat() - - date = datetime.date(2023, 1, 1) - assert jsonable_encoder(date) == date.isoformat() - - time = datetime.time(12, 0, 0) - assert jsonable_encoder(time) == time.isoformat() - - td = datetime.timedelta(seconds=60) - assert jsonable_encoder(td) == 60.0 - - assert jsonable_encoder(IPv4Address("127.0.0.1")) == "127.0.0.1" - assert jsonable_encoder(IPv4Interface("127.0.0.1/24")) == "127.0.0.1/24" - assert jsonable_encoder(IPv4Network("127.0.0.0/24")) == "127.0.0.0/24" - assert jsonable_encoder(IPv6Address("::1")) == "::1" - assert jsonable_encoder(IPv6Interface("::1/128")) == "::1/128" - assert jsonable_encoder(IPv6Network("::/128")) == "::/128" - - assert jsonable_encoder(NameEmail(name="test", email="test@example.com")) == "test " - - assert jsonable_encoder(compile("abc")) == "abc" - - # Secret types - # Check what they actually return in this environment - res_bytes = jsonable_encoder(SecretBytes(b"secret")) - assert "**********" in res_bytes - - res_str = jsonable_encoder(SecretStr("secret")) - assert res_str == "**********" - - u = UUID("12345678-1234-5678-1234-567812345678") - assert jsonable_encoder(u) == str(u) - - url = AnyUrl("https://example.com") - assert jsonable_encoder(url) == "https://example.com/" - - purl = Url("https://example.com") - assert jsonable_encoder(purl) == "https://example.com/" - - def test_jsonable_encoder_fallback(self): - # dict(obj) success - obj_dict = MockWithDict({"a": 1}) - assert jsonable_encoder(obj_dict) == {"a": 1} - - # vars(obj) success - obj_vars = MockWithVars(x=10, y=20) - assert jsonable_encoder(obj_vars) == {"x": 10, "y": 20} - - # error fallback - class ReallyUnserializable: - __slots__ = ["__weakref__"] # No __dict__ - - def __iter__(self): - raise TypeError("not iterable") - - with pytest.raises(ValueError) as exc: - jsonable_encoder(ReallyUnserializable()) - assert "not iterable" in str(exc.value) - - def test_jsonable_encoder_nested(self): - data = { - "model": MockPydanticModel(name="test", age=20), - "list": [Decimal("1.1"), {MockEnum.A: Path("/tmp")}], - "set": {1, 2}, - } - expected = { - "model": {"name": "test", "age": 20}, - "list": ["1.1", {"a": "/tmp"}], - "set": [1, 2], - } - assert jsonable_encoder(data) == expected diff --git a/api/tests/unit_tests/graphon/node_events/test_base.py b/api/tests/unit_tests/graphon/node_events/test_base.py deleted file mode 100644 index 4ff1270265..0000000000 --- a/api/tests/unit_tests/graphon/node_events/test_base.py +++ /dev/null @@ -1,19 +0,0 @@ -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.node_events.base import NodeRunResult - - -def test_node_run_result_accepts_trigger_info_metadata() -> None: - result = NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - metadata={ - WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: { - "provider_id": "provider-id", - "event_name": "event-name", - } - }, - ) - - assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == { - "provider_id": "provider-id", - "event_name": "event-name", - } diff --git a/api/tests/unit_tests/graphon/utils/test_json_in_md_parser.py b/api/tests/unit_tests/graphon/utils/test_json_in_md_parser.py deleted file mode 100644 index a8c86d288c..0000000000 --- a/api/tests/unit_tests/graphon/utils/test_json_in_md_parser.py +++ /dev/null @@ -1,75 +0,0 @@ -import pytest - -from graphon.utils.json_in_md_parser import ( - OutputParserError, - parse_and_check_json_markdown, - parse_json_markdown, -) - - -def test_parse_json_markdown_extracts_fenced_json_object() -> None: - src = """ - ```json - {"a": 1, "b": "x"} - ``` - """ - - assert parse_json_markdown(src) == {"a": 1, "b": "x"} - - -def test_parse_json_markdown_extracts_raw_json_array() -> None: - assert parse_json_markdown('[{"a": 1}]') == {"a": 1} - - -def test_parse_json_markdown_raises_when_no_json_block_exists() -> None: - with pytest.raises(ValueError, match="could not find json block"): - parse_json_markdown("plain text only") - - -def test_parse_and_check_json_markdown_unwraps_single_dict_list() -> None: - parsed = parse_and_check_json_markdown( - """ - ```json - [{"present": 1, "other": 2}] - ``` - """, - ["present"], - ) - - assert parsed == {"present": 1, "other": 2} - - -def test_parse_and_check_json_markdown_rejects_invalid_json() -> None: - with pytest.raises(OutputParserError, match="got invalid json object"): - parse_and_check_json_markdown( - """ - ```json - {invalid json} - ``` - """, - [], - ) - - -def test_parse_and_check_json_markdown_rejects_invalid_return_shapes() -> None: - with pytest.raises(OutputParserError, match="got invalid return object"): - parse_and_check_json_markdown( - """ - ```json - [1, 2] - ``` - """, - ["present"], - ) - - -def test_parse_and_check_json_markdown_requires_expected_keys() -> None: - with pytest.raises(OutputParserError, match="expected key `missing`"): - parse_and_check_json_markdown( - """ - ```json - {"present": 1} - ``` - """, - ["present", "missing"], - ) diff --git a/api/tests/unit_tests/libs/_human_input/support.py b/api/tests/unit_tests/libs/_human_input/support.py index e6cc23161e..13577b7ca5 100644 --- a/api/tests/unit_tests/libs/_human_input/support.py +++ b/api/tests/unit_tests/libs/_human_input/support.py @@ -6,6 +6,7 @@ from typing import Any from graphon.nodes.human_input.entities import FormInput from graphon.nodes.human_input.enums import TimeoutUnit + from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/libs/_human_input/test_form_service.py b/api/tests/unit_tests/libs/_human_input/test_form_service.py index fa2c02020b..f1ce1a2c1c 100644 --- a/api/tests/unit_tests/libs/_human_input/test_form_service.py +++ b/api/tests/unit_tests/libs/_human_input/test_form_service.py @@ -5,7 +5,6 @@ Unit tests for FormService. from datetime import timedelta import pytest - from graphon.nodes.human_input.entities import ( FormInput, UserAction, @@ -14,6 +13,7 @@ from graphon.nodes.human_input.enums import ( FormInputType, TimeoutUnit, ) + from libs.datetime_utils import naive_utc_now from .support import ( diff --git a/api/tests/unit_tests/libs/_human_input/test_models.py b/api/tests/unit_tests/libs/_human_input/test_models.py index 866ee61b3e..0babfbb315 100644 --- a/api/tests/unit_tests/libs/_human_input/test_models.py +++ b/api/tests/unit_tests/libs/_human_input/test_models.py @@ -5,7 +5,6 @@ Unit tests for human input form models. from datetime import datetime, timedelta import pytest - from graphon.nodes.human_input.entities import ( FormInput, UserAction, @@ -14,6 +13,7 @@ from graphon.nodes.human_input.enums import ( FormInputType, TimeoutUnit, ) + from libs.datetime_utils import naive_utc_now from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py index bb3a6db1a1..86163f1554 100644 --- a/api/tests/unit_tests/models/test_conversation_variable.py +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -1,7 +1,8 @@ from uuid import uuid4 -from factories import variable_factory from graphon.variables import SegmentType + +from factories import variable_factory from models import ConversationVariable diff --git a/api/tests/unit_tests/models/test_model.py b/api/tests/unit_tests/models/test_model.py index e21f0e4fbd..a5909f60a8 100644 --- a/api/tests/unit_tests/models/test_model.py +++ b/api/tests/unit_tests/models/test_model.py @@ -2,9 +2,9 @@ import importlib import types import pytest +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod from core.workflow.file_reference import build_file_reference -from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod from models.model import Conversation, Message diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index 550441539a..e7c0479757 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -3,14 +3,14 @@ import json from unittest import mock from uuid import uuid4 +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable +from graphon.variables.segments import IntegerSegment, Segment + from constants import HIDDEN_VALUE from core.helper import encrypter from core.workflow.file_reference import build_file_reference from factories.variable_factory import build_segment -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File -from graphon.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable -from graphon.variables.segments import IntegerSegment, Segment from models.workflow import ( Workflow, WorkflowDraftVariable, diff --git a/api/tests/unit_tests/models/test_workflow_models.py b/api/tests/unit_tests/models/test_workflow_models.py index eb9fef7587..507e1c8c3a 100644 --- a/api/tests/unit_tests/models/test_workflow_models.py +++ b/api/tests/unit_tests/models/test_workflow_models.py @@ -13,12 +13,12 @@ from datetime import UTC, datetime from uuid import uuid4 import pytest - from graphon.enums import ( BuiltinNodeTypes, WorkflowExecutionStatus, WorkflowNodeExecutionStatus, ) + from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import ( Workflow, diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py index ccc9c93815..10850970d8 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -9,11 +9,6 @@ from decimal import Decimal from unittest.mock import MagicMock, PropertyMock import pytest -from pytest_mock import MockerFixture -from sqlalchemy.orm import Session, sessionmaker - -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.repositories.factory import OrderConfig from graphon.entities import ( WorkflowNodeExecution, ) @@ -23,6 +18,11 @@ from graphon.enums import ( WorkflowNodeExecutionStatus, ) from graphon.model_runtime.utils.encoders import jsonable_encoder +from pytest_mock import MockerFixture +from sqlalchemy.orm import Session, sessionmaker + +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from models.account import Account, Tenant from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py index e8c094b75d..2322be9e80 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py @@ -6,13 +6,13 @@ from datetime import datetime from typing import Any from unittest.mock import MagicMock, Mock +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from graphon.entities.workflow_node_execution import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/services/dataset_service_test_helpers.py b/api/tests/unit_tests/services/dataset_service_test_helpers.py index c95b60fad0..ef73bc0e01 100644 --- a/api/tests/unit_tests/services/dataset_service_test_helpers.py +++ b/api/tests/unit_tests/services/dataset_service_test_helpers.py @@ -10,6 +10,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, create_autospec, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from werkzeug.exceptions import Forbidden, NotFound from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError @@ -17,7 +18,6 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod from enums.cloud_plan import CloudPlan -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from models import Account, TenantAccountRole from models.dataset import ( ChildChunk, diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 3358c8b44d..7c36e9d960 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -109,10 +109,10 @@ This test suite follows a comprehensive testing strategy that covers: from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelType from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, DatasetProcessRule, Document from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import ( diff --git a/api/tests/unit_tests/services/test_app_dsl_service.py b/api/tests/unit_tests/services/test_app_dsl_service.py index afea8ec92a..179518a5fa 100644 --- a/api/tests/unit_tests/services/test_app_dsl_service.py +++ b/api/tests/unit_tests/services/test_app_dsl_service.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock import pytest import yaml +from graphon.enums import BuiltinNodeTypes from core.trigger.constants import ( TRIGGER_PLUGIN_NODE_TYPE, TRIGGER_SCHEDULE_NODE_TYPE, TRIGGER_WEBHOOK_NODE_TYPE, ) -from graphon.enums import BuiltinNodeTypes from models import Account, AppMode from models.model import IconType from services import app_dsl_service diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index da93239600..3df7d500cf 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -1,10 +1,10 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy.orm import Session from core.plugin.entities.plugin_daemon import CredentialType -from graphon.model_runtime.entities.provider_entities import FormType from models.account import Account from models.model import EndUser from models.oauth import DatasourceProvider diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index 55af564821..9be475d043 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -3,18 +3,18 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest - -import services.human_input_service as human_input_service_module -from core.repositories.human_input_repository import ( - HumanInputFormRecord, - HumanInputFormSubmissionRepository, -) from graphon.nodes.human_input.entities import ( FormDefinition, FormInput, UserAction, ) from graphon.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus + +import services.human_input_service as human_input_service_module +from core.repositories.human_input_repository import ( + HumanInputFormRecord, + HumanInputFormSubmissionRepository, +) from libs.datetime_utils import naive_utc_now from models.human_input import RecipientType from services.human_input_service import ( diff --git a/api/tests/unit_tests/services/test_model_load_balancing_service.py b/api/tests/unit_tests/services/test_model_load_balancing_service.py index 1e898ada11..b43e79dff5 100644 --- a/api/tests/unit_tests/services/test_model_load_balancing_service.py +++ b/api/tests/unit_tests/services/test_model_load_balancing_service.py @@ -6,9 +6,6 @@ from typing import Any, cast from unittest.mock import MagicMock import pytest -from pytest_mock import MockerFixture - -from constants import HIDDEN_VALUE from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.provider_entities import ( @@ -18,6 +15,9 @@ from graphon.model_runtime.entities.provider_entities import ( ModelCredentialSchema, ProviderCredentialSchema, ) +from pytest_mock import MockerFixture + +from constants import HIDDEN_VALUE from models.provider import LoadBalancingModelConfig from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py index 97f3bd6f01..1bd979b9ec 100644 --- a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -1,11 +1,11 @@ import types import pytest - -from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.provider_entities import ConfigurateMethod + +from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration from models.provider import ProviderType from services.model_provider_service import ModelProviderService diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index 2fe6161785..9c23135225 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -16,9 +16,7 @@ from typing import Any from uuid import uuid4 import pytest - -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File +from graphon.file import File, FileTransferMethod, FileType from graphon.variables.segments import ( ArrayFileSegment, ArrayNumberSegment, @@ -30,6 +28,7 @@ from graphon.variables.segments import ( ObjectSegment, StringSegment, ) + from services.variable_truncator import ( DummyVariableTruncator, MaxDepthExceededError, diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index 239cc83518..a62c9f4555 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -13,10 +13,10 @@ from datetime import datetime from unittest.mock import MagicMock, create_autospec, patch import pytest +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowExecutionStatus from models.workflow import WorkflowPause from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index da606c8329..cd71981bcf 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -15,7 +15,6 @@ from typing import Any, cast from unittest.mock import ANY, MagicMock, patch import pytest - from graphon.entities import WorkflowNodeExecution from graphon.enums import ( BuiltinNodeTypes, @@ -29,6 +28,7 @@ from graphon.model_runtime.entities.model_entities import ModelType from graphon.node_events import NodeRunResult from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig from graphon.variables.input_entities import VariableEntityType + from libs.datetime_utils import naive_utc_now from models.human_input import RecipientType from models.model import App, AppMode diff --git a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py index 2db83576b0..8525672da8 100644 --- a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py +++ b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py @@ -4,13 +4,12 @@ import json from unittest.mock import Mock, patch import pytest +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segments import ObjectSegment, StringSegment +from graphon.variables.types import SegmentType from sqlalchemy import Engine from core.workflow.file_reference import build_file_reference -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File -from graphon.variables.segments import ObjectSegment, StringSegment -from graphon.variables.types import SegmentType from models.model import UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile from services.workflow_draft_variable_service import DraftVarLoader diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 6200c9f859..e7e72793a3 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -4,6 +4,10 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from sqlalchemy import Engine from sqlalchemy.orm import Session @@ -13,11 +17,6 @@ from core.workflow.variable_prefixes import ( ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) -from graphon.enums import BuiltinNodeTypes -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType from libs.uuid_utils import uuidv7 from models.account import Account from models.enums import DraftVariableType diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py index ce66b78b64..077a7c27a2 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py @@ -8,13 +8,13 @@ from datetime import UTC, datetime from threading import Event import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool from models.enums import CreatorUserRole from models.model import AppMode from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py index d7192994b2..98d057e41f 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -3,6 +3,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from sqlalchemy.orm import sessionmaker from core.workflow.human_input_compat import ( @@ -12,9 +15,6 @@ from core.workflow.human_input_compat import ( ExternalRecipient, MemberRecipient, ) -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.human_input.entities import HumanInputNodeData from services import workflow_service as workflow_service_module from services.workflow_service import WorkflowService diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 6b04a1bc09..b9d097350b 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -3,11 +3,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest - from graphon.entities.graph_config import NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes from graphon.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction from graphon.nodes.human_input.enums import FormInputType + from models.model import App from models.workflow import Workflow from services import workflow_service as workflow_service_module diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py index 591da56f49..7119217e94 100644 --- a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py +++ b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py @@ -5,8 +5,8 @@ from types import SimpleNamespace from typing import Any import pytest - from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus + from tasks import human_input_timeout_tasks as task_module diff --git a/api/tests/unit_tests/tools/test_mcp_tool.py b/api/tests/unit_tests/tools/test_mcp_tool.py index f31bf80046..68359ba078 100644 --- a/api/tests/unit_tests/tools/test_mcp_tool.py +++ b/api/tests/unit_tests/tools/test_mcp_tool.py @@ -3,6 +3,7 @@ from decimal import Decimal from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage from core.mcp.types import ( AudioContent, @@ -17,7 +18,6 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.mcp_tool.tool import MCPTool -from graphon.model_runtime.entities.llm_entities import LLMUsage def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool: diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index c166a946d9..ffa6833524 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -2,9 +2,6 @@ from decimal import Decimal from unittest.mock import MagicMock, patch import pytest - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, @@ -21,6 +18,9 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output + def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage: """Create a mock LLMUsage with all required fields""" diff --git a/api/tests/workflow_test_utils.py b/api/tests/workflow_test_utils.py index a29df0bb6b..d33ac2c710 100644 --- a/api/tests/workflow_test_utils.py +++ b/api/tests/workflow_test_utils.py @@ -1,12 +1,13 @@ from collections.abc import Mapping from typing import Any -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool -from graphon.entities.graph_init_params import GraphInitParams +from graphon.entities import GraphInitParams from graphon.runtime import VariablePool from graphon.variables.variables import Variable +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool + def build_test_run_context( *, diff --git a/api/uv.lock b/api/uv.lock index ed2b76ac3c..e0bd0de84d 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1489,12 +1489,12 @@ dependencies = [ { name = "google-auth-httplib2" }, { name = "google-cloud-aiplatform" }, { name = "googleapis-common-protos" }, + { name = "graphon" }, { name = "gunicorn" }, { name = "httpx", extra = ["socks"] }, { name = "httpx-sse" }, { name = "jieba" }, { name = "json-repair" }, - { name = "jsonschema" }, { name = "langfuse" }, { name = "langsmith" }, { name = "litellm" }, @@ -1526,7 +1526,6 @@ dependencies = [ { name = "psycopg2-binary" }, { name = "pycryptodome" }, { name = "pydantic" }, - { name = "pydantic-extra-types" }, { name = "pydantic-settings" }, { name = "pyjwt" }, { name = "pypandoc" }, @@ -1547,7 +1546,6 @@ dependencies = [ { name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] }, { name = "weave" }, { name = "weaviate-client" }, - { name = "webvtt-py" }, { name = "yarl" }, ] @@ -1590,7 +1588,6 @@ dev = [ { name = "types-greenlet" }, { name = "types-html5lib" }, { name = "types-jmespath" }, - { name = "types-jsonschema" }, { name = "types-markdown" }, { name = "types-oauthlib" }, { name = "types-objgraph" }, @@ -1692,12 +1689,12 @@ requires-dist = [ { name = "google-auth-httplib2", specifier = "==0.3.0" }, { name = "google-cloud-aiplatform", specifier = ">=1.123.0" }, { name = "googleapis-common-protos", specifier = ">=1.65.0" }, + { name = "graphon", specifier = ">=0.1.2" }, { name = "gunicorn", specifier = "~=25.1.0" }, { name = "httpx", extras = ["socks"], specifier = "~=0.28.0" }, { name = "httpx-sse", specifier = "~=0.4.0" }, { name = "jieba", specifier = "==0.42.1" }, { name = "json-repair", specifier = ">=0.55.1" }, - { name = "jsonschema", specifier = ">=4.25.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.7.16" }, { name = "litellm", specifier = "==1.82.6" }, @@ -1729,7 +1726,6 @@ requires-dist = [ { name = "psycopg2-binary", specifier = "~=2.9.6" }, { name = "pycryptodome", specifier = "==3.23.0" }, { name = "pydantic", specifier = "~=2.12.5" }, - { name = "pydantic-extra-types", specifier = "~=2.11.0" }, { name = "pydantic-settings", specifier = "~=2.13.1" }, { name = "pyjwt", specifier = "~=2.12.0" }, { name = "pypandoc", specifier = "~=1.13" }, @@ -1750,7 +1746,6 @@ requires-dist = [ { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.21.5" }, { name = "weave", specifier = ">=0.52.16" }, { name = "weaviate-client", specifier = "==4.20.4" }, - { name = "webvtt-py", specifier = "~=0.5.1" }, { name = "yarl", specifier = "~=1.23.0" }, ] @@ -1793,7 +1788,6 @@ dev = [ { name = "types-greenlet", specifier = "~=3.3.0" }, { name = "types-html5lib", specifier = "~=1.1.11" }, { name = "types-jmespath", specifier = ">=1.0.2.20240106" }, - { name = "types-jsonschema", specifier = "~=4.26.0" }, { name = "types-markdown", specifier = "~=3.10.2" }, { name = "types-oauthlib", specifier = "~=3.3.0" }, { name = "types-objgraph", specifier = "~=3.6.0" }, @@ -2652,6 +2646,34 @@ requests = [ { name = "requests-toolbelt" }, ] +[[package]] +name = "graphon" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "charset-normalizer" }, + { name = "httpx" }, + { name = "json-repair" }, + { name = "jsonschema" }, + { name = "orjson" }, + { name = "pandas", extra = ["excel"] }, + { name = "pydantic" }, + { name = "pydantic-extra-types" }, + { name = "pypandoc" }, + { name = "pypdfium2" }, + { name = "python-docx" }, + { name = "pyyaml" }, + { name = "tiktoken" }, + { name = "transformers" }, + { name = "typing-extensions" }, + { name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] }, + { name = "webvtt-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/fc/0a5342a1c29bc367c2254c170ef130a84a60d8cd1c9cc84a7a85e96c1042/graphon-0.1.2.tar.gz", hash = "sha256:a2210629f93258ad2e7cbe85b5d4c6826814f6c679aa2a23ca100511363b9240", size = 214744, upload-time = "2026-03-27T20:09:53.802Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/46/65b5e366ec2d7017b6d6448e2635b3772d86840a6f7297277471b1bfbfbd/graphon-0.1.2-py3-none-any.whl", hash = "sha256:79f0c7796de7b8642d070730bb8bdaf1c68ccdfcecac38e0b2282e0543f0a6db", size = 314398, upload-time = "2026-03-27T20:09:52.524Z" }, +] + [[package]] name = "graphql-core" version = "3.2.7" @@ -6850,18 +6872,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/61/91/915c4a6e6e9bd2bca3ec0c21c1771b175c59e204b85e57f3f572370fe753/types_jmespath-1.1.0.20260124-py3-none-any.whl", hash = "sha256:ec387666d446b15624215aa9cbd2867ffd885b6c74246d357c65e830c7a138b3", size = 11509, upload-time = "2026-01-24T03:18:45.536Z" }, ] -[[package]] -name = "types-jsonschema" -version = "4.26.0.20260202" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "referencing" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a1/07/68f63e715eb327ed2f5292e29e8be99785db0f72c7664d2c63bd4dbdc29d/types_jsonschema-4.26.0.20260202.tar.gz", hash = "sha256:29831baa4308865a9aec547a61797a06fc152b0dac8dddd531e002f32265cb07", size = 16168, upload-time = "2026-02-02T04:11:22.585Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/06/962d4f364f779d7389cd31a1bb581907b057f52f0ace2c119a8dd8409db6/types_jsonschema-4.26.0.20260202-py3-none-any.whl", hash = "sha256:41c95343abc4de9264e333a55e95dfb4d401e463856d0164eec9cb182e8746da", size = 15914, upload-time = "2026-02-02T04:11:21.61Z" }, -] - [[package]] name = "types-markdown" version = "3.10.2.20260211"