resolve: conflict

This commit is contained in:
crazywoola
2026-02-09 15:17:25 +08:00
parent f4d6383019
commit 481c707fab
78 changed files with 3470 additions and 971 deletions

View File

@ -0,0 +1,400 @@
"""
Unit tests for GraphBuilder.
Tests the automatic graph construction from node lists with dependency declarations.
"""
import pytest
from core.workflow.generator.utils.graph_builder import (
CyclicDependencyError,
GraphBuilder,
)
class TestGraphBuilderBasic:
"""Basic functionality tests."""
def test_empty_nodes_creates_minimal_workflow(self):
"""Empty node list creates start -> end workflow."""
result_nodes, result_edges = GraphBuilder.build_graph([])
assert len(result_nodes) == 2
assert result_nodes[0]["type"] == "start"
assert result_nodes[1]["type"] == "end"
assert len(result_edges) == 1
assert result_edges[0]["source"] == "start"
assert result_edges[0]["target"] == "end"
def test_simple_linear_workflow(self):
"""Simple linear workflow: start -> fetch -> process -> end."""
nodes = [
{"id": "fetch", "type": "http-request", "depends_on": []},
{"id": "process", "type": "llm", "depends_on": ["fetch"]},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should have: start + 2 user nodes + end = 4
assert len(result_nodes) == 4
assert result_nodes[0]["type"] == "start"
assert result_nodes[-1]["type"] == "end"
# Should have: start->fetch, fetch->process, process->end = 3
assert len(result_edges) == 3
# Verify edge connections
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
assert ("start", "fetch") in edge_pairs
assert ("fetch", "process") in edge_pairs
assert ("process", "end") in edge_pairs
class TestParallelWorkflow:
"""Tests for parallel node handling."""
def test_parallel_workflow(self):
"""Parallel workflow: multiple nodes from start, merging to one."""
nodes = [
{"id": "api1", "type": "http-request", "depends_on": []},
{"id": "api2", "type": "http-request", "depends_on": []},
{"id": "merge", "type": "llm", "depends_on": ["api1", "api2"]},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# start should connect to both api1 and api2
start_edges = [e for e in result_edges if e["source"] == "start"]
assert len(start_edges) == 2
start_targets = {e["target"] for e in start_edges}
assert start_targets == {"api1", "api2"}
# Both api1 and api2 should connect to merge
merge_incoming = [e for e in result_edges if e["target"] == "merge"]
assert len(merge_incoming) == 2
def test_multiple_terminal_nodes(self):
"""Multiple terminal nodes all connect to end."""
nodes = [
{"id": "branch1", "type": "llm", "depends_on": []},
{"id": "branch2", "type": "llm", "depends_on": []},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Both branches should connect to end
end_incoming = [e for e in result_edges if e["target"] == "end"]
assert len(end_incoming) == 2
class TestIfElseWorkflow:
"""Tests for if-else branching."""
def test_if_else_workflow(self):
"""Conditional branching workflow."""
nodes = [
{
"id": "check",
"type": "if-else",
"config": {"true_branch": "success", "false_branch": "fallback"},
"depends_on": [],
},
{"id": "success", "type": "llm", "depends_on": []},
{"id": "fallback", "type": "code", "depends_on": []},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should have true and false branch edges
branch_edges = [e for e in result_edges if e["source"] == "check"]
assert len(branch_edges) == 2
assert any(e.get("sourceHandle") == "true" for e in branch_edges)
assert any(e.get("sourceHandle") == "false" for e in branch_edges)
# Verify targets
true_edge = next(e for e in branch_edges if e.get("sourceHandle") == "true")
false_edge = next(e for e in branch_edges if e.get("sourceHandle") == "false")
assert true_edge["target"] == "success"
assert false_edge["target"] == "fallback"
def test_if_else_missing_branch_no_error(self):
"""if-else with only true branch doesn't error (warning only)."""
nodes = [
{
"id": "check",
"type": "if-else",
"config": {"true_branch": "success"},
"depends_on": [],
},
{"id": "success", "type": "llm", "depends_on": []},
]
# Should not raise
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should have one branch edge
branch_edges = [e for e in result_edges if e["source"] == "check"]
assert len(branch_edges) == 1
assert branch_edges[0].get("sourceHandle") == "true"
class TestQuestionClassifierWorkflow:
"""Tests for question-classifier branching."""
def test_question_classifier_workflow(self):
"""Question classifier with multiple classes."""
nodes = [
{
"id": "classifier",
"type": "question-classifier",
"config": {
"query": ["start", "user_input"],
"classes": [
{"id": "tech", "name": "技术问题", "target": "tech_handler"},
{"id": "sales", "name": "销售咨询", "target": "sales_handler"},
{"id": "other", "name": "其他问题", "target": "other_handler"},
],
},
"depends_on": [],
},
{"id": "tech_handler", "type": "llm", "depends_on": []},
{"id": "sales_handler", "type": "llm", "depends_on": []},
{"id": "other_handler", "type": "llm", "depends_on": []},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should have 3 branch edges from classifier
classifier_edges = [e for e in result_edges if e["source"] == "classifier"]
assert len(classifier_edges) == 3
# Each should use class id as sourceHandle
assert any(e.get("sourceHandle") == "tech" and e["target"] == "tech_handler" for e in classifier_edges)
assert any(e.get("sourceHandle") == "sales" and e["target"] == "sales_handler" for e in classifier_edges)
assert any(e.get("sourceHandle") == "other" and e["target"] == "other_handler" for e in classifier_edges)
def test_question_classifier_missing_target(self):
"""Classes without target connect to end."""
nodes = [
{
"id": "classifier",
"type": "question-classifier",
"config": {
"classes": [
{"id": "known", "name": "已知问题", "target": "handler"},
{"id": "unknown", "name": "未知问题"}, # Missing target
],
},
"depends_on": [],
},
{"id": "handler", "type": "llm", "depends_on": []},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Missing target should connect to end
classifier_edges = [e for e in result_edges if e["source"] == "classifier"]
assert any(e.get("sourceHandle") == "unknown" and e["target"] == "end" for e in classifier_edges)
class TestVariableDependencyInference:
"""Tests for automatic dependency inference from variables."""
def test_variable_dependency_inference(self):
"""Dependencies inferred from variable references."""
nodes = [
{"id": "fetch", "type": "http-request", "depends_on": []},
{
"id": "process",
"type": "llm",
"config": {"prompt_template": [{"text": "{{#fetch.body#}}"}]},
# No explicit depends_on, but references fetch
},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should automatically infer process depends on fetch
assert any(e["source"] == "fetch" and e["target"] == "process" for e in result_edges)
def test_system_variable_not_inferred(self):
"""System variables (sys, start) not inferred as dependencies."""
nodes = [
{
"id": "process",
"type": "llm",
"config": {"prompt_template": [{"text": "{{#sys.query#}} {{#start.input#}}"}]},
"depends_on": [],
},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should connect to start, not create dependency on sys or start
edge_sources = {e["source"] for e in result_edges}
assert "sys" not in edge_sources
assert "start" in edge_sources
class TestCycleDetection:
"""Tests for cyclic dependency detection."""
def test_cyclic_dependency_detected(self):
"""Cyclic dependencies raise error."""
nodes = [
{"id": "a", "type": "llm", "depends_on": ["c"]},
{"id": "b", "type": "llm", "depends_on": ["a"]},
{"id": "c", "type": "llm", "depends_on": ["b"]},
]
with pytest.raises(CyclicDependencyError):
GraphBuilder.build_graph(nodes)
def test_self_dependency_detected(self):
"""Self-dependency raises error."""
nodes = [
{"id": "a", "type": "llm", "depends_on": ["a"]},
]
with pytest.raises(CyclicDependencyError):
GraphBuilder.build_graph(nodes)
class TestErrorRecovery:
"""Tests for silent error recovery."""
def test_invalid_dependency_removed(self):
"""Invalid dependencies (non-existent nodes) are silently removed."""
nodes = [
{"id": "process", "type": "llm", "depends_on": ["nonexistent"]},
]
# Should not raise, invalid dependency silently removed
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Process should connect from start (since invalid dep was removed)
assert any(e["source"] == "start" and e["target"] == "process" for e in result_edges)
def test_depends_on_as_string(self):
"""depends_on as string is converted to list."""
nodes = [
{"id": "fetch", "type": "http-request", "depends_on": []},
{"id": "process", "type": "llm", "depends_on": "fetch"}, # String instead of list
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should work correctly
assert any(e["source"] == "fetch" and e["target"] == "process" for e in result_edges)
class TestContainerNodes:
"""Tests for container nodes (iteration, loop)."""
def test_iteration_node_as_regular_node(self):
"""Iteration nodes behave as regular single-in-single-out nodes."""
nodes = [
{"id": "prepare", "type": "code", "depends_on": []},
{
"id": "loop",
"type": "iteration",
"config": {"iterator_selector": ["prepare", "items"]},
"depends_on": ["prepare"],
},
{"id": "process_result", "type": "llm", "depends_on": ["loop"]},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should have standard edges: start->prepare, prepare->loop, loop->process_result, process_result->end
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
assert ("start", "prepare") in edge_pairs
assert ("prepare", "loop") in edge_pairs
assert ("loop", "process_result") in edge_pairs
assert ("process_result", "end") in edge_pairs
def test_loop_node_as_regular_node(self):
"""Loop nodes behave as regular single-in-single-out nodes."""
nodes = [
{"id": "init", "type": "code", "depends_on": []},
{
"id": "repeat",
"type": "loop",
"config": {"loop_count": 5},
"depends_on": ["init"],
},
{"id": "finish", "type": "llm", "depends_on": ["repeat"]},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Standard edge flow
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
assert ("init", "repeat") in edge_pairs
assert ("repeat", "finish") in edge_pairs
def test_iteration_with_variable_inference(self):
"""Iteration node dependencies can be inferred from iterator_selector."""
nodes = [
{"id": "data_source", "type": "http-request", "depends_on": []},
{
"id": "process_each",
"type": "iteration",
"config": {
"iterator_selector": ["data_source", "items"],
},
# No explicit depends_on, but references data_source
},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should infer dependency from iterator_selector reference
# Note: iterator_selector format is different from {{#...#}}, so this tests
# that explicit depends_on is properly handled when not provided
# In this case, process_each has no depends_on, so it connects to start
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
# Without explicit depends_on, connects to start
assert ("start", "process_each") in edge_pairs or ("data_source", "process_each") in edge_pairs
def test_loop_node_self_reference_not_cycle(self):
"""Loop nodes referencing their own outputs should not create cycle."""
nodes = [
{"id": "init", "type": "code", "depends_on": []},
{
"id": "my_loop",
"type": "loop",
"config": {
"loop_count": 5,
# Loop node referencing its own output (common pattern)
"prompt": "Previous: {{#my_loop.output#}}, continue...",
},
"depends_on": ["init"],
},
{"id": "finish", "type": "llm", "depends_on": ["my_loop"]},
]
# Should NOT raise CyclicDependencyError
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Verify the graph is built correctly
assert len(result_nodes) == 5 # start + 3 + end
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
assert ("init", "my_loop") in edge_pairs
assert ("my_loop", "finish") in edge_pairs
class TestEdgeStructure:
"""Tests for edge structure correctness."""
def test_edge_has_required_fields(self):
"""Edges have all required fields."""
nodes = [
{"id": "node1", "type": "llm", "depends_on": []},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
for edge in result_edges:
assert "id" in edge
assert "source" in edge
assert "target" in edge
assert "sourceHandle" in edge
assert "targetHandle" in edge
def test_edge_id_unique(self):
"""Each edge has a unique ID."""
nodes = [
{"id": "a", "type": "llm", "depends_on": []},
{"id": "b", "type": "llm", "depends_on": []},
{"id": "c", "type": "llm", "depends_on": ["a", "b"]},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
edge_ids = [e["id"] for e in result_edges]
assert len(edge_ids) == len(set(edge_ids)) # All unique

View File

@ -0,0 +1,287 @@
"""
Unit tests for the Mermaid Generator.
Tests cover:
- Basic workflow rendering
- Reserved word handling ('end''end_node')
- Question classifier multi-branch edges
- If-else branch labels
- Edge validation and skipping
- Tool node formatting
"""
from core.workflow.generator.utils.mermaid_generator import generate_mermaid
class TestBasicWorkflow:
"""Tests for basic workflow Mermaid generation."""
def test_simple_start_end_workflow(self):
"""Test simple Start → End workflow."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "title": "Start"},
{"id": "end", "type": "end", "title": "End"},
],
"edges": [{"source": "start", "target": "end"}],
}
result = generate_mermaid(workflow_data)
assert "flowchart TD" in result
assert 'start["type=start|title=Start"]' in result
assert 'end_node["type=end|title=End"]' in result
assert "start --> end_node" in result
def test_start_llm_end_workflow(self):
"""Test Start → LLM → End workflow."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "title": "Start"},
{"id": "llm", "type": "llm", "title": "Generate"},
{"id": "end", "type": "end", "title": "End"},
],
"edges": [
{"source": "start", "target": "llm"},
{"source": "llm", "target": "end"},
],
}
result = generate_mermaid(workflow_data)
assert 'llm["type=llm|title=Generate"]' in result
assert "start --> llm" in result
assert "llm --> end_node" in result
def test_empty_workflow(self):
"""Test empty workflow returns minimal output."""
workflow_data = {"nodes": [], "edges": []}
result = generate_mermaid(workflow_data)
assert result == "flowchart TD"
def test_missing_keys_handled(self):
"""Test workflow with missing keys doesn't crash."""
workflow_data = {}
result = generate_mermaid(workflow_data)
assert "flowchart TD" in result
class TestReservedWords:
"""Tests for reserved word handling in node IDs."""
def test_end_node_id_is_replaced(self):
"""Test 'end' node ID is replaced with 'end_node'."""
workflow_data = {
"nodes": [{"id": "end", "type": "end", "title": "End"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
# Should use end_node instead of end
assert "end_node[" in result
assert '"type=end|title=End"' in result
def test_subgraph_node_id_is_replaced(self):
"""Test 'subgraph' node ID is replaced with 'subgraph_node'."""
workflow_data = {
"nodes": [{"id": "subgraph", "type": "code", "title": "Process"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "subgraph_node[" in result
def test_edge_uses_safe_ids(self):
"""Test edges correctly reference safe IDs after replacement."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "title": "Start"},
{"id": "end", "type": "end", "title": "End"},
],
"edges": [{"source": "start", "target": "end"}],
}
result = generate_mermaid(workflow_data)
# Edge should use end_node, not end
assert "start --> end_node" in result
assert "start --> end\n" not in result
class TestBranchEdges:
"""Tests for branching node edge labels."""
def test_question_classifier_source_handles(self):
"""Test question-classifier edges with sourceHandle labels."""
workflow_data = {
"nodes": [
{"id": "classifier", "type": "question-classifier", "title": "Classify"},
{"id": "refund", "type": "llm", "title": "Handle Refund"},
{"id": "inquiry", "type": "llm", "title": "Handle Inquiry"},
],
"edges": [
{"source": "classifier", "target": "refund", "sourceHandle": "refund"},
{"source": "classifier", "target": "inquiry", "sourceHandle": "inquiry"},
],
}
result = generate_mermaid(workflow_data)
assert "classifier -->|refund| refund" in result
assert "classifier -->|inquiry| inquiry" in result
def test_if_else_true_false_handles(self):
"""Test if-else edges with true/false labels."""
workflow_data = {
"nodes": [
{"id": "ifelse", "type": "if-else", "title": "Check"},
{"id": "yes_branch", "type": "llm", "title": "Yes"},
{"id": "no_branch", "type": "llm", "title": "No"},
],
"edges": [
{"source": "ifelse", "target": "yes_branch", "sourceHandle": "true"},
{"source": "ifelse", "target": "no_branch", "sourceHandle": "false"},
],
}
result = generate_mermaid(workflow_data)
assert "ifelse -->|true| yes_branch" in result
assert "ifelse -->|false| no_branch" in result
def test_source_handle_source_is_ignored(self):
"""Test sourceHandle='source' doesn't add label."""
workflow_data = {
"nodes": [
{"id": "llm1", "type": "llm", "title": "LLM 1"},
{"id": "llm2", "type": "llm", "title": "LLM 2"},
],
"edges": [{"source": "llm1", "target": "llm2", "sourceHandle": "source"}],
}
result = generate_mermaid(workflow_data)
# Should be plain arrow without label
assert "llm1 --> llm2" in result
assert "llm1 -->|source|" not in result
class TestEdgeValidation:
"""Tests for edge validation and error handling."""
def test_edge_with_missing_source_is_skipped(self):
"""Test edge with non-existent source node is skipped."""
workflow_data = {
"nodes": [{"id": "end", "type": "end", "title": "End"}],
"edges": [{"source": "nonexistent", "target": "end"}],
}
result = generate_mermaid(workflow_data)
# Should not contain the invalid edge
assert "nonexistent" not in result
assert "-->" not in result or "nonexistent" not in result
def test_edge_with_missing_target_is_skipped(self):
"""Test edge with non-existent target node is skipped."""
workflow_data = {
"nodes": [{"id": "start", "type": "start", "title": "Start"}],
"edges": [{"source": "start", "target": "nonexistent"}],
}
result = generate_mermaid(workflow_data)
# Edge should be skipped
assert "start --> nonexistent" not in result
def test_edge_without_source_or_target_is_skipped(self):
"""Test edge missing source or target is skipped."""
workflow_data = {
"nodes": [{"id": "start", "type": "start", "title": "Start"}],
"edges": [{"source": "start"}, {"target": "start"}, {}],
}
result = generate_mermaid(workflow_data)
# No edges should be rendered
assert result.count("-->") == 0
class TestToolNodes:
"""Tests for tool node formatting."""
def test_tool_node_includes_tool_key(self):
"""Test tool node includes tool_key in label."""
workflow_data = {
"nodes": [
{
"id": "search",
"type": "tool",
"title": "Search",
"config": {"tool_key": "google/search"},
}
],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert 'search["type=tool|title=Search|tool=google/search"]' in result
def test_tool_node_with_tool_name_fallback(self):
"""Test tool node uses tool_name as fallback."""
workflow_data = {
"nodes": [
{
"id": "tool1",
"type": "tool",
"title": "My Tool",
"config": {"tool_name": "my_tool"},
}
],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "tool=my_tool" in result
def test_tool_node_missing_tool_key_shows_unknown(self):
"""Test tool node without tool_key shows 'unknown'."""
workflow_data = {
"nodes": [{"id": "tool1", "type": "tool", "title": "Tool", "config": {}}],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "tool=unknown" in result
class TestNodeFormatting:
"""Tests for node label formatting."""
def test_quotes_in_title_are_escaped(self):
"""Test double quotes in title are replaced with single quotes."""
workflow_data = {
"nodes": [{"id": "llm", "type": "llm", "title": 'Say "Hello"'}],
"edges": [],
}
result = generate_mermaid(workflow_data)
# Double quotes should be replaced
assert "Say 'Hello'" in result
assert 'Say "Hello"' not in result
def test_node_without_id_is_skipped(self):
"""Test node without id is skipped."""
workflow_data = {
"nodes": [{"type": "llm", "title": "No ID"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
# Should only have flowchart header
lines = [line for line in result.split("\n") if line.strip()]
assert len(lines) == 1
def test_node_default_values(self):
"""Test node with missing type/title uses defaults."""
workflow_data = {
"nodes": [{"id": "node1"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "type=unknown" in result
assert "title=Untitled" in result

View File

@ -0,0 +1,81 @@
from core.workflow.generator.utils.node_repair import NodeRepair
class TestNodeRepair:
"""Tests for NodeRepair utility."""
def test_repair_if_else_valid_operators(self):
"""Test that valid operators remain unchanged."""
nodes = [
{
"id": "node1",
"type": "if-else",
"config": {
"cases": [
{
"conditions": [
{"comparison_operator": "", "value": "1"},
{"comparison_operator": "=", "value": "2"},
]
}
]
},
}
]
result = NodeRepair.repair(nodes)
assert result.was_repaired is False
assert result.nodes == nodes
def test_repair_if_else_invalid_operators(self):
"""Test that invalid operators are normalized."""
nodes = [
{
"id": "node1",
"type": "if-else",
"config": {
"cases": [
{
"conditions": [
{"comparison_operator": ">=", "value": "1"},
{"comparison_operator": "<=", "value": "2"},
{"comparison_operator": "!=", "value": "3"},
{"comparison_operator": "==", "value": "4"},
]
}
]
},
}
]
result = NodeRepair.repair(nodes)
assert result.was_repaired is True
assert len(result.repairs_made) == 4
conditions = result.nodes[0]["config"]["cases"][0]["conditions"]
assert conditions[0]["comparison_operator"] == ""
assert conditions[1]["comparison_operator"] == ""
assert conditions[2]["comparison_operator"] == ""
assert conditions[3]["comparison_operator"] == "="
def test_repair_ignores_other_nodes(self):
"""Test that other node types are ignored."""
nodes = [{"id": "node1", "type": "llm", "config": {"some_field": ">="}}]
result = NodeRepair.repair(nodes)
assert result.was_repaired is False
assert result.nodes[0]["config"]["some_field"] == ">="
def test_repair_handles_missing_config(self):
"""Test robustness against missing fields."""
nodes = [
{
"id": "node1",
"type": "if-else",
# Missing config
},
{
"id": "node2",
"type": "if-else",
"config": {}, # Missing cases
},
]
result = NodeRepair.repair(nodes)
assert result.was_repaired is False

View File

@ -0,0 +1,99 @@
"""
Tests for node schemas validation.
Ensures that the node configuration stays in sync with registered node types.
"""
from core.workflow.generator.config.node_schemas import (
get_builtin_node_schemas,
validate_node_schemas,
)
class TestNodeSchemasValidation:
"""Tests for node schema validation utilities."""
def test_validate_node_schemas_returns_no_warnings(self):
"""Ensure all registered node types have corresponding schemas."""
warnings = validate_node_schemas()
# If this test fails, it means a new node type was added but
# no schema was defined for it in node_schemas.py
assert len(warnings) == 0, (
f"Missing schemas for node types: {warnings}. "
"Please add schemas for these node types in node_schemas.py "
"or add them to _INTERNAL_NODE_TYPES if they don't need schemas."
)
def test_builtin_node_schemas_not_empty(self):
"""Ensure BUILTIN_NODE_SCHEMAS contains expected node types."""
# get_builtin_node_schemas() includes dynamic schemas
all_schemas = get_builtin_node_schemas()
assert len(all_schemas) > 0
# Core node types should always be present
expected_types = ["llm", "code", "http-request", "if-else"]
for node_type in expected_types:
assert node_type in all_schemas, f"Missing schema for core node type: {node_type}"
def test_schema_structure(self):
"""Ensure each schema has required fields."""
all_schemas = get_builtin_node_schemas()
for node_type, schema in all_schemas.items():
assert "description" in schema, f"Missing 'description' in schema for {node_type}"
# 'parameters' is optional but if present should be a dict
if "parameters" in schema:
assert isinstance(schema["parameters"], dict), (
f"'parameters' in schema for {node_type} should be a dict"
)
class TestNodeSchemasMerged:
"""Tests to verify the merged configuration works correctly."""
def test_fallback_rules_available(self):
"""Ensure FALLBACK_RULES is available from node_schemas."""
from core.workflow.generator.config.node_schemas import FALLBACK_RULES
assert len(FALLBACK_RULES) > 0
assert "http-request" in FALLBACK_RULES
assert "code" in FALLBACK_RULES
assert "llm" in FALLBACK_RULES
def test_node_type_aliases_available(self):
"""Ensure NODE_TYPE_ALIASES is available from node_schemas."""
from core.workflow.generator.config.node_schemas import NODE_TYPE_ALIASES
assert len(NODE_TYPE_ALIASES) > 0
assert NODE_TYPE_ALIASES.get("gpt") == "llm"
assert NODE_TYPE_ALIASES.get("api") == "http-request"
def test_field_name_corrections_available(self):
"""Ensure FIELD_NAME_CORRECTIONS is available from node_schemas."""
from core.workflow.generator.config.node_schemas import (
FIELD_NAME_CORRECTIONS,
get_corrected_field_name,
)
assert len(FIELD_NAME_CORRECTIONS) > 0
# Test the helper function
assert get_corrected_field_name("http-request", "text") == "body"
assert get_corrected_field_name("llm", "response") == "text"
assert get_corrected_field_name("code", "unknown") == "unknown"
def test_config_init_exports(self):
"""Ensure config __init__.py exports all needed symbols."""
from core.workflow.generator.config import (
BUILTIN_NODE_SCHEMAS,
FALLBACK_RULES,
FIELD_NAME_CORRECTIONS,
NODE_TYPE_ALIASES,
get_corrected_field_name,
validate_node_schemas,
)
# Just verify imports work
assert BUILTIN_NODE_SCHEMAS is not None
assert FALLBACK_RULES is not None
assert FIELD_NAME_CORRECTIONS is not None
assert NODE_TYPE_ALIASES is not None
assert callable(get_corrected_field_name)
assert callable(validate_node_schemas)

View File

@ -0,0 +1,172 @@
"""
Unit tests for the Planner Prompts.
Tests cover:
- Tool formatting for planner context
- Edge cases with missing fields
- Empty tool lists
"""
from core.workflow.generator.prompts.planner_prompts import format_tools_for_planner
class TestFormatToolsForPlanner:
"""Tests for format_tools_for_planner function."""
def test_empty_tools_returns_default_message(self):
"""Test empty tools list returns default message."""
result = format_tools_for_planner([])
assert result == "No external tools available."
def test_none_tools_returns_default_message(self):
"""Test None tools list returns default message."""
result = format_tools_for_planner(None)
assert result == "No external tools available."
def test_single_tool_formatting(self):
"""Test single tool is formatted correctly."""
tools = [
{
"provider_id": "google",
"tool_key": "search",
"tool_label": "Google Search",
"tool_description": "Search the web using Google",
}
]
result = format_tools_for_planner(tools)
assert "[google/search]" in result
assert "Google Search" in result
assert "Search the web using Google" in result
def test_multiple_tools_formatting(self):
"""Test multiple tools are formatted correctly."""
tools = [
{
"provider_id": "google",
"tool_key": "search",
"tool_label": "Search",
"tool_description": "Web search",
},
{
"provider_id": "slack",
"tool_key": "send_message",
"tool_label": "Send Message",
"tool_description": "Send a Slack message",
},
]
result = format_tools_for_planner(tools)
lines = result.strip().split("\n")
assert len(lines) == 2
assert "[google/search]" in result
assert "[slack/send_message]" in result
def test_tool_without_provider_uses_key_only(self):
"""Test tool without provider_id uses tool_key only."""
tools = [
{
"tool_key": "my_tool",
"tool_label": "My Tool",
"tool_description": "A custom tool",
}
]
result = format_tools_for_planner(tools)
# Should format as [my_tool] without provider prefix
assert "[my_tool]" in result
assert "My Tool" in result
def test_tool_with_tool_name_fallback(self):
"""Test tool uses tool_name when tool_key is missing."""
tools = [
{
"tool_name": "fallback_tool",
"description": "Fallback description",
}
]
result = format_tools_for_planner(tools)
assert "fallback_tool" in result
assert "Fallback description" in result
def test_tool_with_missing_description(self):
"""Test tool with missing description doesn't crash."""
tools = [
{
"provider_id": "test",
"tool_key": "tool1",
"tool_label": "Tool 1",
}
]
result = format_tools_for_planner(tools)
assert "[test/tool1]" in result
assert "Tool 1" in result
def test_tool_with_all_missing_fields(self):
"""Test tool with all fields missing uses defaults."""
tools = [{}]
result = format_tools_for_planner(tools)
# Should not crash, may produce minimal output
assert isinstance(result, str)
def test_tool_uses_provider_fallback(self):
"""Test tool uses 'provider' when 'provider_id' is missing."""
tools = [
{
"provider": "openai",
"tool_key": "dalle",
"tool_label": "DALL-E",
"tool_description": "Generate images",
}
]
result = format_tools_for_planner(tools)
assert "[openai/dalle]" in result
def test_tool_label_fallback_to_key(self):
"""Test tool_label falls back to tool_key when missing."""
tools = [
{
"provider_id": "test",
"tool_key": "my_key",
"tool_description": "Description here",
}
]
result = format_tools_for_planner(tools)
# Label should fallback to key
assert "my_key" in result
assert "Description here" in result
class TestPlannerPromptConstants:
"""Tests for planner prompt constant availability."""
def test_planner_system_prompt_exists(self):
"""Test PLANNER_SYSTEM_PROMPT is defined."""
from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT
assert PLANNER_SYSTEM_PROMPT is not None
assert len(PLANNER_SYSTEM_PROMPT) > 0
assert "{tools_summary}" in PLANNER_SYSTEM_PROMPT
def test_planner_user_prompt_exists(self):
"""Test PLANNER_USER_PROMPT is defined."""
from core.workflow.generator.prompts.planner_prompts import PLANNER_USER_PROMPT
assert PLANNER_USER_PROMPT is not None
assert "{instruction}" in PLANNER_USER_PROMPT
def test_planner_system_prompt_has_required_sections(self):
"""Test PLANNER_SYSTEM_PROMPT has required XML sections."""
from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT
assert "<role>" in PLANNER_SYSTEM_PROMPT
assert "<task>" in PLANNER_SYSTEM_PROMPT
assert "<available_tools>" in PLANNER_SYSTEM_PROMPT
assert "<response_format>" in PLANNER_SYSTEM_PROMPT

View File

@ -0,0 +1,510 @@
"""
Unit tests for the Validation Rule Engine.
Tests cover:
- Structure rules (required fields, types, formats)
- Semantic rules (variable references, edge connections)
- Reference rules (model exists, tool configured, dataset valid)
- ValidationEngine integration
"""
from core.workflow.generator.validation import (
ValidationContext,
ValidationEngine,
)
from core.workflow.generator.validation.rules import (
extract_variable_refs,
is_placeholder,
)
class TestPlaceholderDetection:
"""Tests for placeholder detection utility."""
def test_detects_please_select(self):
assert is_placeholder("PLEASE_SELECT_YOUR_MODEL") is True
def test_detects_your_prefix(self):
assert is_placeholder("YOUR_API_KEY") is True
def test_detects_todo(self):
assert is_placeholder("TODO: fill this in") is True
def test_detects_placeholder(self):
assert is_placeholder("PLACEHOLDER_VALUE") is True
def test_detects_example_prefix(self):
assert is_placeholder("EXAMPLE_URL") is True
def test_detects_replace_prefix(self):
assert is_placeholder("REPLACE_WITH_ACTUAL") is True
def test_case_insensitive(self):
assert is_placeholder("please_select") is True
assert is_placeholder("Please_Select") is True
def test_valid_values_not_detected(self):
assert is_placeholder("https://api.example.com") is False
assert is_placeholder("gpt-4") is False
assert is_placeholder("my_variable") is False
def test_non_string_returns_false(self):
assert is_placeholder(123) is False
assert is_placeholder(None) is False
assert is_placeholder(["list"]) is False
class TestVariableRefExtraction:
"""Tests for variable reference extraction."""
def test_extracts_simple_ref(self):
refs = extract_variable_refs("Hello {{#start.query#}}")
assert refs == [("start", "query")]
def test_extracts_multiple_refs(self):
refs = extract_variable_refs("{{#node1.output#}} and {{#node2.text#}}")
assert refs == [("node1", "output"), ("node2", "text")]
def test_extracts_nested_field(self):
refs = extract_variable_refs("{{#http_request.body#}}")
assert refs == [("http_request", "body")]
def test_no_refs_returns_empty(self):
refs = extract_variable_refs("No references here")
assert refs == []
def test_handles_malformed_refs(self):
refs = extract_variable_refs("{{#invalid}} and {{incomplete#}}")
assert refs == []
class TestValidationContext:
"""Tests for ValidationContext."""
def test_node_map_lookup(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start"},
{"id": "llm_1", "type": "llm"},
]
)
assert ctx.get_node("start") == {"id": "start", "type": "start"}
assert ctx.get_node("nonexistent") is None
def test_model_set(self):
ctx = ValidationContext(
available_models=[
{"provider": "openai", "model": "gpt-4"},
{"provider": "anthropic", "model": "claude-3"},
]
)
assert ctx.has_model("openai", "gpt-4") is True
assert ctx.has_model("anthropic", "claude-3") is True
assert ctx.has_model("openai", "gpt-3.5") is False
def test_tool_set(self):
ctx = ValidationContext(
available_tools=[
{"provider_id": "google", "tool_key": "search", "is_team_authorization": True},
{"provider_id": "slack", "tool_key": "send_message", "is_team_authorization": False},
]
)
assert ctx.has_tool("google/search") is True
assert ctx.has_tool("search") is True
assert ctx.is_tool_configured("google/search") is True
assert ctx.is_tool_configured("slack/send_message") is False
def test_upstream_downstream_nodes(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start"},
{"id": "llm", "type": "llm"},
{"id": "end", "type": "end"},
],
edges=[
{"source": "start", "target": "llm"},
{"source": "llm", "target": "end"},
],
)
assert ctx.get_upstream_nodes("llm") == ["start"]
assert ctx.get_downstream_nodes("llm") == ["end"]
class TestStructureRules:
"""Tests for structure validation rules."""
def test_llm_missing_prompt_template(self):
ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
assert result.has_errors
errors = [e for e in result.all_errors if e.rule_id == "llm.prompt_template.required"]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_llm_with_prompt_template_passes(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {
"prompt_template": [
{"role": "system", "text": "You are helpful"},
{"role": "user", "text": "Hello"},
]
},
}
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
# No prompt_template errors
errors = [e for e in result.all_errors if "prompt_template" in e.rule_id]
assert len(errors) == 0
def test_http_request_missing_url(self):
ctx = ValidationContext(nodes=[{"id": "http_1", "type": "http-request", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "http.url" in e.rule_id]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_http_request_placeholder_url(self):
ctx = ValidationContext(
nodes=[
{
"id": "http_1",
"type": "http-request",
"config": {"url": "PLEASE_SELECT_YOUR_URL", "method": "GET"},
}
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "placeholder" in e.rule_id]
assert len(errors) == 1
def test_code_node_missing_fields(self):
ctx = ValidationContext(nodes=[{"id": "code_1", "type": "code", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
error_rules = {e.rule_id for e in result.all_errors}
assert "code.code.required" in error_rules
assert "code.language.required" in error_rules
def test_knowledge_retrieval_missing_dataset(self):
ctx = ValidationContext(nodes=[{"id": "kb_1", "type": "knowledge-retrieval", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "knowledge.dataset" in e.rule_id]
assert len(errors) == 1
assert errors[0].is_fixable is False # User must configure
class TestSemanticRules:
"""Tests for semantic validation rules."""
def test_valid_variable_reference(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Process: {{#start.query#}}"}]},
},
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
# No variable reference errors
errors = [e for e in result.all_errors if "variable.ref" in e.rule_id]
assert len(errors) == 0
def test_invalid_variable_reference(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Process: {{#nonexistent.field#}}"}]},
},
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "variable.ref" in e.rule_id]
assert len(errors) == 1
assert "nonexistent" in errors[0].message
def test_edge_validation(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{"id": "end", "type": "end", "config": {}},
],
edges=[
{"source": "start", "target": "end"},
{"source": "nonexistent", "target": "end"},
],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "edge" in e.rule_id]
assert len(errors) == 1
assert "nonexistent" in errors[0].message
class TestReferenceRules:
"""Tests for reference validation rules (models, tools)."""
def test_llm_missing_model_with_available(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "model.required"]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_llm_missing_model_no_available(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
}
],
available_models=[], # No models available
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "model.no_available"]
assert len(errors) == 1
assert errors[0].is_fixable is False
def test_llm_with_valid_model(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {
"prompt_template": [{"role": "user", "text": "Hi"}],
"model": {"provider": "openai", "name": "gpt-4"},
},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "model" in e.rule_id]
assert len(errors) == 0
def test_llm_with_invalid_model(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {
"prompt_template": [{"role": "user", "text": "Hi"}],
"model": {"provider": "openai", "name": "gpt-99"},
},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "model.not_found"]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_tool_node_not_found(self):
ctx = ValidationContext(
nodes=[
{
"id": "tool_1",
"type": "tool",
"config": {"tool_key": "nonexistent/tool"},
}
],
available_tools=[],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "tool.not_found"]
assert len(errors) == 1
def test_tool_node_not_configured(self):
ctx = ValidationContext(
nodes=[
{
"id": "tool_1",
"type": "tool",
"config": {"tool_key": "google/search"},
}
],
available_tools=[{"provider_id": "google", "tool_key": "search", "is_team_authorization": False}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "tool.not_configured"]
assert len(errors) == 1
assert errors[0].is_fixable is False
class TestValidationResult:
"""Tests for ValidationResult classification."""
def test_has_errors(self):
ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
assert result.has_errors is True
assert result.is_valid is False
def test_has_fixable_errors(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
assert result.has_fixable_errors is True
assert len(result.fixable_errors) > 0
def test_get_fixable_by_node(self):
ctx = ValidationContext(
nodes=[
{"id": "llm_1", "type": "llm", "config": {}},
{"id": "http_1", "type": "http-request", "config": {}},
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
by_node = result.get_fixable_by_node()
assert "llm_1" in by_node
assert "http_1" in by_node
def test_to_dict(self):
ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
d = result.to_dict()
assert "fixable" in d
assert "user_required" in d
assert "warnings" in d
assert "all_warnings" in d
assert "stats" in d
class TestIntegration:
"""Integration tests for the full validation pipeline."""
def test_complete_workflow_validation(self):
"""Test validation of a complete workflow."""
ctx = ValidationContext(
nodes=[
{
"id": "start",
"type": "start",
"config": {"variables": [{"variable": "query", "type": "text-input"}]},
},
{
"id": "llm_1",
"type": "llm",
"config": {
"model": {"provider": "openai", "name": "gpt-4"},
"prompt_template": [{"role": "user", "text": "{{#start.query#}}"}],
},
},
{
"id": "end",
"type": "end",
"config": {"outputs": [{"variable": "result", "value_selector": ["llm_1", "text"]}]},
},
],
edges=[
{"source": "start", "target": "llm_1"},
{"source": "llm_1", "target": "end"},
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
# Should have no errors
assert result.is_valid is True
assert len(result.fixable_errors) == 0
assert len(result.user_required_errors) == 0
def test_workflow_with_multiple_errors(self):
"""Test workflow with multiple types of errors."""
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{
"id": "llm_1",
"type": "llm",
"config": {}, # Missing prompt_template and model
},
{
"id": "kb_1",
"type": "knowledge-retrieval",
"config": {"dataset_ids": ["PLEASE_SELECT_YOUR_DATASET"]},
},
{"id": "end", "type": "end", "config": {}},
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
# Should have multiple errors
assert result.has_errors is True
assert len(result.fixable_errors) >= 2 # model, prompt_template
assert len(result.user_required_errors) >= 1 # dataset placeholder
# Check stats
assert result.stats["total_nodes"] == 4
assert result.stats["total_errors"] >= 3

View File

@ -0,0 +1,197 @@
from unittest.mock import MagicMock, patch
import pytest
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.agent.agent_node import AgentNode
class TestInferToolProviderType:
"""Test cases for AgentNode._infer_tool_provider_type method."""
def test_infer_type_from_config_workflow(self):
"""Test inferring workflow provider type from config."""
tool_config = {
"type": "workflow",
"provider_name": "workflow-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.WORKFLOW
def test_infer_type_from_config_builtin(self):
"""Test inferring builtin provider type from config."""
tool_config = {
"type": "builtin",
"provider_name": "builtin-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.BUILT_IN
def test_infer_type_from_config_api(self):
"""Test inferring API provider type from config."""
tool_config = {
"type": "api",
"provider_name": "api-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.API
def test_infer_type_from_config_mcp(self):
"""Test inferring MCP provider type from config."""
tool_config = {
"type": "mcp",
"provider_name": "mcp-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.MCP
def test_infer_type_invalid_config_value_raises_error(self):
"""Test that invalid type value in config raises ValueError."""
tool_config = {
"type": "invalid-type",
"provider_name": "workflow-provider-id",
}
tenant_id = "test-tenant"
with pytest.raises(ValueError):
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
def test_infer_workflow_type_from_database(self):
"""Test inferring workflow provider type from database."""
tool_config = {
"provider_name": "workflow-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First query (WorkflowToolProvider) returns a result
mock_session.scalar.return_value = True
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.WORKFLOW
# Should only query once (after finding WorkflowToolProvider)
assert mock_session.scalar.call_count == 1
def test_infer_mcp_type_from_database(self):
"""Test inferring MCP provider type from database."""
tool_config = {
"provider_name": "mcp-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First query (WorkflowToolProvider) returns None
# Second query (MCPToolProvider) returns a result
mock_session.scalar.side_effect = [None, True]
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.MCP
assert mock_session.scalar.call_count == 2
def test_infer_api_type_from_database(self):
"""Test inferring API provider type from database."""
tool_config = {
"provider_name": "api-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First query (WorkflowToolProvider) returns None
# Second query (MCPToolProvider) returns None
# Third query (ApiToolProvider) returns a result
mock_session.scalar.side_effect = [None, None, True]
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.API
assert mock_session.scalar.call_count == 3
def test_infer_builtin_type_from_database(self):
"""Test inferring builtin provider type from database."""
tool_config = {
"provider_name": "builtin-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First three queries return None
# Fourth query (BuiltinToolProvider) returns a result
mock_session.scalar.side_effect = [None, None, None, True]
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.BUILT_IN
assert mock_session.scalar.call_count == 4
def test_infer_type_default_when_not_found(self):
"""Test raising AgentNodeError when provider is not found in database."""
tool_config = {
"provider_name": "unknown-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# All queries return None
mock_session.scalar.return_value = None
# Current implementation raises AgentNodeError when provider not found
from core.workflow.nodes.agent.exc import AgentNodeError
with pytest.raises(AgentNodeError, match="Tool provider with ID 'unknown-provider-id' not found"):
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
def test_infer_type_default_when_no_provider_name(self):
"""Test defaulting to BUILT_IN when provider_name is missing."""
tool_config = {}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.BUILT_IN
def test_infer_type_database_exception_propagates(self):
"""Test that database exception propagates (current implementation doesn't catch it)."""
tool_config = {
"provider_name": "provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# Database query raises exception
mock_session.scalar.side_effect = Exception("Database error")
# Current implementation doesn't catch exceptions, so it propagates
with pytest.raises(Exception, match="Database error"):
AgentNode._infer_tool_provider_type(tool_config, tenant_id)