mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
resolve: conflict
This commit is contained in:
400
api/tests/unit_tests/core/llm_generator/test_graph_builder.py
Normal file
400
api/tests/unit_tests/core/llm_generator/test_graph_builder.py
Normal file
@ -0,0 +1,400 @@
|
||||
"""
|
||||
Unit tests for GraphBuilder.
|
||||
|
||||
Tests the automatic graph construction from node lists with dependency declarations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.generator.utils.graph_builder import (
|
||||
CyclicDependencyError,
|
||||
GraphBuilder,
|
||||
)
|
||||
|
||||
|
||||
class TestGraphBuilderBasic:
|
||||
"""Basic functionality tests."""
|
||||
|
||||
def test_empty_nodes_creates_minimal_workflow(self):
|
||||
"""Empty node list creates start -> end workflow."""
|
||||
result_nodes, result_edges = GraphBuilder.build_graph([])
|
||||
|
||||
assert len(result_nodes) == 2
|
||||
assert result_nodes[0]["type"] == "start"
|
||||
assert result_nodes[1]["type"] == "end"
|
||||
assert len(result_edges) == 1
|
||||
assert result_edges[0]["source"] == "start"
|
||||
assert result_edges[0]["target"] == "end"
|
||||
|
||||
def test_simple_linear_workflow(self):
|
||||
"""Simple linear workflow: start -> fetch -> process -> end."""
|
||||
nodes = [
|
||||
{"id": "fetch", "type": "http-request", "depends_on": []},
|
||||
{"id": "process", "type": "llm", "depends_on": ["fetch"]},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should have: start + 2 user nodes + end = 4
|
||||
assert len(result_nodes) == 4
|
||||
assert result_nodes[0]["type"] == "start"
|
||||
assert result_nodes[-1]["type"] == "end"
|
||||
|
||||
# Should have: start->fetch, fetch->process, process->end = 3
|
||||
assert len(result_edges) == 3
|
||||
|
||||
# Verify edge connections
|
||||
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
|
||||
assert ("start", "fetch") in edge_pairs
|
||||
assert ("fetch", "process") in edge_pairs
|
||||
assert ("process", "end") in edge_pairs
|
||||
|
||||
|
||||
class TestParallelWorkflow:
|
||||
"""Tests for parallel node handling."""
|
||||
|
||||
def test_parallel_workflow(self):
|
||||
"""Parallel workflow: multiple nodes from start, merging to one."""
|
||||
nodes = [
|
||||
{"id": "api1", "type": "http-request", "depends_on": []},
|
||||
{"id": "api2", "type": "http-request", "depends_on": []},
|
||||
{"id": "merge", "type": "llm", "depends_on": ["api1", "api2"]},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# start should connect to both api1 and api2
|
||||
start_edges = [e for e in result_edges if e["source"] == "start"]
|
||||
assert len(start_edges) == 2
|
||||
|
||||
start_targets = {e["target"] for e in start_edges}
|
||||
assert start_targets == {"api1", "api2"}
|
||||
|
||||
# Both api1 and api2 should connect to merge
|
||||
merge_incoming = [e for e in result_edges if e["target"] == "merge"]
|
||||
assert len(merge_incoming) == 2
|
||||
|
||||
def test_multiple_terminal_nodes(self):
|
||||
"""Multiple terminal nodes all connect to end."""
|
||||
nodes = [
|
||||
{"id": "branch1", "type": "llm", "depends_on": []},
|
||||
{"id": "branch2", "type": "llm", "depends_on": []},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Both branches should connect to end
|
||||
end_incoming = [e for e in result_edges if e["target"] == "end"]
|
||||
assert len(end_incoming) == 2
|
||||
|
||||
|
||||
class TestIfElseWorkflow:
|
||||
"""Tests for if-else branching."""
|
||||
|
||||
def test_if_else_workflow(self):
|
||||
"""Conditional branching workflow."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "check",
|
||||
"type": "if-else",
|
||||
"config": {"true_branch": "success", "false_branch": "fallback"},
|
||||
"depends_on": [],
|
||||
},
|
||||
{"id": "success", "type": "llm", "depends_on": []},
|
||||
{"id": "fallback", "type": "code", "depends_on": []},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should have true and false branch edges
|
||||
branch_edges = [e for e in result_edges if e["source"] == "check"]
|
||||
assert len(branch_edges) == 2
|
||||
assert any(e.get("sourceHandle") == "true" for e in branch_edges)
|
||||
assert any(e.get("sourceHandle") == "false" for e in branch_edges)
|
||||
|
||||
# Verify targets
|
||||
true_edge = next(e for e in branch_edges if e.get("sourceHandle") == "true")
|
||||
false_edge = next(e for e in branch_edges if e.get("sourceHandle") == "false")
|
||||
assert true_edge["target"] == "success"
|
||||
assert false_edge["target"] == "fallback"
|
||||
|
||||
def test_if_else_missing_branch_no_error(self):
|
||||
"""if-else with only true branch doesn't error (warning only)."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "check",
|
||||
"type": "if-else",
|
||||
"config": {"true_branch": "success"},
|
||||
"depends_on": [],
|
||||
},
|
||||
{"id": "success", "type": "llm", "depends_on": []},
|
||||
]
|
||||
# Should not raise
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should have one branch edge
|
||||
branch_edges = [e for e in result_edges if e["source"] == "check"]
|
||||
assert len(branch_edges) == 1
|
||||
assert branch_edges[0].get("sourceHandle") == "true"
|
||||
|
||||
|
||||
class TestQuestionClassifierWorkflow:
|
||||
"""Tests for question-classifier branching."""
|
||||
|
||||
def test_question_classifier_workflow(self):
|
||||
"""Question classifier with multiple classes."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "classifier",
|
||||
"type": "question-classifier",
|
||||
"config": {
|
||||
"query": ["start", "user_input"],
|
||||
"classes": [
|
||||
{"id": "tech", "name": "技术问题", "target": "tech_handler"},
|
||||
{"id": "sales", "name": "销售咨询", "target": "sales_handler"},
|
||||
{"id": "other", "name": "其他问题", "target": "other_handler"},
|
||||
],
|
||||
},
|
||||
"depends_on": [],
|
||||
},
|
||||
{"id": "tech_handler", "type": "llm", "depends_on": []},
|
||||
{"id": "sales_handler", "type": "llm", "depends_on": []},
|
||||
{"id": "other_handler", "type": "llm", "depends_on": []},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should have 3 branch edges from classifier
|
||||
classifier_edges = [e for e in result_edges if e["source"] == "classifier"]
|
||||
assert len(classifier_edges) == 3
|
||||
|
||||
# Each should use class id as sourceHandle
|
||||
assert any(e.get("sourceHandle") == "tech" and e["target"] == "tech_handler" for e in classifier_edges)
|
||||
assert any(e.get("sourceHandle") == "sales" and e["target"] == "sales_handler" for e in classifier_edges)
|
||||
assert any(e.get("sourceHandle") == "other" and e["target"] == "other_handler" for e in classifier_edges)
|
||||
|
||||
def test_question_classifier_missing_target(self):
|
||||
"""Classes without target connect to end."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "classifier",
|
||||
"type": "question-classifier",
|
||||
"config": {
|
||||
"classes": [
|
||||
{"id": "known", "name": "已知问题", "target": "handler"},
|
||||
{"id": "unknown", "name": "未知问题"}, # Missing target
|
||||
],
|
||||
},
|
||||
"depends_on": [],
|
||||
},
|
||||
{"id": "handler", "type": "llm", "depends_on": []},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Missing target should connect to end
|
||||
classifier_edges = [e for e in result_edges if e["source"] == "classifier"]
|
||||
assert any(e.get("sourceHandle") == "unknown" and e["target"] == "end" for e in classifier_edges)
|
||||
|
||||
|
||||
class TestVariableDependencyInference:
|
||||
"""Tests for automatic dependency inference from variables."""
|
||||
|
||||
def test_variable_dependency_inference(self):
|
||||
"""Dependencies inferred from variable references."""
|
||||
nodes = [
|
||||
{"id": "fetch", "type": "http-request", "depends_on": []},
|
||||
{
|
||||
"id": "process",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"text": "{{#fetch.body#}}"}]},
|
||||
# No explicit depends_on, but references fetch
|
||||
},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should automatically infer process depends on fetch
|
||||
assert any(e["source"] == "fetch" and e["target"] == "process" for e in result_edges)
|
||||
|
||||
def test_system_variable_not_inferred(self):
|
||||
"""System variables (sys, start) not inferred as dependencies."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "process",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"text": "{{#sys.query#}} {{#start.input#}}"}]},
|
||||
"depends_on": [],
|
||||
},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should connect to start, not create dependency on sys or start
|
||||
edge_sources = {e["source"] for e in result_edges}
|
||||
assert "sys" not in edge_sources
|
||||
assert "start" in edge_sources
|
||||
|
||||
|
||||
class TestCycleDetection:
|
||||
"""Tests for cyclic dependency detection."""
|
||||
|
||||
def test_cyclic_dependency_detected(self):
|
||||
"""Cyclic dependencies raise error."""
|
||||
nodes = [
|
||||
{"id": "a", "type": "llm", "depends_on": ["c"]},
|
||||
{"id": "b", "type": "llm", "depends_on": ["a"]},
|
||||
{"id": "c", "type": "llm", "depends_on": ["b"]},
|
||||
]
|
||||
|
||||
with pytest.raises(CyclicDependencyError):
|
||||
GraphBuilder.build_graph(nodes)
|
||||
|
||||
def test_self_dependency_detected(self):
|
||||
"""Self-dependency raises error."""
|
||||
nodes = [
|
||||
{"id": "a", "type": "llm", "depends_on": ["a"]},
|
||||
]
|
||||
|
||||
with pytest.raises(CyclicDependencyError):
|
||||
GraphBuilder.build_graph(nodes)
|
||||
|
||||
|
||||
class TestErrorRecovery:
|
||||
"""Tests for silent error recovery."""
|
||||
|
||||
def test_invalid_dependency_removed(self):
|
||||
"""Invalid dependencies (non-existent nodes) are silently removed."""
|
||||
nodes = [
|
||||
{"id": "process", "type": "llm", "depends_on": ["nonexistent"]},
|
||||
]
|
||||
# Should not raise, invalid dependency silently removed
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Process should connect from start (since invalid dep was removed)
|
||||
assert any(e["source"] == "start" and e["target"] == "process" for e in result_edges)
|
||||
|
||||
def test_depends_on_as_string(self):
|
||||
"""depends_on as string is converted to list."""
|
||||
nodes = [
|
||||
{"id": "fetch", "type": "http-request", "depends_on": []},
|
||||
{"id": "process", "type": "llm", "depends_on": "fetch"}, # String instead of list
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should work correctly
|
||||
assert any(e["source"] == "fetch" and e["target"] == "process" for e in result_edges)
|
||||
|
||||
|
||||
class TestContainerNodes:
|
||||
"""Tests for container nodes (iteration, loop)."""
|
||||
|
||||
def test_iteration_node_as_regular_node(self):
|
||||
"""Iteration nodes behave as regular single-in-single-out nodes."""
|
||||
nodes = [
|
||||
{"id": "prepare", "type": "code", "depends_on": []},
|
||||
{
|
||||
"id": "loop",
|
||||
"type": "iteration",
|
||||
"config": {"iterator_selector": ["prepare", "items"]},
|
||||
"depends_on": ["prepare"],
|
||||
},
|
||||
{"id": "process_result", "type": "llm", "depends_on": ["loop"]},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should have standard edges: start->prepare, prepare->loop, loop->process_result, process_result->end
|
||||
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
|
||||
assert ("start", "prepare") in edge_pairs
|
||||
assert ("prepare", "loop") in edge_pairs
|
||||
assert ("loop", "process_result") in edge_pairs
|
||||
assert ("process_result", "end") in edge_pairs
|
||||
|
||||
def test_loop_node_as_regular_node(self):
|
||||
"""Loop nodes behave as regular single-in-single-out nodes."""
|
||||
nodes = [
|
||||
{"id": "init", "type": "code", "depends_on": []},
|
||||
{
|
||||
"id": "repeat",
|
||||
"type": "loop",
|
||||
"config": {"loop_count": 5},
|
||||
"depends_on": ["init"],
|
||||
},
|
||||
{"id": "finish", "type": "llm", "depends_on": ["repeat"]},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Standard edge flow
|
||||
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
|
||||
assert ("init", "repeat") in edge_pairs
|
||||
assert ("repeat", "finish") in edge_pairs
|
||||
|
||||
def test_iteration_with_variable_inference(self):
|
||||
"""Iteration node dependencies can be inferred from iterator_selector."""
|
||||
nodes = [
|
||||
{"id": "data_source", "type": "http-request", "depends_on": []},
|
||||
{
|
||||
"id": "process_each",
|
||||
"type": "iteration",
|
||||
"config": {
|
||||
"iterator_selector": ["data_source", "items"],
|
||||
},
|
||||
# No explicit depends_on, but references data_source
|
||||
},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should infer dependency from iterator_selector reference
|
||||
# Note: iterator_selector format is different from {{#...#}}, so this tests
|
||||
# that explicit depends_on is properly handled when not provided
|
||||
# In this case, process_each has no depends_on, so it connects to start
|
||||
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
|
||||
# Without explicit depends_on, connects to start
|
||||
assert ("start", "process_each") in edge_pairs or ("data_source", "process_each") in edge_pairs
|
||||
|
||||
def test_loop_node_self_reference_not_cycle(self):
|
||||
"""Loop nodes referencing their own outputs should not create cycle."""
|
||||
nodes = [
|
||||
{"id": "init", "type": "code", "depends_on": []},
|
||||
{
|
||||
"id": "my_loop",
|
||||
"type": "loop",
|
||||
"config": {
|
||||
"loop_count": 5,
|
||||
# Loop node referencing its own output (common pattern)
|
||||
"prompt": "Previous: {{#my_loop.output#}}, continue...",
|
||||
},
|
||||
"depends_on": ["init"],
|
||||
},
|
||||
{"id": "finish", "type": "llm", "depends_on": ["my_loop"]},
|
||||
]
|
||||
# Should NOT raise CyclicDependencyError
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Verify the graph is built correctly
|
||||
assert len(result_nodes) == 5 # start + 3 + end
|
||||
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
|
||||
assert ("init", "my_loop") in edge_pairs
|
||||
assert ("my_loop", "finish") in edge_pairs
|
||||
|
||||
|
||||
class TestEdgeStructure:
|
||||
"""Tests for edge structure correctness."""
|
||||
|
||||
def test_edge_has_required_fields(self):
|
||||
"""Edges have all required fields."""
|
||||
nodes = [
|
||||
{"id": "node1", "type": "llm", "depends_on": []},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
for edge in result_edges:
|
||||
assert "id" in edge
|
||||
assert "source" in edge
|
||||
assert "target" in edge
|
||||
assert "sourceHandle" in edge
|
||||
assert "targetHandle" in edge
|
||||
|
||||
def test_edge_id_unique(self):
|
||||
"""Each edge has a unique ID."""
|
||||
nodes = [
|
||||
{"id": "a", "type": "llm", "depends_on": []},
|
||||
{"id": "b", "type": "llm", "depends_on": []},
|
||||
{"id": "c", "type": "llm", "depends_on": ["a", "b"]},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
edge_ids = [e["id"] for e in result_edges]
|
||||
assert len(edge_ids) == len(set(edge_ids)) # All unique
|
||||
@ -0,0 +1,287 @@
|
||||
"""
|
||||
Unit tests for the Mermaid Generator.
|
||||
|
||||
Tests cover:
|
||||
- Basic workflow rendering
|
||||
- Reserved word handling ('end' → 'end_node')
|
||||
- Question classifier multi-branch edges
|
||||
- If-else branch labels
|
||||
- Edge validation and skipping
|
||||
- Tool node formatting
|
||||
"""
|
||||
|
||||
from core.workflow.generator.utils.mermaid_generator import generate_mermaid
|
||||
|
||||
|
||||
class TestBasicWorkflow:
|
||||
"""Tests for basic workflow Mermaid generation."""
|
||||
|
||||
def test_simple_start_end_workflow(self):
|
||||
"""Test simple Start → End workflow."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "title": "Start"},
|
||||
{"id": "end", "type": "end", "title": "End"},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "end"}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "flowchart TD" in result
|
||||
assert 'start["type=start|title=Start"]' in result
|
||||
assert 'end_node["type=end|title=End"]' in result
|
||||
assert "start --> end_node" in result
|
||||
|
||||
def test_start_llm_end_workflow(self):
|
||||
"""Test Start → LLM → End workflow."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "title": "Start"},
|
||||
{"id": "llm", "type": "llm", "title": "Generate"},
|
||||
{"id": "end", "type": "end", "title": "End"},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "start", "target": "llm"},
|
||||
{"source": "llm", "target": "end"},
|
||||
],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert 'llm["type=llm|title=Generate"]' in result
|
||||
assert "start --> llm" in result
|
||||
assert "llm --> end_node" in result
|
||||
|
||||
def test_empty_workflow(self):
|
||||
"""Test empty workflow returns minimal output."""
|
||||
workflow_data = {"nodes": [], "edges": []}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert result == "flowchart TD"
|
||||
|
||||
def test_missing_keys_handled(self):
|
||||
"""Test workflow with missing keys doesn't crash."""
|
||||
workflow_data = {}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "flowchart TD" in result
|
||||
|
||||
|
||||
class TestReservedWords:
|
||||
"""Tests for reserved word handling in node IDs."""
|
||||
|
||||
def test_end_node_id_is_replaced(self):
|
||||
"""Test 'end' node ID is replaced with 'end_node'."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "end", "type": "end", "title": "End"}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Should use end_node instead of end
|
||||
assert "end_node[" in result
|
||||
assert '"type=end|title=End"' in result
|
||||
|
||||
def test_subgraph_node_id_is_replaced(self):
|
||||
"""Test 'subgraph' node ID is replaced with 'subgraph_node'."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "subgraph", "type": "code", "title": "Process"}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "subgraph_node[" in result
|
||||
|
||||
def test_edge_uses_safe_ids(self):
|
||||
"""Test edges correctly reference safe IDs after replacement."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "title": "Start"},
|
||||
{"id": "end", "type": "end", "title": "End"},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "end"}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Edge should use end_node, not end
|
||||
assert "start --> end_node" in result
|
||||
assert "start --> end\n" not in result
|
||||
|
||||
|
||||
class TestBranchEdges:
|
||||
"""Tests for branching node edge labels."""
|
||||
|
||||
def test_question_classifier_source_handles(self):
|
||||
"""Test question-classifier edges with sourceHandle labels."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "classifier", "type": "question-classifier", "title": "Classify"},
|
||||
{"id": "refund", "type": "llm", "title": "Handle Refund"},
|
||||
{"id": "inquiry", "type": "llm", "title": "Handle Inquiry"},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "classifier", "target": "refund", "sourceHandle": "refund"},
|
||||
{"source": "classifier", "target": "inquiry", "sourceHandle": "inquiry"},
|
||||
],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "classifier -->|refund| refund" in result
|
||||
assert "classifier -->|inquiry| inquiry" in result
|
||||
|
||||
def test_if_else_true_false_handles(self):
|
||||
"""Test if-else edges with true/false labels."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "ifelse", "type": "if-else", "title": "Check"},
|
||||
{"id": "yes_branch", "type": "llm", "title": "Yes"},
|
||||
{"id": "no_branch", "type": "llm", "title": "No"},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "ifelse", "target": "yes_branch", "sourceHandle": "true"},
|
||||
{"source": "ifelse", "target": "no_branch", "sourceHandle": "false"},
|
||||
],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "ifelse -->|true| yes_branch" in result
|
||||
assert "ifelse -->|false| no_branch" in result
|
||||
|
||||
def test_source_handle_source_is_ignored(self):
|
||||
"""Test sourceHandle='source' doesn't add label."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "llm1", "type": "llm", "title": "LLM 1"},
|
||||
{"id": "llm2", "type": "llm", "title": "LLM 2"},
|
||||
],
|
||||
"edges": [{"source": "llm1", "target": "llm2", "sourceHandle": "source"}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Should be plain arrow without label
|
||||
assert "llm1 --> llm2" in result
|
||||
assert "llm1 -->|source|" not in result
|
||||
|
||||
|
||||
class TestEdgeValidation:
|
||||
"""Tests for edge validation and error handling."""
|
||||
|
||||
def test_edge_with_missing_source_is_skipped(self):
|
||||
"""Test edge with non-existent source node is skipped."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "end", "type": "end", "title": "End"}],
|
||||
"edges": [{"source": "nonexistent", "target": "end"}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Should not contain the invalid edge
|
||||
assert "nonexistent" not in result
|
||||
assert "-->" not in result or "nonexistent" not in result
|
||||
|
||||
def test_edge_with_missing_target_is_skipped(self):
|
||||
"""Test edge with non-existent target node is skipped."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "start", "type": "start", "title": "Start"}],
|
||||
"edges": [{"source": "start", "target": "nonexistent"}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Edge should be skipped
|
||||
assert "start --> nonexistent" not in result
|
||||
|
||||
def test_edge_without_source_or_target_is_skipped(self):
|
||||
"""Test edge missing source or target is skipped."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "start", "type": "start", "title": "Start"}],
|
||||
"edges": [{"source": "start"}, {"target": "start"}, {}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# No edges should be rendered
|
||||
assert result.count("-->") == 0
|
||||
|
||||
|
||||
class TestToolNodes:
|
||||
"""Tests for tool node formatting."""
|
||||
|
||||
def test_tool_node_includes_tool_key(self):
|
||||
"""Test tool node includes tool_key in label."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "search",
|
||||
"type": "tool",
|
||||
"title": "Search",
|
||||
"config": {"tool_key": "google/search"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert 'search["type=tool|title=Search|tool=google/search"]' in result
|
||||
|
||||
def test_tool_node_with_tool_name_fallback(self):
|
||||
"""Test tool node uses tool_name as fallback."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "tool1",
|
||||
"type": "tool",
|
||||
"title": "My Tool",
|
||||
"config": {"tool_name": "my_tool"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "tool=my_tool" in result
|
||||
|
||||
def test_tool_node_missing_tool_key_shows_unknown(self):
|
||||
"""Test tool node without tool_key shows 'unknown'."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "tool1", "type": "tool", "title": "Tool", "config": {}}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "tool=unknown" in result
|
||||
|
||||
|
||||
class TestNodeFormatting:
|
||||
"""Tests for node label formatting."""
|
||||
|
||||
def test_quotes_in_title_are_escaped(self):
|
||||
"""Test double quotes in title are replaced with single quotes."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "llm", "type": "llm", "title": 'Say "Hello"'}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Double quotes should be replaced
|
||||
assert "Say 'Hello'" in result
|
||||
assert 'Say "Hello"' not in result
|
||||
|
||||
def test_node_without_id_is_skipped(self):
|
||||
"""Test node without id is skipped."""
|
||||
workflow_data = {
|
||||
"nodes": [{"type": "llm", "title": "No ID"}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Should only have flowchart header
|
||||
lines = [line for line in result.split("\n") if line.strip()]
|
||||
assert len(lines) == 1
|
||||
|
||||
def test_node_default_values(self):
|
||||
"""Test node with missing type/title uses defaults."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "node1"}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "type=unknown" in result
|
||||
assert "title=Untitled" in result
|
||||
81
api/tests/unit_tests/core/llm_generator/test_node_repair.py
Normal file
81
api/tests/unit_tests/core/llm_generator/test_node_repair.py
Normal file
@ -0,0 +1,81 @@
|
||||
from core.workflow.generator.utils.node_repair import NodeRepair
|
||||
|
||||
|
||||
class TestNodeRepair:
|
||||
"""Tests for NodeRepair utility."""
|
||||
|
||||
def test_repair_if_else_valid_operators(self):
|
||||
"""Test that valid operators remain unchanged."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "node1",
|
||||
"type": "if-else",
|
||||
"config": {
|
||||
"cases": [
|
||||
{
|
||||
"conditions": [
|
||||
{"comparison_operator": "≥", "value": "1"},
|
||||
{"comparison_operator": "=", "value": "2"},
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
]
|
||||
result = NodeRepair.repair(nodes)
|
||||
assert result.was_repaired is False
|
||||
assert result.nodes == nodes
|
||||
|
||||
def test_repair_if_else_invalid_operators(self):
|
||||
"""Test that invalid operators are normalized."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "node1",
|
||||
"type": "if-else",
|
||||
"config": {
|
||||
"cases": [
|
||||
{
|
||||
"conditions": [
|
||||
{"comparison_operator": ">=", "value": "1"},
|
||||
{"comparison_operator": "<=", "value": "2"},
|
||||
{"comparison_operator": "!=", "value": "3"},
|
||||
{"comparison_operator": "==", "value": "4"},
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
]
|
||||
result = NodeRepair.repair(nodes)
|
||||
assert result.was_repaired is True
|
||||
assert len(result.repairs_made) == 4
|
||||
|
||||
conditions = result.nodes[0]["config"]["cases"][0]["conditions"]
|
||||
assert conditions[0]["comparison_operator"] == "≥"
|
||||
assert conditions[1]["comparison_operator"] == "≤"
|
||||
assert conditions[2]["comparison_operator"] == "≠"
|
||||
assert conditions[3]["comparison_operator"] == "="
|
||||
|
||||
def test_repair_ignores_other_nodes(self):
|
||||
"""Test that other node types are ignored."""
|
||||
nodes = [{"id": "node1", "type": "llm", "config": {"some_field": ">="}}]
|
||||
result = NodeRepair.repair(nodes)
|
||||
assert result.was_repaired is False
|
||||
assert result.nodes[0]["config"]["some_field"] == ">="
|
||||
|
||||
def test_repair_handles_missing_config(self):
|
||||
"""Test robustness against missing fields."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "node1",
|
||||
"type": "if-else",
|
||||
# Missing config
|
||||
},
|
||||
{
|
||||
"id": "node2",
|
||||
"type": "if-else",
|
||||
"config": {}, # Missing cases
|
||||
},
|
||||
]
|
||||
result = NodeRepair.repair(nodes)
|
||||
assert result.was_repaired is False
|
||||
@ -0,0 +1,99 @@
|
||||
"""
|
||||
Tests for node schemas validation.
|
||||
|
||||
Ensures that the node configuration stays in sync with registered node types.
|
||||
"""
|
||||
|
||||
from core.workflow.generator.config.node_schemas import (
|
||||
get_builtin_node_schemas,
|
||||
validate_node_schemas,
|
||||
)
|
||||
|
||||
|
||||
class TestNodeSchemasValidation:
|
||||
"""Tests for node schema validation utilities."""
|
||||
|
||||
def test_validate_node_schemas_returns_no_warnings(self):
|
||||
"""Ensure all registered node types have corresponding schemas."""
|
||||
warnings = validate_node_schemas()
|
||||
# If this test fails, it means a new node type was added but
|
||||
# no schema was defined for it in node_schemas.py
|
||||
assert len(warnings) == 0, (
|
||||
f"Missing schemas for node types: {warnings}. "
|
||||
"Please add schemas for these node types in node_schemas.py "
|
||||
"or add them to _INTERNAL_NODE_TYPES if they don't need schemas."
|
||||
)
|
||||
|
||||
def test_builtin_node_schemas_not_empty(self):
|
||||
"""Ensure BUILTIN_NODE_SCHEMAS contains expected node types."""
|
||||
# get_builtin_node_schemas() includes dynamic schemas
|
||||
all_schemas = get_builtin_node_schemas()
|
||||
assert len(all_schemas) > 0
|
||||
# Core node types should always be present
|
||||
expected_types = ["llm", "code", "http-request", "if-else"]
|
||||
for node_type in expected_types:
|
||||
assert node_type in all_schemas, f"Missing schema for core node type: {node_type}"
|
||||
|
||||
def test_schema_structure(self):
|
||||
"""Ensure each schema has required fields."""
|
||||
all_schemas = get_builtin_node_schemas()
|
||||
for node_type, schema in all_schemas.items():
|
||||
assert "description" in schema, f"Missing 'description' in schema for {node_type}"
|
||||
# 'parameters' is optional but if present should be a dict
|
||||
if "parameters" in schema:
|
||||
assert isinstance(schema["parameters"], dict), (
|
||||
f"'parameters' in schema for {node_type} should be a dict"
|
||||
)
|
||||
|
||||
|
||||
class TestNodeSchemasMerged:
|
||||
"""Tests to verify the merged configuration works correctly."""
|
||||
|
||||
def test_fallback_rules_available(self):
|
||||
"""Ensure FALLBACK_RULES is available from node_schemas."""
|
||||
from core.workflow.generator.config.node_schemas import FALLBACK_RULES
|
||||
|
||||
assert len(FALLBACK_RULES) > 0
|
||||
assert "http-request" in FALLBACK_RULES
|
||||
assert "code" in FALLBACK_RULES
|
||||
assert "llm" in FALLBACK_RULES
|
||||
|
||||
def test_node_type_aliases_available(self):
|
||||
"""Ensure NODE_TYPE_ALIASES is available from node_schemas."""
|
||||
from core.workflow.generator.config.node_schemas import NODE_TYPE_ALIASES
|
||||
|
||||
assert len(NODE_TYPE_ALIASES) > 0
|
||||
assert NODE_TYPE_ALIASES.get("gpt") == "llm"
|
||||
assert NODE_TYPE_ALIASES.get("api") == "http-request"
|
||||
|
||||
def test_field_name_corrections_available(self):
|
||||
"""Ensure FIELD_NAME_CORRECTIONS is available from node_schemas."""
|
||||
from core.workflow.generator.config.node_schemas import (
|
||||
FIELD_NAME_CORRECTIONS,
|
||||
get_corrected_field_name,
|
||||
)
|
||||
|
||||
assert len(FIELD_NAME_CORRECTIONS) > 0
|
||||
# Test the helper function
|
||||
assert get_corrected_field_name("http-request", "text") == "body"
|
||||
assert get_corrected_field_name("llm", "response") == "text"
|
||||
assert get_corrected_field_name("code", "unknown") == "unknown"
|
||||
|
||||
def test_config_init_exports(self):
|
||||
"""Ensure config __init__.py exports all needed symbols."""
|
||||
from core.workflow.generator.config import (
|
||||
BUILTIN_NODE_SCHEMAS,
|
||||
FALLBACK_RULES,
|
||||
FIELD_NAME_CORRECTIONS,
|
||||
NODE_TYPE_ALIASES,
|
||||
get_corrected_field_name,
|
||||
validate_node_schemas,
|
||||
)
|
||||
|
||||
# Just verify imports work
|
||||
assert BUILTIN_NODE_SCHEMAS is not None
|
||||
assert FALLBACK_RULES is not None
|
||||
assert FIELD_NAME_CORRECTIONS is not None
|
||||
assert NODE_TYPE_ALIASES is not None
|
||||
assert callable(get_corrected_field_name)
|
||||
assert callable(validate_node_schemas)
|
||||
172
api/tests/unit_tests/core/llm_generator/test_planner_prompts.py
Normal file
172
api/tests/unit_tests/core/llm_generator/test_planner_prompts.py
Normal file
@ -0,0 +1,172 @@
|
||||
"""
|
||||
Unit tests for the Planner Prompts.
|
||||
|
||||
Tests cover:
|
||||
- Tool formatting for planner context
|
||||
- Edge cases with missing fields
|
||||
- Empty tool lists
|
||||
"""
|
||||
|
||||
from core.workflow.generator.prompts.planner_prompts import format_tools_for_planner
|
||||
|
||||
|
||||
class TestFormatToolsForPlanner:
|
||||
"""Tests for format_tools_for_planner function."""
|
||||
|
||||
def test_empty_tools_returns_default_message(self):
|
||||
"""Test empty tools list returns default message."""
|
||||
result = format_tools_for_planner([])
|
||||
|
||||
assert result == "No external tools available."
|
||||
|
||||
def test_none_tools_returns_default_message(self):
|
||||
"""Test None tools list returns default message."""
|
||||
result = format_tools_for_planner(None)
|
||||
|
||||
assert result == "No external tools available."
|
||||
|
||||
def test_single_tool_formatting(self):
|
||||
"""Test single tool is formatted correctly."""
|
||||
tools = [
|
||||
{
|
||||
"provider_id": "google",
|
||||
"tool_key": "search",
|
||||
"tool_label": "Google Search",
|
||||
"tool_description": "Search the web using Google",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
assert "[google/search]" in result
|
||||
assert "Google Search" in result
|
||||
assert "Search the web using Google" in result
|
||||
|
||||
def test_multiple_tools_formatting(self):
|
||||
"""Test multiple tools are formatted correctly."""
|
||||
tools = [
|
||||
{
|
||||
"provider_id": "google",
|
||||
"tool_key": "search",
|
||||
"tool_label": "Search",
|
||||
"tool_description": "Web search",
|
||||
},
|
||||
{
|
||||
"provider_id": "slack",
|
||||
"tool_key": "send_message",
|
||||
"tool_label": "Send Message",
|
||||
"tool_description": "Send a Slack message",
|
||||
},
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
assert "[google/search]" in result
|
||||
assert "[slack/send_message]" in result
|
||||
|
||||
def test_tool_without_provider_uses_key_only(self):
|
||||
"""Test tool without provider_id uses tool_key only."""
|
||||
tools = [
|
||||
{
|
||||
"tool_key": "my_tool",
|
||||
"tool_label": "My Tool",
|
||||
"tool_description": "A custom tool",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
# Should format as [my_tool] without provider prefix
|
||||
assert "[my_tool]" in result
|
||||
assert "My Tool" in result
|
||||
|
||||
def test_tool_with_tool_name_fallback(self):
|
||||
"""Test tool uses tool_name when tool_key is missing."""
|
||||
tools = [
|
||||
{
|
||||
"tool_name": "fallback_tool",
|
||||
"description": "Fallback description",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
assert "fallback_tool" in result
|
||||
assert "Fallback description" in result
|
||||
|
||||
def test_tool_with_missing_description(self):
|
||||
"""Test tool with missing description doesn't crash."""
|
||||
tools = [
|
||||
{
|
||||
"provider_id": "test",
|
||||
"tool_key": "tool1",
|
||||
"tool_label": "Tool 1",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
assert "[test/tool1]" in result
|
||||
assert "Tool 1" in result
|
||||
|
||||
def test_tool_with_all_missing_fields(self):
|
||||
"""Test tool with all fields missing uses defaults."""
|
||||
tools = [{}]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
# Should not crash, may produce minimal output
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_tool_uses_provider_fallback(self):
|
||||
"""Test tool uses 'provider' when 'provider_id' is missing."""
|
||||
tools = [
|
||||
{
|
||||
"provider": "openai",
|
||||
"tool_key": "dalle",
|
||||
"tool_label": "DALL-E",
|
||||
"tool_description": "Generate images",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
assert "[openai/dalle]" in result
|
||||
|
||||
def test_tool_label_fallback_to_key(self):
|
||||
"""Test tool_label falls back to tool_key when missing."""
|
||||
tools = [
|
||||
{
|
||||
"provider_id": "test",
|
||||
"tool_key": "my_key",
|
||||
"tool_description": "Description here",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
# Label should fallback to key
|
||||
assert "my_key" in result
|
||||
assert "Description here" in result
|
||||
|
||||
|
||||
class TestPlannerPromptConstants:
|
||||
"""Tests for planner prompt constant availability."""
|
||||
|
||||
def test_planner_system_prompt_exists(self):
|
||||
"""Test PLANNER_SYSTEM_PROMPT is defined."""
|
||||
from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT
|
||||
|
||||
assert PLANNER_SYSTEM_PROMPT is not None
|
||||
assert len(PLANNER_SYSTEM_PROMPT) > 0
|
||||
assert "{tools_summary}" in PLANNER_SYSTEM_PROMPT
|
||||
|
||||
def test_planner_user_prompt_exists(self):
|
||||
"""Test PLANNER_USER_PROMPT is defined."""
|
||||
from core.workflow.generator.prompts.planner_prompts import PLANNER_USER_PROMPT
|
||||
|
||||
assert PLANNER_USER_PROMPT is not None
|
||||
assert "{instruction}" in PLANNER_USER_PROMPT
|
||||
|
||||
def test_planner_system_prompt_has_required_sections(self):
|
||||
"""Test PLANNER_SYSTEM_PROMPT has required XML sections."""
|
||||
from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT
|
||||
|
||||
assert "<role>" in PLANNER_SYSTEM_PROMPT
|
||||
assert "<task>" in PLANNER_SYSTEM_PROMPT
|
||||
assert "<available_tools>" in PLANNER_SYSTEM_PROMPT
|
||||
assert "<response_format>" in PLANNER_SYSTEM_PROMPT
|
||||
@ -0,0 +1,510 @@
|
||||
"""
|
||||
Unit tests for the Validation Rule Engine.
|
||||
|
||||
Tests cover:
|
||||
- Structure rules (required fields, types, formats)
|
||||
- Semantic rules (variable references, edge connections)
|
||||
- Reference rules (model exists, tool configured, dataset valid)
|
||||
- ValidationEngine integration
|
||||
"""
|
||||
|
||||
from core.workflow.generator.validation import (
|
||||
ValidationContext,
|
||||
ValidationEngine,
|
||||
)
|
||||
from core.workflow.generator.validation.rules import (
|
||||
extract_variable_refs,
|
||||
is_placeholder,
|
||||
)
|
||||
|
||||
|
||||
class TestPlaceholderDetection:
|
||||
"""Tests for placeholder detection utility."""
|
||||
|
||||
def test_detects_please_select(self):
|
||||
assert is_placeholder("PLEASE_SELECT_YOUR_MODEL") is True
|
||||
|
||||
def test_detects_your_prefix(self):
|
||||
assert is_placeholder("YOUR_API_KEY") is True
|
||||
|
||||
def test_detects_todo(self):
|
||||
assert is_placeholder("TODO: fill this in") is True
|
||||
|
||||
def test_detects_placeholder(self):
|
||||
assert is_placeholder("PLACEHOLDER_VALUE") is True
|
||||
|
||||
def test_detects_example_prefix(self):
|
||||
assert is_placeholder("EXAMPLE_URL") is True
|
||||
|
||||
def test_detects_replace_prefix(self):
|
||||
assert is_placeholder("REPLACE_WITH_ACTUAL") is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert is_placeholder("please_select") is True
|
||||
assert is_placeholder("Please_Select") is True
|
||||
|
||||
def test_valid_values_not_detected(self):
|
||||
assert is_placeholder("https://api.example.com") is False
|
||||
assert is_placeholder("gpt-4") is False
|
||||
assert is_placeholder("my_variable") is False
|
||||
|
||||
def test_non_string_returns_false(self):
|
||||
assert is_placeholder(123) is False
|
||||
assert is_placeholder(None) is False
|
||||
assert is_placeholder(["list"]) is False
|
||||
|
||||
|
||||
class TestVariableRefExtraction:
|
||||
"""Tests for variable reference extraction."""
|
||||
|
||||
def test_extracts_simple_ref(self):
|
||||
refs = extract_variable_refs("Hello {{#start.query#}}")
|
||||
assert refs == [("start", "query")]
|
||||
|
||||
def test_extracts_multiple_refs(self):
|
||||
refs = extract_variable_refs("{{#node1.output#}} and {{#node2.text#}}")
|
||||
assert refs == [("node1", "output"), ("node2", "text")]
|
||||
|
||||
def test_extracts_nested_field(self):
|
||||
refs = extract_variable_refs("{{#http_request.body#}}")
|
||||
assert refs == [("http_request", "body")]
|
||||
|
||||
def test_no_refs_returns_empty(self):
|
||||
refs = extract_variable_refs("No references here")
|
||||
assert refs == []
|
||||
|
||||
def test_handles_malformed_refs(self):
|
||||
refs = extract_variable_refs("{{#invalid}} and {{incomplete#}}")
|
||||
assert refs == []
|
||||
|
||||
|
||||
class TestValidationContext:
|
||||
"""Tests for ValidationContext."""
|
||||
|
||||
def test_node_map_lookup(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start"},
|
||||
{"id": "llm_1", "type": "llm"},
|
||||
]
|
||||
)
|
||||
assert ctx.get_node("start") == {"id": "start", "type": "start"}
|
||||
assert ctx.get_node("nonexistent") is None
|
||||
|
||||
def test_model_set(self):
|
||||
ctx = ValidationContext(
|
||||
available_models=[
|
||||
{"provider": "openai", "model": "gpt-4"},
|
||||
{"provider": "anthropic", "model": "claude-3"},
|
||||
]
|
||||
)
|
||||
assert ctx.has_model("openai", "gpt-4") is True
|
||||
assert ctx.has_model("anthropic", "claude-3") is True
|
||||
assert ctx.has_model("openai", "gpt-3.5") is False
|
||||
|
||||
def test_tool_set(self):
|
||||
ctx = ValidationContext(
|
||||
available_tools=[
|
||||
{"provider_id": "google", "tool_key": "search", "is_team_authorization": True},
|
||||
{"provider_id": "slack", "tool_key": "send_message", "is_team_authorization": False},
|
||||
]
|
||||
)
|
||||
assert ctx.has_tool("google/search") is True
|
||||
assert ctx.has_tool("search") is True
|
||||
assert ctx.is_tool_configured("google/search") is True
|
||||
assert ctx.is_tool_configured("slack/send_message") is False
|
||||
|
||||
def test_upstream_downstream_nodes(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start"},
|
||||
{"id": "llm", "type": "llm"},
|
||||
{"id": "end", "type": "end"},
|
||||
],
|
||||
edges=[
|
||||
{"source": "start", "target": "llm"},
|
||||
{"source": "llm", "target": "end"},
|
||||
],
|
||||
)
|
||||
assert ctx.get_upstream_nodes("llm") == ["start"]
|
||||
assert ctx.get_downstream_nodes("llm") == ["end"]
|
||||
|
||||
|
||||
class TestStructureRules:
|
||||
"""Tests for structure validation rules."""
|
||||
|
||||
def test_llm_missing_prompt_template(self):
|
||||
ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}])
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
assert result.has_errors
|
||||
errors = [e for e in result.all_errors if e.rule_id == "llm.prompt_template.required"]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is True
|
||||
|
||||
def test_llm_with_prompt_template_passes(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"prompt_template": [
|
||||
{"role": "system", "text": "You are helpful"},
|
||||
{"role": "user", "text": "Hello"},
|
||||
]
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
# No prompt_template errors
|
||||
errors = [e for e in result.all_errors if "prompt_template" in e.rule_id]
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_http_request_missing_url(self):
|
||||
ctx = ValidationContext(nodes=[{"id": "http_1", "type": "http-request", "config": {}}])
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "http.url" in e.rule_id]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is True
|
||||
|
||||
def test_http_request_placeholder_url(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "http_1",
|
||||
"type": "http-request",
|
||||
"config": {"url": "PLEASE_SELECT_YOUR_URL", "method": "GET"},
|
||||
}
|
||||
]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "placeholder" in e.rule_id]
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_code_node_missing_fields(self):
|
||||
ctx = ValidationContext(nodes=[{"id": "code_1", "type": "code", "config": {}}])
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
error_rules = {e.rule_id for e in result.all_errors}
|
||||
assert "code.code.required" in error_rules
|
||||
assert "code.language.required" in error_rules
|
||||
|
||||
def test_knowledge_retrieval_missing_dataset(self):
|
||||
ctx = ValidationContext(nodes=[{"id": "kb_1", "type": "knowledge-retrieval", "config": {}}])
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "knowledge.dataset" in e.rule_id]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is False # User must configure
|
||||
|
||||
|
||||
class TestSemanticRules:
|
||||
"""Tests for semantic validation rules."""
|
||||
|
||||
def test_valid_variable_reference(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "Process: {{#start.query#}}"}]},
|
||||
},
|
||||
]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
# No variable reference errors
|
||||
errors = [e for e in result.all_errors if "variable.ref" in e.rule_id]
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_invalid_variable_reference(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "Process: {{#nonexistent.field#}}"}]},
|
||||
},
|
||||
]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "variable.ref" in e.rule_id]
|
||||
assert len(errors) == 1
|
||||
assert "nonexistent" in errors[0].message
|
||||
|
||||
def test_edge_validation(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
edges=[
|
||||
{"source": "start", "target": "end"},
|
||||
{"source": "nonexistent", "target": "end"},
|
||||
],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "edge" in e.rule_id]
|
||||
assert len(errors) == 1
|
||||
assert "nonexistent" in errors[0].message
|
||||
|
||||
|
||||
class TestReferenceRules:
|
||||
"""Tests for reference validation rules (models, tools)."""
|
||||
|
||||
def test_llm_missing_model_with_available(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
|
||||
}
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if e.rule_id == "model.required"]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is True
|
||||
|
||||
def test_llm_missing_model_no_available(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
|
||||
}
|
||||
],
|
||||
available_models=[], # No models available
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if e.rule_id == "model.no_available"]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is False
|
||||
|
||||
def test_llm_with_valid_model(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"prompt_template": [{"role": "user", "text": "Hi"}],
|
||||
"model": {"provider": "openai", "name": "gpt-4"},
|
||||
},
|
||||
}
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "model" in e.rule_id]
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_llm_with_invalid_model(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"prompt_template": [{"role": "user", "text": "Hi"}],
|
||||
"model": {"provider": "openai", "name": "gpt-99"},
|
||||
},
|
||||
}
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if e.rule_id == "model.not_found"]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is True
|
||||
|
||||
def test_tool_node_not_found(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "tool_1",
|
||||
"type": "tool",
|
||||
"config": {"tool_key": "nonexistent/tool"},
|
||||
}
|
||||
],
|
||||
available_tools=[],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if e.rule_id == "tool.not_found"]
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_tool_node_not_configured(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "tool_1",
|
||||
"type": "tool",
|
||||
"config": {"tool_key": "google/search"},
|
||||
}
|
||||
],
|
||||
available_tools=[{"provider_id": "google", "tool_key": "search", "is_team_authorization": False}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if e.rule_id == "tool.not_configured"]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is False
|
||||
|
||||
|
||||
class TestValidationResult:
|
||||
"""Tests for ValidationResult classification."""
|
||||
|
||||
def test_has_errors(self):
|
||||
ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}])
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
assert result.has_errors is True
|
||||
assert result.is_valid is False
|
||||
|
||||
def test_has_fixable_errors(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
|
||||
}
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
assert result.has_fixable_errors is True
|
||||
assert len(result.fixable_errors) > 0
|
||||
|
||||
def test_get_fixable_by_node(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "llm_1", "type": "llm", "config": {}},
|
||||
{"id": "http_1", "type": "http-request", "config": {}},
|
||||
]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
by_node = result.get_fixable_by_node()
|
||||
assert "llm_1" in by_node
|
||||
assert "http_1" in by_node
|
||||
|
||||
def test_to_dict(self):
|
||||
ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}])
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
d = result.to_dict()
|
||||
assert "fixable" in d
|
||||
assert "user_required" in d
|
||||
assert "warnings" in d
|
||||
assert "all_warnings" in d
|
||||
assert "stats" in d
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for the full validation pipeline."""
|
||||
|
||||
def test_complete_workflow_validation(self):
|
||||
"""Test validation of a complete workflow."""
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"config": {"variables": [{"variable": "query", "type": "text-input"}]},
|
||||
},
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"model": {"provider": "openai", "name": "gpt-4"},
|
||||
"prompt_template": [{"role": "user", "text": "{{#start.query#}}"}],
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "end",
|
||||
"type": "end",
|
||||
"config": {"outputs": [{"variable": "result", "value_selector": ["llm_1", "text"]}]},
|
||||
},
|
||||
],
|
||||
edges=[
|
||||
{"source": "start", "target": "llm_1"},
|
||||
{"source": "llm_1", "target": "end"},
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
# Should have no errors
|
||||
assert result.is_valid is True
|
||||
assert len(result.fixable_errors) == 0
|
||||
assert len(result.user_required_errors) == 0
|
||||
|
||||
def test_workflow_with_multiple_errors(self):
|
||||
"""Test workflow with multiple types of errors."""
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {}, # Missing prompt_template and model
|
||||
},
|
||||
{
|
||||
"id": "kb_1",
|
||||
"type": "knowledge-retrieval",
|
||||
"config": {"dataset_ids": ["PLEASE_SELECT_YOUR_DATASET"]},
|
||||
},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
# Should have multiple errors
|
||||
assert result.has_errors is True
|
||||
assert len(result.fixable_errors) >= 2 # model, prompt_template
|
||||
assert len(result.user_required_errors) >= 1 # dataset placeholder
|
||||
|
||||
# Check stats
|
||||
assert result.stats["total_nodes"] == 4
|
||||
assert result.stats["total_errors"] >= 3
|
||||
@ -0,0 +1,197 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.nodes.agent.agent_node import AgentNode
|
||||
|
||||
|
||||
class TestInferToolProviderType:
|
||||
"""Test cases for AgentNode._infer_tool_provider_type method."""
|
||||
|
||||
def test_infer_type_from_config_workflow(self):
|
||||
"""Test inferring workflow provider type from config."""
|
||||
tool_config = {
|
||||
"type": "workflow",
|
||||
"provider_name": "workflow-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.WORKFLOW
|
||||
|
||||
def test_infer_type_from_config_builtin(self):
|
||||
"""Test inferring builtin provider type from config."""
|
||||
tool_config = {
|
||||
"type": "builtin",
|
||||
"provider_name": "builtin-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.BUILT_IN
|
||||
|
||||
def test_infer_type_from_config_api(self):
|
||||
"""Test inferring API provider type from config."""
|
||||
tool_config = {
|
||||
"type": "api",
|
||||
"provider_name": "api-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.API
|
||||
|
||||
def test_infer_type_from_config_mcp(self):
|
||||
"""Test inferring MCP provider type from config."""
|
||||
tool_config = {
|
||||
"type": "mcp",
|
||||
"provider_name": "mcp-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.MCP
|
||||
|
||||
def test_infer_type_invalid_config_value_raises_error(self):
|
||||
"""Test that invalid type value in config raises ValueError."""
|
||||
tool_config = {
|
||||
"type": "invalid-type",
|
||||
"provider_name": "workflow-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
def test_infer_workflow_type_from_database(self):
|
||||
"""Test inferring workflow provider type from database."""
|
||||
tool_config = {
|
||||
"provider_name": "workflow-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||
mock_session = MagicMock()
|
||||
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# First query (WorkflowToolProvider) returns a result
|
||||
mock_session.scalar.return_value = True
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.WORKFLOW
|
||||
# Should only query once (after finding WorkflowToolProvider)
|
||||
assert mock_session.scalar.call_count == 1
|
||||
|
||||
def test_infer_mcp_type_from_database(self):
|
||||
"""Test inferring MCP provider type from database."""
|
||||
tool_config = {
|
||||
"provider_name": "mcp-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||
mock_session = MagicMock()
|
||||
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# First query (WorkflowToolProvider) returns None
|
||||
# Second query (MCPToolProvider) returns a result
|
||||
mock_session.scalar.side_effect = [None, True]
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.MCP
|
||||
assert mock_session.scalar.call_count == 2
|
||||
|
||||
def test_infer_api_type_from_database(self):
|
||||
"""Test inferring API provider type from database."""
|
||||
tool_config = {
|
||||
"provider_name": "api-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||
mock_session = MagicMock()
|
||||
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# First query (WorkflowToolProvider) returns None
|
||||
# Second query (MCPToolProvider) returns None
|
||||
# Third query (ApiToolProvider) returns a result
|
||||
mock_session.scalar.side_effect = [None, None, True]
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.API
|
||||
assert mock_session.scalar.call_count == 3
|
||||
|
||||
def test_infer_builtin_type_from_database(self):
|
||||
"""Test inferring builtin provider type from database."""
|
||||
tool_config = {
|
||||
"provider_name": "builtin-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||
mock_session = MagicMock()
|
||||
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# First three queries return None
|
||||
# Fourth query (BuiltinToolProvider) returns a result
|
||||
mock_session.scalar.side_effect = [None, None, None, True]
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.BUILT_IN
|
||||
assert mock_session.scalar.call_count == 4
|
||||
|
||||
def test_infer_type_default_when_not_found(self):
|
||||
"""Test raising AgentNodeError when provider is not found in database."""
|
||||
tool_config = {
|
||||
"provider_name": "unknown-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||
mock_session = MagicMock()
|
||||
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# All queries return None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Current implementation raises AgentNodeError when provider not found
|
||||
from core.workflow.nodes.agent.exc import AgentNodeError
|
||||
|
||||
with pytest.raises(AgentNodeError, match="Tool provider with ID 'unknown-provider-id' not found"):
|
||||
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
def test_infer_type_default_when_no_provider_name(self):
|
||||
"""Test defaulting to BUILT_IN when provider_name is missing."""
|
||||
tool_config = {}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.BUILT_IN
|
||||
|
||||
def test_infer_type_database_exception_propagates(self):
|
||||
"""Test that database exception propagates (current implementation doesn't catch it)."""
|
||||
tool_config = {
|
||||
"provider_name": "provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||
mock_session = MagicMock()
|
||||
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Database query raises exception
|
||||
mock_session.scalar.side_effect = Exception("Database error")
|
||||
|
||||
# Current implementation doesn't catch exceptions, so it propagates
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
Reference in New Issue
Block a user