mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
feat(api): Human Input Node (backend part) (#31646)
The backend part of the human in the loop (HITL) feature and relevant architecture / workflow engine changes. Signed-off-by: yihong0618 <zouzou0208@gmail.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: 盐粒 Yanli <yanli@dify.ai> Co-authored-by: CrabSAMA <40541269+CrabSAMA@users.noreply.github.com> Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
1
api/tests/unit_tests/libs/_human_input/__init__.py
Normal file
1
api/tests/unit_tests/libs/_human_input/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Treat this directory as a package so support modules can be imported relatively.
|
||||
249
api/tests/unit_tests/libs/_human_input/support.py
Normal file
249
api/tests/unit_tests/libs/_human_input/support.py
Normal file
@ -0,0 +1,249 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormInput
|
||||
from core.workflow.nodes.human_input.enums import TimeoutUnit
|
||||
|
||||
|
||||
# Exceptions
|
||||
class HumanInputError(Exception):
|
||||
error_code: str = "unknown"
|
||||
|
||||
def __init__(self, message: str = "", error_code: str | None = None):
|
||||
super().__init__(message)
|
||||
self.message = message or self.__class__.__name__
|
||||
if error_code:
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class FormNotFoundError(HumanInputError):
|
||||
error_code = "form_not_found"
|
||||
|
||||
|
||||
class FormExpiredError(HumanInputError):
|
||||
error_code = "human_input_form_expired"
|
||||
|
||||
|
||||
class FormAlreadySubmittedError(HumanInputError):
|
||||
error_code = "human_input_form_submitted"
|
||||
|
||||
|
||||
class InvalidFormDataError(HumanInputError):
|
||||
error_code = "invalid_form_data"
|
||||
|
||||
|
||||
# Models
|
||||
@dataclass
|
||||
class HumanInputForm:
|
||||
form_id: str
|
||||
workflow_run_id: str
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
app_id: str | None
|
||||
form_content: str
|
||||
inputs: list[FormInput]
|
||||
user_actions: list[dict[str, Any]]
|
||||
timeout: int
|
||||
timeout_unit: TimeoutUnit
|
||||
form_token: str | None = None
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
expires_at: datetime | None = None
|
||||
submitted_at: datetime | None = None
|
||||
submitted_data: dict[str, Any] | None = None
|
||||
submitted_action: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.expires_at is None:
|
||||
self.calculate_expiration()
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return self.expires_at is not None and datetime.utcnow() > self.expires_at
|
||||
|
||||
@property
|
||||
def is_submitted(self) -> bool:
|
||||
return self.submitted_at is not None
|
||||
|
||||
def mark_submitted(self, inputs: dict[str, Any], action: str) -> None:
|
||||
self.submitted_data = inputs
|
||||
self.submitted_action = action
|
||||
self.submitted_at = datetime.utcnow()
|
||||
|
||||
def submit(self, inputs: dict[str, Any], action: str) -> None:
|
||||
self.mark_submitted(inputs, action)
|
||||
|
||||
def calculate_expiration(self) -> None:
|
||||
start = self.created_at
|
||||
if self.timeout_unit == TimeoutUnit.HOUR:
|
||||
self.expires_at = start + timedelta(hours=self.timeout)
|
||||
elif self.timeout_unit == TimeoutUnit.DAY:
|
||||
self.expires_at = start + timedelta(days=self.timeout)
|
||||
else:
|
||||
raise ValueError(f"Unsupported timeout unit {self.timeout_unit}")
|
||||
|
||||
def to_response_dict(self, *, include_site_info: bool) -> dict[str, Any]:
|
||||
inputs_response = [
|
||||
{
|
||||
"type": form_input.type.name.lower().replace("_", "-"),
|
||||
"output_variable_name": form_input.output_variable_name,
|
||||
}
|
||||
for form_input in self.inputs
|
||||
]
|
||||
response = {
|
||||
"form_content": self.form_content,
|
||||
"inputs": inputs_response,
|
||||
"user_actions": self.user_actions,
|
||||
}
|
||||
if include_site_info:
|
||||
response["site"] = {"app_id": self.app_id, "title": "Workflow Form"}
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormSubmissionData:
|
||||
form_id: str
|
||||
inputs: dict[str, Any]
|
||||
action: str
|
||||
submitted_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, form_id: str, request: FormSubmissionRequest) -> FormSubmissionData: # type: ignore
|
||||
return cls(form_id=form_id, inputs=request.inputs, action=request.action)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormSubmissionRequest:
|
||||
inputs: dict[str, Any]
|
||||
action: str
|
||||
|
||||
|
||||
# Repository
|
||||
class InMemoryFormRepository:
|
||||
"""
|
||||
Simple in-memory repository used by unit tests.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._forms: dict[str, HumanInputForm] = {}
|
||||
|
||||
@property
|
||||
def forms(self) -> dict[str, HumanInputForm]:
|
||||
return self._forms
|
||||
|
||||
def save(self, form: HumanInputForm) -> None:
|
||||
self._forms[form.form_id] = form
|
||||
|
||||
def get_by_id(self, form_id: str) -> HumanInputForm | None:
|
||||
return self._forms.get(form_id)
|
||||
|
||||
def get_by_token(self, token: str) -> HumanInputForm | None:
|
||||
for form in self._forms.values():
|
||||
if form.form_token == token:
|
||||
return form
|
||||
return None
|
||||
|
||||
def delete(self, form_id: str) -> None:
|
||||
self._forms.pop(form_id, None)
|
||||
|
||||
|
||||
# Service
|
||||
class FormService:
|
||||
"""Service layer for managing human input forms in tests."""
|
||||
|
||||
def __init__(self, repository: InMemoryFormRepository):
|
||||
self.repository = repository
|
||||
|
||||
def create_form(
|
||||
self,
|
||||
*,
|
||||
form_id: str,
|
||||
workflow_run_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
app_id: str | None,
|
||||
form_content: str,
|
||||
inputs,
|
||||
user_actions,
|
||||
timeout: int,
|
||||
timeout_unit: TimeoutUnit,
|
||||
form_token: str | None = None,
|
||||
) -> HumanInputForm:
|
||||
form = HumanInputForm(
|
||||
form_id=form_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id=node_id,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
form_content=form_content,
|
||||
inputs=list(inputs),
|
||||
user_actions=[{"id": action.id, "title": action.title} for action in user_actions],
|
||||
timeout=timeout,
|
||||
timeout_unit=timeout_unit,
|
||||
form_token=form_token,
|
||||
)
|
||||
form.calculate_expiration()
|
||||
self.repository.save(form)
|
||||
return form
|
||||
|
||||
def get_form_by_id(self, form_id: str) -> HumanInputForm:
|
||||
form = self.repository.get_by_id(form_id)
|
||||
if form is None:
|
||||
raise FormNotFoundError()
|
||||
return form
|
||||
|
||||
def get_form_by_token(self, token: str) -> HumanInputForm:
|
||||
form = self.repository.get_by_token(token)
|
||||
if form is None:
|
||||
raise FormNotFoundError()
|
||||
return form
|
||||
|
||||
def get_form_definition(self, form_id: str, *, is_token: bool) -> dict:
|
||||
form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id)
|
||||
if form.is_expired:
|
||||
raise FormExpiredError()
|
||||
if form.is_submitted:
|
||||
raise FormAlreadySubmittedError()
|
||||
|
||||
definition = {
|
||||
"form_content": form.form_content,
|
||||
"inputs": form.inputs,
|
||||
"user_actions": form.user_actions,
|
||||
}
|
||||
if is_token:
|
||||
definition["site"] = {"title": "Workflow Form"}
|
||||
return definition
|
||||
|
||||
def submit_form(self, form_id: str, submission_data: FormSubmissionData, *, is_token: bool) -> None:
|
||||
form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id)
|
||||
if form.is_expired:
|
||||
raise FormExpiredError()
|
||||
if form.is_submitted:
|
||||
raise FormAlreadySubmittedError()
|
||||
|
||||
self._validate_submission(form=form, submission_data=submission_data)
|
||||
form.mark_submitted(inputs=submission_data.inputs, action=submission_data.action)
|
||||
self.repository.save(form)
|
||||
|
||||
def cleanup_expired_forms(self) -> int:
|
||||
expired_ids = [form_id for form_id, form in list(self.repository.forms.items()) if form.is_expired]
|
||||
for form_id in expired_ids:
|
||||
self.repository.delete(form_id)
|
||||
return len(expired_ids)
|
||||
|
||||
def _validate_submission(self, form: HumanInputForm, submission_data: FormSubmissionData) -> None:
|
||||
defined_actions = {action["id"] for action in form.user_actions}
|
||||
if submission_data.action not in defined_actions:
|
||||
raise InvalidFormDataError(f"Invalid action: {submission_data.action}")
|
||||
|
||||
missing_inputs = []
|
||||
for form_input in form.inputs:
|
||||
if form_input.output_variable_name not in submission_data.inputs:
|
||||
missing_inputs.append(form_input.output_variable_name)
|
||||
|
||||
if missing_inputs:
|
||||
raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}")
|
||||
|
||||
# Extra inputs are allowed; no further validation required.
|
||||
326
api/tests/unit_tests/libs/_human_input/test_form_service.py
Normal file
326
api/tests/unit_tests/libs/_human_input/test_form_service.py
Normal file
@ -0,0 +1,326 @@
|
||||
"""
|
||||
Unit tests for FormService.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
FormInput,
|
||||
UserAction,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import (
|
||||
FormInputType,
|
||||
TimeoutUnit,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .support import (
|
||||
FormAlreadySubmittedError,
|
||||
FormExpiredError,
|
||||
FormNotFoundError,
|
||||
FormService,
|
||||
FormSubmissionData,
|
||||
InMemoryFormRepository,
|
||||
InvalidFormDataError,
|
||||
)
|
||||
|
||||
|
||||
class TestFormService:
|
||||
"""Test FormService functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self):
|
||||
"""Create in-memory repository for testing."""
|
||||
return InMemoryFormRepository()
|
||||
|
||||
@pytest.fixture
|
||||
def form_service(self, repository):
|
||||
"""Create FormService with in-memory repository."""
|
||||
return FormService(repository)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_form_data(self):
|
||||
"""Create sample form data."""
|
||||
return {
|
||||
"form_id": "form-123",
|
||||
"workflow_run_id": "run-456",
|
||||
"node_id": "node-789",
|
||||
"tenant_id": "tenant-abc",
|
||||
"app_id": "app-def",
|
||||
"form_content": "# Test Form\n\nInput: {{#$output.input#}}",
|
||||
"inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="input", default=None)],
|
||||
"user_actions": [UserAction(id="submit", title="Submit")],
|
||||
"timeout": 1,
|
||||
"timeout_unit": TimeoutUnit.HOUR,
|
||||
"form_token": "token-xyz",
|
||||
}
|
||||
|
||||
def test_create_form(self, form_service, sample_form_data):
|
||||
"""Test form creation."""
|
||||
form = form_service.create_form(**sample_form_data)
|
||||
|
||||
assert form.form_id == "form-123"
|
||||
assert form.workflow_run_id == "run-456"
|
||||
assert form.node_id == "node-789"
|
||||
assert form.tenant_id == "tenant-abc"
|
||||
assert form.app_id == "app-def"
|
||||
assert form.form_token == "token-xyz"
|
||||
assert form.timeout == 1
|
||||
assert form.timeout_unit == TimeoutUnit.HOUR
|
||||
assert form.expires_at is not None
|
||||
assert not form.is_expired
|
||||
assert not form.is_submitted
|
||||
|
||||
def test_get_form_by_id(self, form_service, sample_form_data):
|
||||
"""Test getting form by ID."""
|
||||
# Create form first
|
||||
created_form = form_service.create_form(**sample_form_data)
|
||||
|
||||
# Retrieve form
|
||||
retrieved_form = form_service.get_form_by_id("form-123")
|
||||
|
||||
assert retrieved_form.form_id == created_form.form_id
|
||||
assert retrieved_form.workflow_run_id == created_form.workflow_run_id
|
||||
|
||||
def test_get_form_by_id_not_found(self, form_service):
|
||||
"""Test getting non-existent form by ID."""
|
||||
with pytest.raises(FormNotFoundError) as exc_info:
|
||||
form_service.get_form_by_id("non-existent-form")
|
||||
|
||||
assert exc_info.value.error_code == "form_not_found"
|
||||
|
||||
def test_get_form_by_token(self, form_service, sample_form_data):
|
||||
"""Test getting form by token."""
|
||||
# Create form first
|
||||
created_form = form_service.create_form(**sample_form_data)
|
||||
|
||||
# Retrieve form by token
|
||||
retrieved_form = form_service.get_form_by_token("token-xyz")
|
||||
|
||||
assert retrieved_form.form_id == created_form.form_id
|
||||
assert retrieved_form.form_token == "token-xyz"
|
||||
|
||||
def test_get_form_by_token_not_found(self, form_service):
|
||||
"""Test getting non-existent form by token."""
|
||||
with pytest.raises(FormNotFoundError) as exc_info:
|
||||
form_service.get_form_by_token("non-existent-token")
|
||||
|
||||
assert exc_info.value.error_code == "form_not_found"
|
||||
|
||||
def test_get_form_definition_by_id(self, form_service, sample_form_data):
|
||||
"""Test getting form definition by ID."""
|
||||
# Create form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Get form definition
|
||||
definition = form_service.get_form_definition("form-123", is_token=False)
|
||||
|
||||
assert "form_content" in definition
|
||||
assert "inputs" in definition
|
||||
assert definition["form_content"] == "# Test Form\n\nInput: {{#$output.input#}}"
|
||||
assert len(definition["inputs"]) == 1
|
||||
assert "site" not in definition # Should not include site info for ID-based access
|
||||
|
||||
def test_get_form_definition_by_token(self, form_service, sample_form_data):
|
||||
"""Test getting form definition by token."""
|
||||
# Create form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Get form definition
|
||||
definition = form_service.get_form_definition("token-xyz", is_token=True)
|
||||
|
||||
assert "form_content" in definition
|
||||
assert "inputs" in definition
|
||||
assert "site" in definition # Should include site info for token-based access
|
||||
|
||||
def test_get_form_definition_expired_form(self, form_service, sample_form_data):
|
||||
"""Test getting definition for expired form."""
|
||||
# Create form with past expiry
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Manually expire the form by modifying expiry time
|
||||
form = form_service.get_form_by_id("form-123")
|
||||
form.expires_at = datetime.utcnow() - timedelta(hours=1)
|
||||
form_service.repository.save(form)
|
||||
|
||||
# Should raise FormExpiredError
|
||||
with pytest.raises(FormExpiredError) as exc_info:
|
||||
form_service.get_form_definition("form-123", is_token=False)
|
||||
|
||||
assert exc_info.value.error_code == "human_input_form_expired"
|
||||
|
||||
def test_get_form_definition_submitted_form(self, form_service, sample_form_data):
|
||||
"""Test getting definition for already submitted form."""
|
||||
# Create form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Submit the form
|
||||
submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit")
|
||||
form_service.submit_form("form-123", submission_data, is_token=False)
|
||||
|
||||
# Should raise FormAlreadySubmittedError
|
||||
with pytest.raises(FormAlreadySubmittedError) as exc_info:
|
||||
form_service.get_form_definition("form-123", is_token=False)
|
||||
|
||||
assert exc_info.value.error_code == "human_input_form_submitted"
|
||||
|
||||
def test_submit_form_success(self, form_service, sample_form_data):
|
||||
"""Test successful form submission."""
|
||||
# Create form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Submit form
|
||||
submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit")
|
||||
|
||||
# Should not raise any exception
|
||||
form_service.submit_form("form-123", submission_data, is_token=False)
|
||||
|
||||
# Verify form is marked as submitted
|
||||
form = form_service.get_form_by_id("form-123")
|
||||
assert form.is_submitted
|
||||
assert form.submitted_data == {"input": "test value"}
|
||||
assert form.submitted_action == "submit"
|
||||
assert form.submitted_at is not None
|
||||
|
||||
def test_submit_form_missing_inputs(self, form_service, sample_form_data):
|
||||
"""Test form submission with missing inputs."""
|
||||
# Create form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Submit form with missing required input
|
||||
submission_data = FormSubmissionData(
|
||||
form_id="form-123",
|
||||
inputs={}, # Missing required "input" field
|
||||
action="submit",
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidFormDataError) as exc_info:
|
||||
form_service.submit_form("form-123", submission_data, is_token=False)
|
||||
|
||||
assert "Missing required inputs" in exc_info.value.message
|
||||
assert "input" in exc_info.value.message
|
||||
|
||||
def test_submit_form_invalid_action(self, form_service, sample_form_data):
|
||||
"""Test form submission with invalid action."""
|
||||
# Create form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Submit form with invalid action
|
||||
submission_data = FormSubmissionData(
|
||||
form_id="form-123",
|
||||
inputs={"input": "test value"},
|
||||
action="invalid_action", # Not in the allowed actions
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidFormDataError) as exc_info:
|
||||
form_service.submit_form("form-123", submission_data, is_token=False)
|
||||
|
||||
assert "Invalid action" in exc_info.value.message
|
||||
assert "invalid_action" in exc_info.value.message
|
||||
|
||||
def test_submit_form_expired(self, form_service, sample_form_data):
|
||||
"""Test submitting expired form."""
|
||||
# Create form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Manually expire the form
|
||||
form = form_service.get_form_by_id("form-123")
|
||||
form.expires_at = datetime.utcnow() - timedelta(hours=1)
|
||||
form_service.repository.save(form)
|
||||
|
||||
# Try to submit expired form
|
||||
submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit")
|
||||
|
||||
with pytest.raises(FormExpiredError) as exc_info:
|
||||
form_service.submit_form("form-123", submission_data, is_token=False)
|
||||
|
||||
assert exc_info.value.error_code == "human_input_form_expired"
|
||||
|
||||
def test_submit_form_already_submitted(self, form_service, sample_form_data):
|
||||
"""Test submitting form that's already submitted."""
|
||||
# Create and submit form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "first submission"}, action="submit")
|
||||
form_service.submit_form("form-123", submission_data, is_token=False)
|
||||
|
||||
# Try to submit again
|
||||
second_submission = FormSubmissionData(
|
||||
form_id="form-123", inputs={"input": "second submission"}, action="submit"
|
||||
)
|
||||
|
||||
with pytest.raises(FormAlreadySubmittedError) as exc_info:
|
||||
form_service.submit_form("form-123", second_submission, is_token=False)
|
||||
|
||||
assert exc_info.value.error_code == "human_input_form_submitted"
|
||||
|
||||
def test_cleanup_expired_forms(self, form_service, sample_form_data):
|
||||
"""Test cleanup of expired forms."""
|
||||
# Create multiple forms
|
||||
for i in range(3):
|
||||
data = sample_form_data.copy()
|
||||
data["form_id"] = f"form-{i}"
|
||||
data["form_token"] = f"token-{i}"
|
||||
form_service.create_form(**data)
|
||||
|
||||
# Manually expire some forms
|
||||
for i in range(2): # Expire first 2 forms
|
||||
form = form_service.get_form_by_id(f"form-{i}")
|
||||
form.expires_at = naive_utc_now() - timedelta(hours=1)
|
||||
form_service.repository.save(form)
|
||||
|
||||
# Clean up expired forms
|
||||
cleaned_count = form_service.cleanup_expired_forms()
|
||||
|
||||
assert cleaned_count == 2
|
||||
|
||||
# Verify expired forms are gone
|
||||
with pytest.raises(FormNotFoundError):
|
||||
form_service.get_form_by_id("form-0")
|
||||
|
||||
with pytest.raises(FormNotFoundError):
|
||||
form_service.get_form_by_id("form-1")
|
||||
|
||||
# Verify non-expired form still exists
|
||||
form = form_service.get_form_by_id("form-2")
|
||||
assert form.form_id == "form-2"
|
||||
|
||||
|
||||
class TestFormValidation:
|
||||
"""Test form validation logic."""
|
||||
|
||||
def test_validate_submission_with_extra_inputs(self):
|
||||
"""Test validation allows extra inputs that aren't defined in form."""
|
||||
repository = InMemoryFormRepository()
|
||||
form_service = FormService(repository)
|
||||
|
||||
# Create form with one input
|
||||
form_data = {
|
||||
"form_id": "form-123",
|
||||
"workflow_run_id": "run-456",
|
||||
"node_id": "node-789",
|
||||
"tenant_id": "tenant-abc",
|
||||
"app_id": "app-def",
|
||||
"form_content": "Test form",
|
||||
"inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="required_input", default=None)],
|
||||
"user_actions": [UserAction(id="submit", title="Submit")],
|
||||
"timeout": 1,
|
||||
"timeout_unit": TimeoutUnit.HOUR,
|
||||
}
|
||||
|
||||
form_service.create_form(**form_data)
|
||||
|
||||
# Submit with extra input (should be allowed)
|
||||
submission_data = FormSubmissionData(
|
||||
form_id="form-123",
|
||||
inputs={
|
||||
"required_input": "value1",
|
||||
"extra_input": "value2", # Extra input not defined in form
|
||||
},
|
||||
action="submit",
|
||||
)
|
||||
|
||||
# Should not raise any exception
|
||||
form_service.submit_form("form-123", submission_data, is_token=False)
|
||||
232
api/tests/unit_tests/libs/_human_input/test_models.py
Normal file
232
api/tests/unit_tests/libs/_human_input/test_models.py
Normal file
@ -0,0 +1,232 @@
|
||||
"""
|
||||
Unit tests for human input form models.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
FormInput,
|
||||
UserAction,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import (
|
||||
FormInputType,
|
||||
TimeoutUnit,
|
||||
)
|
||||
|
||||
from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm
|
||||
|
||||
|
||||
class TestHumanInputForm:
|
||||
"""Test HumanInputForm model."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_form_data(self):
|
||||
"""Create sample form data."""
|
||||
return {
|
||||
"form_id": "form-123",
|
||||
"workflow_run_id": "run-456",
|
||||
"node_id": "node-789",
|
||||
"tenant_id": "tenant-abc",
|
||||
"app_id": "app-def",
|
||||
"form_content": "# Test Form\n\nInput: {{#$output.input#}}",
|
||||
"inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="input", default=None)],
|
||||
"user_actions": [UserAction(id="submit", title="Submit")],
|
||||
"timeout": 2,
|
||||
"timeout_unit": TimeoutUnit.HOUR,
|
||||
"form_token": "token-xyz",
|
||||
}
|
||||
|
||||
def test_form_creation(self, sample_form_data):
|
||||
"""Test form creation."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
|
||||
assert form.form_id == "form-123"
|
||||
assert form.workflow_run_id == "run-456"
|
||||
assert form.node_id == "node-789"
|
||||
assert form.tenant_id == "tenant-abc"
|
||||
assert form.app_id == "app-def"
|
||||
assert form.form_token == "token-xyz"
|
||||
assert form.timeout == 2
|
||||
assert form.timeout_unit == TimeoutUnit.HOUR
|
||||
assert form.created_at is not None
|
||||
assert form.expires_at is not None
|
||||
assert form.submitted_at is None
|
||||
assert form.submitted_data is None
|
||||
assert form.submitted_action is None
|
||||
|
||||
def test_form_expiry_calculation_hours(self, sample_form_data):
|
||||
"""Test form expiry calculation for hours."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
|
||||
# Should expire 2 hours after creation
|
||||
expected_expiry = form.created_at + timedelta(hours=2)
|
||||
assert abs((form.expires_at - expected_expiry).total_seconds()) < 1 # Within 1 second
|
||||
|
||||
def test_form_expiry_calculation_days(self, sample_form_data):
|
||||
"""Test form expiry calculation for days."""
|
||||
sample_form_data["timeout"] = 3
|
||||
sample_form_data["timeout_unit"] = TimeoutUnit.DAY
|
||||
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
|
||||
# Should expire 3 days after creation
|
||||
expected_expiry = form.created_at + timedelta(days=3)
|
||||
assert abs((form.expires_at - expected_expiry).total_seconds()) < 1 # Within 1 second
|
||||
|
||||
def test_form_expiry_property_not_expired(self, sample_form_data):
|
||||
"""Test is_expired property for non-expired form."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
assert not form.is_expired
|
||||
|
||||
def test_form_expiry_property_expired(self, sample_form_data):
|
||||
"""Test is_expired property for expired form."""
|
||||
# Create form with past expiry
|
||||
past_time = datetime.utcnow() - timedelta(hours=1)
|
||||
sample_form_data["created_at"] = past_time
|
||||
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
# Manually set expiry to past time
|
||||
form.expires_at = past_time
|
||||
|
||||
assert form.is_expired
|
||||
|
||||
def test_form_submission_property_not_submitted(self, sample_form_data):
|
||||
"""Test is_submitted property for non-submitted form."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
assert not form.is_submitted
|
||||
|
||||
def test_form_submission_property_submitted(self, sample_form_data):
|
||||
"""Test is_submitted property for submitted form."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
form.submit({"input": "test value"}, "submit")
|
||||
|
||||
assert form.is_submitted
|
||||
assert form.submitted_at is not None
|
||||
assert form.submitted_data == {"input": "test value"}
|
||||
assert form.submitted_action == "submit"
|
||||
|
||||
def test_form_submit_method(self, sample_form_data):
|
||||
"""Test form submit method."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
|
||||
submission_time_before = datetime.utcnow()
|
||||
form.submit({"input": "test value"}, "submit")
|
||||
submission_time_after = datetime.utcnow()
|
||||
|
||||
assert form.is_submitted
|
||||
assert form.submitted_data == {"input": "test value"}
|
||||
assert form.submitted_action == "submit"
|
||||
assert submission_time_before <= form.submitted_at <= submission_time_after
|
||||
|
||||
def test_form_to_response_dict_without_site_info(self, sample_form_data):
|
||||
"""Test converting form to response dict without site info."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
|
||||
response = form.to_response_dict(include_site_info=False)
|
||||
|
||||
assert "form_content" in response
|
||||
assert "inputs" in response
|
||||
assert "site" not in response
|
||||
assert response["form_content"] == "# Test Form\n\nInput: {{#$output.input#}}"
|
||||
assert len(response["inputs"]) == 1
|
||||
assert response["inputs"][0]["type"] == "text-input"
|
||||
assert response["inputs"][0]["output_variable_name"] == "input"
|
||||
|
||||
def test_form_to_response_dict_with_site_info(self, sample_form_data):
|
||||
"""Test converting form to response dict with site info."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
|
||||
response = form.to_response_dict(include_site_info=True)
|
||||
|
||||
assert "form_content" in response
|
||||
assert "inputs" in response
|
||||
assert "site" in response
|
||||
assert response["site"]["app_id"] == "app-def"
|
||||
assert response["site"]["title"] == "Workflow Form"
|
||||
|
||||
def test_form_without_web_app_token(self, sample_form_data):
|
||||
"""Test form creation without web app token."""
|
||||
sample_form_data["form_token"] = None
|
||||
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
|
||||
assert form.form_token is None
|
||||
assert form.form_id == "form-123" # Other fields should still work
|
||||
|
||||
def test_form_with_explicit_timestamps(self):
|
||||
"""Test form creation with explicit timestamps."""
|
||||
created_time = datetime(2024, 1, 15, 10, 30, 0)
|
||||
expires_time = datetime(2024, 1, 15, 12, 30, 0)
|
||||
|
||||
form = HumanInputForm(
|
||||
form_id="form-123",
|
||||
workflow_run_id="run-456",
|
||||
node_id="node-789",
|
||||
tenant_id="tenant-abc",
|
||||
app_id="app-def",
|
||||
form_content="Test content",
|
||||
inputs=[],
|
||||
user_actions=[],
|
||||
timeout=2,
|
||||
timeout_unit=TimeoutUnit.HOUR,
|
||||
created_at=created_time,
|
||||
expires_at=expires_time,
|
||||
)
|
||||
|
||||
assert form.created_at == created_time
|
||||
assert form.expires_at == expires_time
|
||||
|
||||
|
||||
class TestFormSubmissionData:
|
||||
"""Test FormSubmissionData model."""
|
||||
|
||||
def test_submission_data_creation(self):
|
||||
"""Test submission data creation."""
|
||||
submission_data = FormSubmissionData(
|
||||
form_id="form-123", inputs={"field1": "value1", "field2": "value2"}, action="submit"
|
||||
)
|
||||
|
||||
assert submission_data.form_id == "form-123"
|
||||
assert submission_data.inputs == {"field1": "value1", "field2": "value2"}
|
||||
assert submission_data.action == "submit"
|
||||
assert submission_data.submitted_at is not None
|
||||
|
||||
def test_submission_data_from_request(self):
|
||||
"""Test creating submission data from API request."""
|
||||
request = FormSubmissionRequest(inputs={"input": "test value"}, action="confirm")
|
||||
|
||||
submission_data = FormSubmissionData.from_request("form-456", request)
|
||||
|
||||
assert submission_data.form_id == "form-456"
|
||||
assert submission_data.inputs == {"input": "test value"}
|
||||
assert submission_data.action == "confirm"
|
||||
assert submission_data.submitted_at is not None
|
||||
|
||||
def test_submission_data_with_empty_inputs(self):
|
||||
"""Test submission data with empty inputs."""
|
||||
submission_data = FormSubmissionData(form_id="form-123", inputs={}, action="cancel")
|
||||
|
||||
assert submission_data.inputs == {}
|
||||
assert submission_data.action == "cancel"
|
||||
|
||||
def test_submission_data_timestamps(self):
|
||||
"""Test submission data timestamp handling."""
|
||||
before_time = datetime.utcnow()
|
||||
|
||||
submission_data = FormSubmissionData(form_id="form-123", inputs={"test": "value"}, action="submit")
|
||||
|
||||
after_time = datetime.utcnow()
|
||||
|
||||
assert before_time <= submission_data.submitted_at <= after_time
|
||||
|
||||
def test_submission_data_with_explicit_timestamp(self):
|
||||
"""Test submission data with explicit timestamp."""
|
||||
specific_time = datetime(2024, 1, 15, 14, 30, 0)
|
||||
|
||||
submission_data = FormSubmissionData(
|
||||
form_id="form-123", inputs={"test": "value"}, action="submit", submitted_at=specific_time
|
||||
)
|
||||
|
||||
assert submission_data.submitted_at == specific_time
|
||||
@ -1,6 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.helper import escape_like_pattern, extract_tenant_id
|
||||
from libs.helper import OptionalTimestampField, escape_like_pattern, extract_tenant_id
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
@ -65,6 +67,19 @@ class TestExtractTenantId:
|
||||
extract_tenant_id(dict_user)
|
||||
|
||||
|
||||
class TestOptionalTimestampField:
|
||||
def test_format_returns_none_for_none(self):
|
||||
field = OptionalTimestampField()
|
||||
|
||||
assert field.format(None) is None
|
||||
|
||||
def test_format_returns_unix_timestamp_for_datetime(self):
|
||||
field = OptionalTimestampField()
|
||||
value = datetime(2024, 1, 2, 3, 4, 5)
|
||||
|
||||
assert field.format(value) == int(value.timestamp())
|
||||
|
||||
|
||||
class TestEscapeLikePattern:
|
||||
"""Test cases for the escape_like_pattern utility function."""
|
||||
|
||||
|
||||
68
api/tests/unit_tests/libs/test_rate_limiter.py
Normal file
68
api/tests/unit_tests/libs/test_rate_limiter.py
Normal file
@ -0,0 +1,68 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from libs import helper as helper_module
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self) -> None:
|
||||
self._zsets: dict[str, dict[str, float]] = {}
|
||||
self._expiry: dict[str, int] = {}
|
||||
|
||||
def zadd(self, key: str, mapping: dict[str, float]) -> int:
|
||||
zset = self._zsets.setdefault(key, {})
|
||||
for member, score in mapping.items():
|
||||
zset[str(member)] = float(score)
|
||||
return len(mapping)
|
||||
|
||||
def zremrangebyscore(self, key: str, min_score: str | float, max_score: str | float) -> int:
|
||||
zset = self._zsets.get(key, {})
|
||||
min_value = float("-inf") if min_score == "-inf" else float(min_score)
|
||||
max_value = float("inf") if max_score == "+inf" else float(max_score)
|
||||
to_delete = [member for member, score in zset.items() if min_value <= score <= max_value]
|
||||
for member in to_delete:
|
||||
del zset[member]
|
||||
return len(to_delete)
|
||||
|
||||
def zcard(self, key: str) -> int:
|
||||
return len(self._zsets.get(key, {}))
|
||||
|
||||
def expire(self, key: str, ttl: int) -> bool:
|
||||
self._expiry[key] = ttl
|
||||
return True
|
||||
|
||||
|
||||
def test_rate_limiter_counts_attempts_within_same_second(monkeypatch):
|
||||
fake_redis = _FakeRedis()
|
||||
monkeypatch.setattr(helper_module.time, "time", lambda: 1000)
|
||||
|
||||
limiter = helper_module.RateLimiter(
|
||||
prefix="test_rate_limit",
|
||||
max_attempts=2,
|
||||
time_window=60,
|
||||
redis_client=fake_redis,
|
||||
)
|
||||
|
||||
limiter.increment_rate_limit("203.0.113.10")
|
||||
limiter.increment_rate_limit("203.0.113.10")
|
||||
|
||||
assert limiter.is_rate_limited("203.0.113.10") is True
|
||||
|
||||
|
||||
def test_rate_limiter_uses_injected_redis(monkeypatch):
|
||||
redis_client = MagicMock()
|
||||
redis_client.zcard.return_value = 1
|
||||
monkeypatch.setattr(helper_module.time, "time", lambda: 1000)
|
||||
|
||||
limiter = helper_module.RateLimiter(
|
||||
prefix="test_rate_limit",
|
||||
max_attempts=1,
|
||||
time_window=60,
|
||||
redis_client=redis_client,
|
||||
)
|
||||
|
||||
limiter.increment_rate_limit("203.0.113.10")
|
||||
limiter.is_rate_limited("203.0.113.10")
|
||||
|
||||
assert redis_client.zadd.called is True
|
||||
assert redis_client.zremrangebyscore.called is True
|
||||
assert redis_client.zcard.called is True
|
||||
Reference in New Issue
Block a user