mirror of
https://github.com/langgenius/dify.git
synced 2026-04-25 13:16:16 +08:00
Merge remote-tracking branch 'origin/main' into feat/trigger
This commit is contained in:
@ -25,7 +25,7 @@ class FirecrawlApp:
|
||||
}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v1/scrape", json_data, headers)
|
||||
response = self._post_request(f"{self.base_url}/v2/scrape", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
data = response_data["data"]
|
||||
@ -42,7 +42,7 @@ class FirecrawlApp:
|
||||
json_data = {"url": url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers)
|
||||
response = self._post_request(f"{self.base_url}/v2/crawl", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
# There's also another two fields in the response: "success" (bool) and "url" (str)
|
||||
job_id = response.json().get("id")
|
||||
@ -51,9 +51,25 @@ class FirecrawlApp:
|
||||
self._handle_error(response, "start crawl job")
|
||||
return "" # unreachable
|
||||
|
||||
def map(self, url: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/map
|
||||
headers = self._prepare_headers()
|
||||
json_data: dict[str, Any] = {"url": url, "integration": "dify"}
|
||||
if params:
|
||||
# Pass through provided params, including optional "sitemap": "only" | "include" | "skip"
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v2/map", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
return cast(dict[str, Any], response.json())
|
||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||
self._handle_error(response, "start map job")
|
||||
return {}
|
||||
else:
|
||||
raise Exception(f"Failed to start map job. Status code: {response.status_code}")
|
||||
|
||||
def check_crawl_status(self, job_id) -> dict[str, Any]:
|
||||
headers = self._prepare_headers()
|
||||
response = self._get_request(f"{self.base_url}/v1/crawl/{job_id}", headers)
|
||||
response = self._get_request(f"{self.base_url}/v2/crawl/{job_id}", headers)
|
||||
if response.status_code == 200:
|
||||
crawl_status_response = response.json()
|
||||
if crawl_status_response.get("status") == "completed":
|
||||
@ -135,12 +151,16 @@ class FirecrawlApp:
|
||||
"lang": "en",
|
||||
"country": "us",
|
||||
"timeout": 60000,
|
||||
"ignoreInvalidURLs": False,
|
||||
"ignoreInvalidURLs": True,
|
||||
"scrapeOptions": {},
|
||||
"sources": [
|
||||
{"type": "web"},
|
||||
],
|
||||
"integration": "dify",
|
||||
}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v1/search", json_data, headers)
|
||||
response = self._post_request(f"{self.base_url}/v2/search", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if not response_data.get("success"):
|
||||
|
||||
@ -41,6 +41,7 @@ class RedisChannel:
|
||||
self._redis = redis_client
|
||||
self._key = channel_key
|
||||
self._command_ttl = command_ttl
|
||||
self._pending_key = f"{channel_key}:pending"
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
@ -49,6 +50,9 @@ class RedisChannel:
|
||||
Returns:
|
||||
List of pending commands (drains the Redis list)
|
||||
"""
|
||||
if not self._has_pending_commands():
|
||||
return []
|
||||
|
||||
commands: list[GraphEngineCommand] = []
|
||||
|
||||
# Use pipeline for atomic operations
|
||||
@ -85,6 +89,7 @@ class RedisChannel:
|
||||
with self._redis.pipeline() as pipe:
|
||||
pipe.rpush(self._key, command_json)
|
||||
pipe.expire(self._key, self._command_ttl)
|
||||
pipe.set(self._pending_key, "1", ex=self._command_ttl)
|
||||
pipe.execute()
|
||||
|
||||
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
|
||||
@ -112,3 +117,17 @@ class RedisChannel:
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def _has_pending_commands(self) -> bool:
|
||||
"""
|
||||
Check and consume the pending marker to avoid unnecessary list reads.
|
||||
|
||||
Returns:
|
||||
True if commands should be fetched from Redis.
|
||||
"""
|
||||
with self._redis.pipeline() as pipe:
|
||||
pipe.get(self._pending_key)
|
||||
pipe.delete(self._pending_key)
|
||||
pending_value, _ = pipe.execute()
|
||||
|
||||
return pending_value is not None
|
||||
|
||||
@ -7,6 +7,7 @@ from collections.abc import Mapping
|
||||
from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
@ -125,6 +126,7 @@ class EventHandler:
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
is_initial_attempt = node_execution.retry_count == 0
|
||||
node_execution.mark_started(event.id)
|
||||
self._graph_runtime_state.increment_node_run_steps()
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
@ -163,6 +165,8 @@ class EventHandler:
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
@ -212,6 +216,8 @@ class EventHandler:
|
||||
node_execution.mark_failed(event.error)
|
||||
self._graph_execution.record_node_failure()
|
||||
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
result = self._error_handler.handle_node_failure(event)
|
||||
|
||||
if result:
|
||||
@ -235,6 +241,8 @@ class EventHandler:
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Persist outputs produced by the exception strategy (e.g. default values)
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
@ -286,6 +294,19 @@ class EventHandler:
|
||||
self._state_manager.enqueue_node(event.node_id)
|
||||
self._state_manager.start_execution(event.node_id)
|
||||
|
||||
def _accumulate_node_usage(self, usage: LLMUsage) -> None:
|
||||
"""Accumulate token usage into the shared runtime state."""
|
||||
if usage.total_tokens <= 0:
|
||||
return
|
||||
|
||||
self._graph_runtime_state.add_tokens(usage.total_tokens)
|
||||
|
||||
current_usage = self._graph_runtime_state.llm_usage
|
||||
if current_usage.total_tokens == 0:
|
||||
self._graph_runtime_state.llm_usage = usage
|
||||
else:
|
||||
self._graph_runtime_state.llm_usage = current_usage.plus(usage)
|
||||
|
||||
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
|
||||
"""
|
||||
Store node outputs in the variable pool.
|
||||
|
||||
@ -8,7 +8,12 @@ import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.graph_events.base import GraphNodeEventBase
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from ..event_management import EventManager
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
@ -72,13 +77,16 @@ class Dispatcher:
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=10.0)
|
||||
|
||||
_COMMAND_TRIGGER_EVENTS = (
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunExceptionEvent,
|
||||
)
|
||||
|
||||
def _dispatcher_loop(self) -> None:
|
||||
"""Main dispatcher loop."""
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
# Check for commands
|
||||
self._execution_coordinator.check_commands()
|
||||
|
||||
# Check for scaling
|
||||
self._execution_coordinator.check_scaling()
|
||||
|
||||
@ -87,6 +95,8 @@ class Dispatcher:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
# Route to the event handler
|
||||
self._event_handler.dispatch(event)
|
||||
if self._should_check_commands(event):
|
||||
self._execution_coordinator.check_commands()
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
# Check if execution is complete
|
||||
@ -102,3 +112,7 @@ class Dispatcher:
|
||||
# Signal the event emitter that execution is complete
|
||||
if self._event_emitter:
|
||||
self._event_emitter.mark_complete()
|
||||
|
||||
def _should_check_commands(self, event: GraphNodeEventBase) -> bool:
|
||||
"""Return True if the event represents a node completion."""
|
||||
return isinstance(event, self._COMMAND_TRIGGER_EVENTS)
|
||||
|
||||
@ -23,6 +23,7 @@ class CrawlOptions:
|
||||
only_main_content: bool = False
|
||||
includes: str | None = None
|
||||
excludes: str | None = None
|
||||
prompt: str | None = None
|
||||
max_depth: int | None = None
|
||||
use_sitemap: bool = True
|
||||
|
||||
@ -70,6 +71,7 @@ class WebsiteCrawlApiRequest:
|
||||
only_main_content=self.options.get("only_main_content", False),
|
||||
includes=self.options.get("includes"),
|
||||
excludes=self.options.get("excludes"),
|
||||
prompt=self.options.get("prompt"),
|
||||
max_depth=self.options.get("max_depth"),
|
||||
use_sitemap=self.options.get("use_sitemap", True),
|
||||
)
|
||||
@ -174,6 +176,7 @@ class WebsiteService:
|
||||
def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
|
||||
params: dict[str, Any]
|
||||
if not request.options.crawl_sub_pages:
|
||||
params = {
|
||||
"includePaths": [],
|
||||
@ -188,8 +191,10 @@ class WebsiteService:
|
||||
"limit": request.options.limit,
|
||||
"scrapeOptions": {"onlyMainContent": request.options.only_main_content},
|
||||
}
|
||||
if request.options.max_depth:
|
||||
params["maxDepth"] = request.options.max_depth
|
||||
|
||||
# Add optional prompt for Firecrawl v2 crawl-params compatibility
|
||||
if request.options.prompt:
|
||||
params["prompt"] = request.options.prompt
|
||||
|
||||
job_id = firecrawl_app.crawl_url(request.url, params)
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
|
||||
@ -35,11 +35,15 @@ class TestRedisChannel:
|
||||
"""Test sending a command to Redis."""
|
||||
mock_redis = MagicMock()
|
||||
mock_pipe = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
|
||||
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
|
||||
context = MagicMock()
|
||||
context.__enter__.return_value = mock_pipe
|
||||
context.__exit__.return_value = None
|
||||
mock_redis.pipeline.return_value = context
|
||||
|
||||
channel = RedisChannel(mock_redis, "test:key", 3600)
|
||||
|
||||
pending_key = "test:key:pending"
|
||||
|
||||
# Create a test command
|
||||
command = GraphEngineCommand(command_type=CommandType.ABORT)
|
||||
|
||||
@ -55,6 +59,7 @@ class TestRedisChannel:
|
||||
|
||||
# Verify expire was set
|
||||
mock_pipe.expire.assert_called_once_with("test:key", 3600)
|
||||
mock_pipe.set.assert_called_once_with(pending_key, "1", ex=3600)
|
||||
|
||||
# Verify execute was called
|
||||
mock_pipe.execute.assert_called_once()
|
||||
@ -62,33 +67,48 @@ class TestRedisChannel:
|
||||
def test_fetch_commands_empty(self):
|
||||
"""Test fetching commands when Redis list is empty."""
|
||||
mock_redis = MagicMock()
|
||||
mock_pipe = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
|
||||
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
|
||||
pending_pipe = MagicMock()
|
||||
fetch_pipe = MagicMock()
|
||||
pending_context = MagicMock()
|
||||
fetch_context = MagicMock()
|
||||
pending_context.__enter__.return_value = pending_pipe
|
||||
pending_context.__exit__.return_value = None
|
||||
fetch_context.__enter__.return_value = fetch_pipe
|
||||
fetch_context.__exit__.return_value = None
|
||||
mock_redis.pipeline.side_effect = [pending_context]
|
||||
|
||||
# Simulate empty list
|
||||
mock_pipe.execute.return_value = [[], 1] # Empty list, delete successful
|
||||
# No pending marker
|
||||
pending_pipe.execute.return_value = [None, 0]
|
||||
mock_redis.llen.return_value = 0
|
||||
|
||||
channel = RedisChannel(mock_redis, "test:key")
|
||||
commands = channel.fetch_commands()
|
||||
|
||||
assert commands == []
|
||||
mock_pipe.lrange.assert_called_once_with("test:key", 0, -1)
|
||||
mock_pipe.delete.assert_called_once_with("test:key")
|
||||
mock_redis.pipeline.assert_called_once()
|
||||
fetch_pipe.lrange.assert_not_called()
|
||||
fetch_pipe.delete.assert_not_called()
|
||||
|
||||
def test_fetch_commands_with_abort_command(self):
|
||||
"""Test fetching abort commands from Redis."""
|
||||
mock_redis = MagicMock()
|
||||
mock_pipe = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
|
||||
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
|
||||
pending_pipe = MagicMock()
|
||||
fetch_pipe = MagicMock()
|
||||
pending_context = MagicMock()
|
||||
fetch_context = MagicMock()
|
||||
pending_context.__enter__.return_value = pending_pipe
|
||||
pending_context.__exit__.return_value = None
|
||||
fetch_context.__enter__.return_value = fetch_pipe
|
||||
fetch_context.__exit__.return_value = None
|
||||
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
|
||||
|
||||
# Create abort command data
|
||||
abort_command = AbortCommand()
|
||||
command_json = json.dumps(abort_command.model_dump())
|
||||
|
||||
# Simulate Redis returning one command
|
||||
mock_pipe.execute.return_value = [[command_json.encode()], 1]
|
||||
pending_pipe.execute.return_value = [b"1", 1]
|
||||
fetch_pipe.execute.return_value = [[command_json.encode()], 1]
|
||||
|
||||
channel = RedisChannel(mock_redis, "test:key")
|
||||
commands = channel.fetch_commands()
|
||||
@ -100,9 +120,15 @@ class TestRedisChannel:
|
||||
def test_fetch_commands_multiple(self):
|
||||
"""Test fetching multiple commands from Redis."""
|
||||
mock_redis = MagicMock()
|
||||
mock_pipe = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
|
||||
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
|
||||
pending_pipe = MagicMock()
|
||||
fetch_pipe = MagicMock()
|
||||
pending_context = MagicMock()
|
||||
fetch_context = MagicMock()
|
||||
pending_context.__enter__.return_value = pending_pipe
|
||||
pending_context.__exit__.return_value = None
|
||||
fetch_context.__enter__.return_value = fetch_pipe
|
||||
fetch_context.__exit__.return_value = None
|
||||
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
|
||||
|
||||
# Create multiple commands
|
||||
command1 = GraphEngineCommand(command_type=CommandType.ABORT)
|
||||
@ -112,7 +138,8 @@ class TestRedisChannel:
|
||||
command2_json = json.dumps(command2.model_dump())
|
||||
|
||||
# Simulate Redis returning multiple commands
|
||||
mock_pipe.execute.return_value = [[command1_json.encode(), command2_json.encode()], 1]
|
||||
pending_pipe.execute.return_value = [b"1", 1]
|
||||
fetch_pipe.execute.return_value = [[command1_json.encode(), command2_json.encode()], 1]
|
||||
|
||||
channel = RedisChannel(mock_redis, "test:key")
|
||||
commands = channel.fetch_commands()
|
||||
@ -124,9 +151,15 @@ class TestRedisChannel:
|
||||
def test_fetch_commands_skips_invalid_json(self):
|
||||
"""Test that invalid JSON commands are skipped."""
|
||||
mock_redis = MagicMock()
|
||||
mock_pipe = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
|
||||
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
|
||||
pending_pipe = MagicMock()
|
||||
fetch_pipe = MagicMock()
|
||||
pending_context = MagicMock()
|
||||
fetch_context = MagicMock()
|
||||
pending_context.__enter__.return_value = pending_pipe
|
||||
pending_context.__exit__.return_value = None
|
||||
fetch_context.__enter__.return_value = fetch_pipe
|
||||
fetch_context.__exit__.return_value = None
|
||||
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
|
||||
|
||||
# Mix valid and invalid JSON
|
||||
valid_command = AbortCommand()
|
||||
@ -134,7 +167,8 @@ class TestRedisChannel:
|
||||
invalid_json = b"invalid json {"
|
||||
|
||||
# Simulate Redis returning mixed valid/invalid commands
|
||||
mock_pipe.execute.return_value = [[invalid_json, valid_json.encode()], 1]
|
||||
pending_pipe.execute.return_value = [b"1", 1]
|
||||
fetch_pipe.execute.return_value = [[invalid_json, valid_json.encode()], 1]
|
||||
|
||||
channel = RedisChannel(mock_redis, "test:key")
|
||||
commands = channel.fetch_commands()
|
||||
@ -187,13 +221,20 @@ class TestRedisChannel:
|
||||
def test_atomic_fetch_and_clear(self):
|
||||
"""Test that fetch_commands atomically fetches and clears the list."""
|
||||
mock_redis = MagicMock()
|
||||
mock_pipe = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
|
||||
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
|
||||
pending_pipe = MagicMock()
|
||||
fetch_pipe = MagicMock()
|
||||
pending_context = MagicMock()
|
||||
fetch_context = MagicMock()
|
||||
pending_context.__enter__.return_value = pending_pipe
|
||||
pending_context.__exit__.return_value = None
|
||||
fetch_context.__enter__.return_value = fetch_pipe
|
||||
fetch_context.__exit__.return_value = None
|
||||
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
|
||||
|
||||
command = AbortCommand()
|
||||
command_json = json.dumps(command.model_dump())
|
||||
mock_pipe.execute.return_value = [[command_json.encode()], 1]
|
||||
pending_pipe.execute.return_value = [b"1", 1]
|
||||
fetch_pipe.execute.return_value = [[command_json.encode()], 1]
|
||||
|
||||
channel = RedisChannel(mock_redis, "test:key")
|
||||
|
||||
@ -202,7 +243,29 @@ class TestRedisChannel:
|
||||
assert len(commands) == 1
|
||||
|
||||
# Verify both lrange and delete were called in the pipeline
|
||||
assert mock_pipe.lrange.call_count == 1
|
||||
assert mock_pipe.delete.call_count == 1
|
||||
mock_pipe.lrange.assert_called_with("test:key", 0, -1)
|
||||
mock_pipe.delete.assert_called_with("test:key")
|
||||
assert fetch_pipe.lrange.call_count == 1
|
||||
assert fetch_pipe.delete.call_count == 1
|
||||
fetch_pipe.lrange.assert_called_with("test:key", 0, -1)
|
||||
fetch_pipe.delete.assert_called_with("test:key")
|
||||
|
||||
def test_fetch_commands_without_pending_marker_returns_empty(self):
|
||||
"""Ensure we avoid unnecessary list reads when pending flag is missing."""
|
||||
mock_redis = MagicMock()
|
||||
pending_pipe = MagicMock()
|
||||
fetch_pipe = MagicMock()
|
||||
pending_context = MagicMock()
|
||||
fetch_context = MagicMock()
|
||||
pending_context.__enter__.return_value = pending_pipe
|
||||
pending_context.__exit__.return_value = None
|
||||
fetch_context.__enter__.return_value = fetch_pipe
|
||||
fetch_context.__exit__.return_value = None
|
||||
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
|
||||
|
||||
# Pending flag absent
|
||||
pending_pipe.execute.return_value = [None, 0]
|
||||
channel = RedisChannel(mock_redis, "test:key")
|
||||
commands = channel.fetch_commands()
|
||||
|
||||
assert commands == []
|
||||
mock_redis.llen.assert_not_called()
|
||||
assert mock_redis.pipeline.call_count == 1
|
||||
|
||||
@ -0,0 +1,104 @@
|
||||
"""Tests for dispatcher command checking behavior."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import queue
|
||||
from datetime import datetime
|
||||
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.event_management.event_manager import EventManager
|
||||
from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher
|
||||
from core.workflow.graph_events import NodeRunStartedEvent, NodeRunSucceededEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
|
||||
class _StubExecutionCoordinator:
|
||||
"""Stub execution coordinator that tracks command checks."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.command_checks = 0
|
||||
self.scaling_checks = 0
|
||||
self._execution_complete = False
|
||||
self.mark_complete_called = False
|
||||
self.failed = False
|
||||
|
||||
def check_commands(self) -> None:
|
||||
self.command_checks += 1
|
||||
|
||||
def check_scaling(self) -> None:
|
||||
self.scaling_checks += 1
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
return self._execution_complete
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
self.mark_complete_called = True
|
||||
|
||||
def mark_failed(self, error: Exception) -> None: # pragma: no cover - defensive, not triggered in tests
|
||||
self.failed = True
|
||||
|
||||
def set_execution_complete(self) -> None:
|
||||
self._execution_complete = True
|
||||
|
||||
|
||||
class _StubEventHandler:
|
||||
"""Minimal event handler that marks execution complete after handling an event."""
|
||||
|
||||
def __init__(self, coordinator: _StubExecutionCoordinator) -> None:
|
||||
self._coordinator = coordinator
|
||||
self.events = []
|
||||
|
||||
def dispatch(self, event) -> None:
|
||||
self.events.append(event)
|
||||
self._coordinator.set_execution_complete()
|
||||
|
||||
|
||||
def _run_dispatcher_for_event(event) -> int:
|
||||
"""Run the dispatcher loop for a single event and return command check count."""
|
||||
event_queue: queue.Queue = queue.Queue()
|
||||
event_queue.put(event)
|
||||
|
||||
coordinator = _StubExecutionCoordinator()
|
||||
event_handler = _StubEventHandler(coordinator)
|
||||
event_manager = EventManager()
|
||||
|
||||
dispatcher = Dispatcher(
|
||||
event_queue=event_queue,
|
||||
event_handler=event_handler,
|
||||
event_collector=event_manager,
|
||||
execution_coordinator=coordinator,
|
||||
)
|
||||
|
||||
dispatcher._dispatcher_loop()
|
||||
|
||||
return coordinator.command_checks
|
||||
|
||||
|
||||
def _make_started_event() -> NodeRunStartedEvent:
|
||||
return NodeRunStartedEvent(
|
||||
id="start-event",
|
||||
node_id="node-1",
|
||||
node_type=NodeType.CODE,
|
||||
node_title="Test Node",
|
||||
start_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
def _make_succeeded_event() -> NodeRunSucceededEvent:
|
||||
return NodeRunSucceededEvent(
|
||||
id="success-event",
|
||||
node_id="node-1",
|
||||
node_type=NodeType.CODE,
|
||||
node_title="Test Node",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
|
||||
)
|
||||
|
||||
|
||||
def test_dispatcher_checks_commands_after_node_completion() -> None:
|
||||
"""Dispatcher should only check commands after node completion events."""
|
||||
started_checks = _run_dispatcher_for_event(_make_started_event())
|
||||
succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event())
|
||||
|
||||
assert started_checks == 0
|
||||
assert succeeded_checks == 1
|
||||
@ -132,15 +132,22 @@ class TestRedisStopIntegration:
|
||||
"""Test RedisChannel correctly fetches and deserializes commands."""
|
||||
# Setup
|
||||
mock_redis = MagicMock()
|
||||
mock_pipeline = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
|
||||
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
|
||||
pending_pipe = MagicMock()
|
||||
fetch_pipe = MagicMock()
|
||||
pending_context = MagicMock()
|
||||
fetch_context = MagicMock()
|
||||
pending_context.__enter__.return_value = pending_pipe
|
||||
pending_context.__exit__.return_value = None
|
||||
fetch_context.__enter__.return_value = fetch_pipe
|
||||
fetch_context.__exit__.return_value = None
|
||||
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
|
||||
|
||||
# Mock command data
|
||||
abort_command_json = json.dumps({"command_type": CommandType.ABORT, "reason": "Test abort", "payload": None})
|
||||
|
||||
# Mock pipeline execute to return commands
|
||||
mock_pipeline.execute.return_value = [
|
||||
pending_pipe.execute.return_value = [b"1", 1]
|
||||
fetch_pipe.execute.return_value = [
|
||||
[abort_command_json.encode()], # lrange result
|
||||
True, # delete result
|
||||
]
|
||||
@ -158,19 +165,29 @@ class TestRedisStopIntegration:
|
||||
assert commands[0].reason == "Test abort"
|
||||
|
||||
# Verify Redis operations
|
||||
mock_pipeline.lrange.assert_called_once_with(channel_key, 0, -1)
|
||||
mock_pipeline.delete.assert_called_once_with(channel_key)
|
||||
pending_pipe.get.assert_called_once_with(f"{channel_key}:pending")
|
||||
pending_pipe.delete.assert_called_once_with(f"{channel_key}:pending")
|
||||
fetch_pipe.lrange.assert_called_once_with(channel_key, 0, -1)
|
||||
fetch_pipe.delete.assert_called_once_with(channel_key)
|
||||
assert mock_redis.pipeline.call_count == 2
|
||||
|
||||
def test_redis_channel_fetch_commands_handles_invalid_json(self):
|
||||
"""Test RedisChannel gracefully handles invalid JSON in commands."""
|
||||
# Setup
|
||||
mock_redis = MagicMock()
|
||||
mock_pipeline = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
|
||||
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
|
||||
pending_pipe = MagicMock()
|
||||
fetch_pipe = MagicMock()
|
||||
pending_context = MagicMock()
|
||||
fetch_context = MagicMock()
|
||||
pending_context.__enter__.return_value = pending_pipe
|
||||
pending_context.__exit__.return_value = None
|
||||
fetch_context.__enter__.return_value = fetch_pipe
|
||||
fetch_context.__exit__.return_value = None
|
||||
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
|
||||
|
||||
# Mock invalid command data
|
||||
mock_pipeline.execute.return_value = [
|
||||
pending_pipe.execute.return_value = [b"1", 1]
|
||||
fetch_pipe.execute.return_value = [
|
||||
[b"invalid json", b'{"command_type": "invalid_type"}'], # lrange result
|
||||
True, # delete result
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user