mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
feat: add llm first token timeout config
This commit is contained in:
@ -109,6 +109,7 @@ class ModelInstance:
|
||||
stream: Literal[True] = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
first_token_timeout: float | None = None,
|
||||
) -> Generator: ...
|
||||
|
||||
@overload
|
||||
@ -121,6 +122,7 @@ class ModelInstance:
|
||||
stream: Literal[False] = False,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
first_token_timeout: float | None = None,
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
@ -133,6 +135,7 @@ class ModelInstance:
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
first_token_timeout: float | None = None,
|
||||
) -> Union[LLMResult, Generator]: ...
|
||||
|
||||
def invoke_llm(
|
||||
@ -144,6 +147,7 @@ class ModelInstance:
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
first_token_timeout: float | None = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -155,26 +159,33 @@ class ModelInstance:
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:param first_token_timeout: timeout in seconds for receiving first token (streaming only)
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
|
||||
result = self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
# Apply first token timeout wrapper for streaming responses
|
||||
if stream and first_token_timeout and first_token_timeout > 0 and isinstance(result, Generator):
|
||||
from core.workflow.utils.generator_timeout import with_first_token_timeout
|
||||
|
||||
result = with_first_token_timeout(result, first_token_timeout)
|
||||
|
||||
return cast(Union[LLMResult, Generator], result)
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self, prompt_messages: Sequence[PromptMessage], tools: Sequence[PromptMessageTool] | None = None
|
||||
) -> int:
|
||||
|
||||
@ -23,10 +23,22 @@ class RetryConfig(BaseModel):
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
# First token timeout for LLM nodes (milliseconds), 0 means no timeout
|
||||
first_token_timeout: int = 0
|
||||
|
||||
@property
|
||||
def first_token_timeout_seconds(self) -> float:
|
||||
return self.first_token_timeout / 1000
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
@property
|
||||
def has_first_token_timeout(self) -> bool:
|
||||
"""Check if first token timeout should be applied (retry enabled and timeout > 0)."""
|
||||
return self.retry_enabled and self.first_token_timeout > 0
|
||||
|
||||
|
||||
class VariableSelector(BaseModel):
|
||||
"""
|
||||
|
||||
@ -43,3 +43,11 @@ class FileTypeNotSupportError(LLMNodeError):
|
||||
class UnsupportedPromptContentTypeError(LLMNodeError):
|
||||
def __init__(self, *, type_name: str):
|
||||
super().__init__(f"Prompt content type {type_name} is not supported.")
|
||||
|
||||
|
||||
class LLMFirstTokenTimeoutError(LLMNodeError):
|
||||
"""Raised when LLM request fails to receive first token within configured timeout."""
|
||||
|
||||
def __init__(self, timeout_ms: int):
|
||||
self.timeout_ms = timeout_ms
|
||||
super().__init__(f"LLM request timed out after {timeout_ms}ms without receiving first token")
|
||||
|
||||
@ -237,6 +237,13 @@ class LLMNode(Node[LLMNodeData]):
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
# Get first token timeout from retry config if enabled (convert ms to seconds)
|
||||
first_token_timeout = (
|
||||
self.node_data.retry_config.first_token_timeout_seconds
|
||||
if self.node_data.retry_config.has_first_token_timeout
|
||||
else None
|
||||
)
|
||||
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=self.node_data.model,
|
||||
model_instance=model_instance,
|
||||
@ -250,6 +257,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
reasoning_format=self.node_data.reasoning_format,
|
||||
first_token_timeout=first_token_timeout,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
@ -367,6 +375,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
first_token_timeout: float | None = None,
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
@ -400,6 +409,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=user_id,
|
||||
first_token_timeout=first_token_timeout,
|
||||
)
|
||||
|
||||
return LLMNode.handle_invoke_result(
|
||||
|
||||
54
api/core/workflow/utils/generator_timeout.py
Normal file
54
api/core/workflow/utils/generator_timeout.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""
|
||||
Generator timeout utilities for workflow nodes.
|
||||
|
||||
Provides timeout wrappers for streaming generators, primarily used for
|
||||
LLM response streaming where we need to enforce time-to-first-token limits.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class FirstTokenTimeoutError(Exception):
|
||||
"""Raised when a generator fails to yield its first item within the configured timeout."""
|
||||
|
||||
def __init__(self, timeout_ms: int):
|
||||
self.timeout_ms = timeout_ms
|
||||
super().__init__(f"Generator timed out after {timeout_ms}ms without yielding first item")
|
||||
|
||||
|
||||
def with_first_token_timeout(
|
||||
generator: Generator[T, None, None],
|
||||
timeout_seconds: float,
|
||||
) -> Generator[T, None, None]:
|
||||
"""
|
||||
Wrap a generator with first token timeout monitoring.
|
||||
|
||||
Only monitors the time until the FIRST item is yielded.
|
||||
Once the first item arrives, timeout monitoring stops and
|
||||
subsequent items are yielded without timeout checks.
|
||||
|
||||
Args:
|
||||
generator: The source generator to wrap
|
||||
timeout_seconds: Maximum time to wait for first item (in seconds)
|
||||
|
||||
Yields:
|
||||
Items from the source generator
|
||||
|
||||
Raises:
|
||||
FirstTokenTimeoutError: If first item doesn't arrive within timeout
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
first_token_received = False
|
||||
|
||||
for item in generator:
|
||||
if not first_token_received:
|
||||
current_time = time.monotonic()
|
||||
if current_time - start_time > timeout_seconds:
|
||||
raise FirstTokenTimeoutError(int(timeout_seconds * 1000))
|
||||
first_token_received = True
|
||||
|
||||
yield item
|
||||
@ -0,0 +1,416 @@
|
||||
"""Tests for LLM Node first token timeout retry functionality."""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from core.workflow.nodes.base.entities import RetryConfig
|
||||
from core.workflow.nodes.llm.exc import LLMFirstTokenTimeoutError
|
||||
from core.workflow.utils.generator_timeout import FirstTokenTimeoutError, with_first_token_timeout
|
||||
|
||||
|
||||
class TestRetryConfigFirstTokenTimeout:
|
||||
"""Test cases for RetryConfig first token timeout fields."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test that first token timeout fields have correct default values."""
|
||||
config = RetryConfig()
|
||||
|
||||
assert config.first_token_timeout == 0
|
||||
assert config.has_first_token_timeout is False
|
||||
|
||||
def test_has_first_token_timeout_when_retry_enabled_and_positive(self):
|
||||
"""Test has_first_token_timeout returns True when retry enabled with positive timeout."""
|
||||
config = RetryConfig(
|
||||
retry_enabled=True,
|
||||
first_token_timeout=3000, # 3000ms = 3s
|
||||
)
|
||||
|
||||
assert config.has_first_token_timeout is True
|
||||
assert config.first_token_timeout_seconds == 3.0
|
||||
|
||||
def test_has_first_token_timeout_when_retry_disabled(self):
|
||||
"""Test has_first_token_timeout returns False when retry is disabled."""
|
||||
config = RetryConfig(
|
||||
retry_enabled=False,
|
||||
first_token_timeout=60,
|
||||
)
|
||||
|
||||
assert config.has_first_token_timeout is False
|
||||
|
||||
def test_has_first_token_timeout_when_zero_timeout(self):
|
||||
"""Test has_first_token_timeout returns False when timeout is 0."""
|
||||
config = RetryConfig(
|
||||
retry_enabled=True,
|
||||
first_token_timeout=0,
|
||||
)
|
||||
|
||||
assert config.has_first_token_timeout is False
|
||||
|
||||
def test_backward_compatibility(self):
|
||||
"""Test that existing workflows without first_token_timeout work correctly."""
|
||||
old_config_data = {
|
||||
"max_retries": 3,
|
||||
"retry_interval": 1000,
|
||||
"retry_enabled": True,
|
||||
}
|
||||
|
||||
config = RetryConfig.model_validate(old_config_data)
|
||||
|
||||
assert config.max_retries == 3
|
||||
assert config.retry_interval == 1000
|
||||
assert config.retry_enabled is True
|
||||
assert config.first_token_timeout == 0
|
||||
# has_first_token_timeout is False because timeout is 0
|
||||
assert config.has_first_token_timeout is False
|
||||
|
||||
def test_full_config_serialization(self):
|
||||
"""Test that full config can be serialized and deserialized."""
|
||||
config = RetryConfig(
|
||||
max_retries=5,
|
||||
retry_interval=2000,
|
||||
retry_enabled=True,
|
||||
first_token_timeout=120,
|
||||
)
|
||||
|
||||
config_dict = config.model_dump()
|
||||
restored_config = RetryConfig.model_validate(config_dict)
|
||||
|
||||
assert restored_config.max_retries == 5
|
||||
assert restored_config.retry_interval == 2000
|
||||
assert restored_config.retry_enabled is True
|
||||
assert restored_config.first_token_timeout == 120
|
||||
assert restored_config.has_first_token_timeout is True
|
||||
|
||||
|
||||
class TestLLMFirstTokenTimeoutError:
|
||||
"""Test cases for LLMFirstTokenTimeoutError exception."""
|
||||
|
||||
def test_error_message_format(self):
|
||||
"""Test that error message contains timeout value in milliseconds."""
|
||||
error = LLMFirstTokenTimeoutError(timeout_ms=3000)
|
||||
|
||||
assert "3000ms" in str(error)
|
||||
assert "first token" in str(error).lower()
|
||||
|
||||
def test_inherits_from_llm_node_error(self):
|
||||
"""Test that LLMFirstTokenTimeoutError inherits from LLMNodeError."""
|
||||
from core.workflow.nodes.llm.exc import LLMNodeError
|
||||
|
||||
error = LLMFirstTokenTimeoutError(timeout_ms=3000)
|
||||
|
||||
assert isinstance(error, LLMNodeError)
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
|
||||
class TestWithFirstTokenTimeout:
|
||||
"""Test cases for with_first_token_timeout function."""
|
||||
|
||||
@staticmethod
|
||||
def _create_mock_chunk(text: str = "test") -> LLMResultChunk:
|
||||
"""Helper to create a mock LLMResultChunk."""
|
||||
return LLMResultChunk(
|
||||
model="test-model",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
),
|
||||
)
|
||||
|
||||
def test_first_token_arrives_within_timeout(self):
|
||||
"""Test that chunks are yielded normally when first token arrives in time."""
|
||||
|
||||
def mock_generator() -> Generator[LLMResultChunk, None, None]:
|
||||
yield self._create_mock_chunk("Hello")
|
||||
yield self._create_mock_chunk(" world")
|
||||
|
||||
wrapped = with_first_token_timeout(mock_generator(), timeout_seconds=10)
|
||||
chunks = list(wrapped)
|
||||
|
||||
assert len(chunks) == 2
|
||||
|
||||
def test_first_token_timeout_raises_error(self, monkeypatch):
|
||||
"""Test that timeout error is raised when first token doesn't arrive in time."""
|
||||
call_count = 0
|
||||
|
||||
def mock_monotonic():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# First call: start_time = 0
|
||||
# Second call (when checking): current_time = 11 (exceeds 10 second timeout)
|
||||
if call_count == 1:
|
||||
return 0.0
|
||||
return 11.0
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", mock_monotonic)
|
||||
|
||||
def slow_generator() -> Generator[LLMResultChunk, None, None]:
|
||||
# This chunk arrives "after timeout"
|
||||
yield self._create_mock_chunk("Late token")
|
||||
|
||||
wrapped = with_first_token_timeout(slow_generator(), timeout_seconds=10)
|
||||
|
||||
with pytest.raises(FirstTokenTimeoutError) as exc_info:
|
||||
list(wrapped)
|
||||
|
||||
# Error message shows milliseconds (10 seconds = 10000ms)
|
||||
assert "10000ms" in str(exc_info.value)
|
||||
|
||||
def test_no_timeout_check_after_first_token(self, monkeypatch):
|
||||
"""Test that subsequent chunks are not subject to timeout after first token received."""
|
||||
call_count = 0
|
||||
|
||||
def mock_monotonic():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return 0.0 # start_time
|
||||
elif call_count == 2:
|
||||
return 5.0 # first token arrives at 5s (within 10s timeout)
|
||||
else:
|
||||
# Subsequent calls simulate long delays for remaining chunks
|
||||
# These should NOT trigger timeout because first token already received
|
||||
return 100.0 + call_count
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", mock_monotonic)
|
||||
|
||||
def generator_with_slow_subsequent_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield self._create_mock_chunk("First")
|
||||
yield self._create_mock_chunk("Second")
|
||||
yield self._create_mock_chunk("Third")
|
||||
|
||||
wrapped = with_first_token_timeout(
|
||||
generator_with_slow_subsequent_chunks(),
|
||||
timeout_seconds=10,
|
||||
)
|
||||
|
||||
# Should not raise, even though "time" passes beyond timeout after first token
|
||||
chunks = list(wrapped)
|
||||
assert len(chunks) == 3
|
||||
|
||||
def test_empty_generator_no_error(self):
|
||||
"""Test that empty generator doesn't raise timeout error (no chunks to check)."""
|
||||
|
||||
def empty_generator() -> Generator[LLMResultChunk, None, None]:
|
||||
return
|
||||
yield # unreachable, but makes this a generator
|
||||
|
||||
wrapped = with_first_token_timeout(empty_generator(), timeout_seconds=10)
|
||||
chunks = list(wrapped)
|
||||
|
||||
assert chunks == []
|
||||
|
||||
def test_exact_timeout_boundary(self, monkeypatch):
|
||||
"""Test behavior at exact timeout boundary (should not raise when equal)."""
|
||||
call_count = 0
|
||||
|
||||
def mock_monotonic():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return 0.0
|
||||
# Exactly at boundary: current_time - start_time = 10, timeout_seconds = 10
|
||||
# Since we check > not >=, this should NOT raise
|
||||
return 10.0
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", mock_monotonic)
|
||||
|
||||
def generator() -> Generator[LLMResultChunk, None, None]:
|
||||
yield self._create_mock_chunk("Token at boundary")
|
||||
|
||||
wrapped = with_first_token_timeout(generator(), timeout_seconds=10)
|
||||
|
||||
# Should not raise because 10 is not > 10
|
||||
chunks = list(wrapped)
|
||||
assert len(chunks) == 1
|
||||
|
||||
def test_just_over_timeout_boundary(self, monkeypatch):
|
||||
"""Test behavior just over timeout boundary (should raise)."""
|
||||
call_count = 0
|
||||
|
||||
def mock_monotonic():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return 0.0
|
||||
# Just over boundary
|
||||
return 10.001
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", mock_monotonic)
|
||||
|
||||
def generator() -> Generator[LLMResultChunk, None, None]:
|
||||
yield self._create_mock_chunk("Late token")
|
||||
|
||||
wrapped = with_first_token_timeout(generator(), timeout_seconds=10)
|
||||
|
||||
with pytest.raises(FirstTokenTimeoutError):
|
||||
list(wrapped)
|
||||
|
||||
|
||||
class TestLLMNodeInvokeLLMWithTimeout:
|
||||
"""Test cases for LLMNode.invoke_llm with first_token_timeout parameter."""
|
||||
|
||||
def test_invoke_llm_without_timeout(self):
|
||||
"""Test invoke_llm works normally when first_token_timeout is None."""
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
|
||||
with mock.patch.object(LLMNode, "handle_invoke_result") as mock_handle:
|
||||
mock_handle.return_value = iter([])
|
||||
|
||||
# Mock model_instance.invoke_llm to return empty generator
|
||||
mock_model_instance = mock.MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = iter([])
|
||||
mock_model_instance.model_type_instance.get_model_schema.return_value = mock.MagicMock()
|
||||
|
||||
mock_node_data_model = mock.MagicMock()
|
||||
mock_node_data_model.completion_params = {}
|
||||
|
||||
result = LLMNode.invoke_llm(
|
||||
node_data_model=mock_node_data_model,
|
||||
model_instance=mock_model_instance,
|
||||
prompt_messages=[],
|
||||
user_id="test-user",
|
||||
structured_output_enabled=False,
|
||||
structured_output=None,
|
||||
file_saver=mock.MagicMock(),
|
||||
file_outputs=[],
|
||||
node_id="test-node",
|
||||
node_type=mock.MagicMock(),
|
||||
first_token_timeout=None, # No timeout
|
||||
)
|
||||
|
||||
list(result) # Consume generator
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_invoke_llm_with_timeout_passes_to_model_instance(self):
|
||||
"""Test invoke_llm passes first_token_timeout to model_instance.invoke_llm."""
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
|
||||
with mock.patch.object(LLMNode, "handle_invoke_result") as mock_handle:
|
||||
mock_handle.return_value = iter([])
|
||||
|
||||
mock_model_instance = mock.MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = iter([])
|
||||
mock_model_instance.model_type_instance.get_model_schema.return_value = mock.MagicMock()
|
||||
|
||||
mock_node_data_model = mock.MagicMock()
|
||||
mock_node_data_model.completion_params = {}
|
||||
|
||||
result = LLMNode.invoke_llm(
|
||||
node_data_model=mock_node_data_model,
|
||||
model_instance=mock_model_instance,
|
||||
prompt_messages=[],
|
||||
user_id="test-user",
|
||||
structured_output_enabled=False,
|
||||
structured_output=None,
|
||||
file_saver=mock.MagicMock(),
|
||||
file_outputs=[],
|
||||
node_id="test-node",
|
||||
node_type=mock.MagicMock(),
|
||||
first_token_timeout=60, # With timeout
|
||||
)
|
||||
|
||||
list(result) # Consume generator
|
||||
|
||||
# Verify model_instance.invoke_llm was called with first_token_timeout
|
||||
mock_model_instance.invoke_llm.assert_called_once()
|
||||
call_kwargs = mock_model_instance.invoke_llm.call_args.kwargs
|
||||
assert call_kwargs.get("first_token_timeout") == 60
|
||||
|
||||
def test_invoke_llm_with_zero_timeout_passes_zero(self):
|
||||
"""Test invoke_llm passes zero timeout to model_instance."""
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
|
||||
with mock.patch.object(LLMNode, "handle_invoke_result") as mock_handle:
|
||||
mock_handle.return_value = iter([])
|
||||
|
||||
mock_model_instance = mock.MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = iter([])
|
||||
mock_model_instance.model_type_instance.get_model_schema.return_value = mock.MagicMock()
|
||||
|
||||
mock_node_data_model = mock.MagicMock()
|
||||
mock_node_data_model.completion_params = {}
|
||||
|
||||
result = LLMNode.invoke_llm(
|
||||
node_data_model=mock_node_data_model,
|
||||
model_instance=mock_model_instance,
|
||||
prompt_messages=[],
|
||||
user_id="test-user",
|
||||
structured_output_enabled=False,
|
||||
structured_output=None,
|
||||
file_saver=mock.MagicMock(),
|
||||
file_outputs=[],
|
||||
node_id="test-node",
|
||||
node_type=mock.MagicMock(),
|
||||
first_token_timeout=0, # Zero timeout
|
||||
)
|
||||
|
||||
list(result) # Consume generator
|
||||
|
||||
# Verify model_instance.invoke_llm was called with zero timeout
|
||||
mock_model_instance.invoke_llm.assert_called_once()
|
||||
call_kwargs = mock_model_instance.invoke_llm.call_args.kwargs
|
||||
assert call_kwargs.get("first_token_timeout") == 0
|
||||
|
||||
|
||||
class TestRetryConfigIntegration:
|
||||
"""Integration tests for RetryConfig with LLM node data."""
|
||||
|
||||
def test_retry_config_in_node_data(self):
|
||||
"""Test RetryConfig can be properly configured in LLMNodeData."""
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig
|
||||
|
||||
node_data = LLMNodeData(
|
||||
title="Test LLM",
|
||||
model=ModelConfig(
|
||||
provider="openai",
|
||||
name="gpt-4",
|
||||
mode=LLMMode.CHAT,
|
||||
completion_params={},
|
||||
),
|
||||
prompt_template=[],
|
||||
context=ContextConfig(enabled=False),
|
||||
structured_output_enabled=False,
|
||||
retry_config=RetryConfig(
|
||||
max_retries=3,
|
||||
retry_interval=1000,
|
||||
retry_enabled=True,
|
||||
first_token_timeout=3000, # 3000ms = 3s
|
||||
),
|
||||
)
|
||||
|
||||
assert node_data.retry_config.max_retries == 3
|
||||
assert node_data.retry_config.retry_enabled is True
|
||||
assert node_data.retry_config.first_token_timeout == 3000
|
||||
assert node_data.retry_config.first_token_timeout_seconds == 3.0
|
||||
assert node_data.retry_config.has_first_token_timeout is True
|
||||
|
||||
def test_default_retry_config_in_node_data(self):
|
||||
"""Test default RetryConfig in LLMNodeData."""
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig
|
||||
|
||||
node_data = LLMNodeData(
|
||||
title="Test LLM",
|
||||
model=ModelConfig(
|
||||
provider="openai",
|
||||
name="gpt-4",
|
||||
mode=LLMMode.CHAT,
|
||||
completion_params={},
|
||||
),
|
||||
prompt_template=[],
|
||||
context=ContextConfig(enabled=False),
|
||||
structured_output_enabled=False,
|
||||
)
|
||||
|
||||
# Should have default RetryConfig
|
||||
assert node_data.retry_config.max_retries == 0
|
||||
assert node_data.retry_config.retry_enabled is False
|
||||
assert node_data.retry_config.first_token_timeout == 0
|
||||
assert node_data.retry_config.has_first_token_timeout is False
|
||||
Reference in New Issue
Block a user