mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
Merge branch 'feat/queue-based-graph-engine' into feat/rag-2
# Conflicts: # api/core/app/apps/advanced_chat/generate_task_pipeline.py # api/pyproject.toml # api/uv.lock # docker/docker-compose-template.yaml # docker/docker-compose.yaml # web/package.json
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import jsonschema
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
@ -434,7 +435,7 @@ class TestUtilityFunctions:
|
||||
assert parameters["category"]["enum"] == ["A", "B", "C"]
|
||||
|
||||
assert "count" in parameters
|
||||
assert parameters["count"]["type"] == "float"
|
||||
assert parameters["count"]["type"] == "number"
|
||||
|
||||
# FILE type should be skipped - it creates empty dict but gets filtered later
|
||||
# Check that it doesn't have any meaningful content
|
||||
@ -447,3 +448,65 @@ class TestUtilityFunctions:
|
||||
assert "category" not in required
|
||||
|
||||
# Note: _get_request_id function has been removed as request_id is now passed as parameter
|
||||
|
||||
def test_convert_input_form_to_parameters_jsonschema_validation_ok(self):
|
||||
"""Current schema uses 'number' for numeric fields; it should be a valid JSON Schema."""
|
||||
user_input_form = [
|
||||
VariableEntity(
|
||||
type=VariableEntityType.NUMBER,
|
||||
variable="count",
|
||||
description="Count",
|
||||
label="Count",
|
||||
required=True,
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
variable="name",
|
||||
description="User name",
|
||||
label="Name",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
parameters_dict = {
|
||||
"count": "Enter count",
|
||||
"name": "Enter your name",
|
||||
}
|
||||
|
||||
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||
|
||||
# Build a complete JSON Schema
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
}
|
||||
|
||||
# 1) The schema itself must be valid
|
||||
jsonschema.Draft202012Validator.check_schema(schema)
|
||||
|
||||
# 2) Both float and integer instances should pass validation
|
||||
jsonschema.validate(instance={"count": 3.14, "name": "alice"}, schema=schema)
|
||||
jsonschema.validate(instance={"count": 2, "name": "bob"}, schema=schema)
|
||||
|
||||
def test_legacy_float_type_schema_is_invalid(self):
|
||||
"""Legacy/buggy behavior: using 'float' should produce an invalid JSON Schema."""
|
||||
# Manually construct a legacy/incorrect schema (simulating old behavior)
|
||||
bad_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "float", # Invalid type: JSON Schema does not support 'float'
|
||||
"description": "Enter count",
|
||||
}
|
||||
},
|
||||
"required": ["count"],
|
||||
}
|
||||
|
||||
# The schema itself should raise a SchemaError
|
||||
with pytest.raises(jsonschema.exceptions.SchemaError):
|
||||
jsonschema.Draft202012Validator.check_schema(bad_schema)
|
||||
|
||||
# Or validation should also raise SchemaError
|
||||
with pytest.raises(jsonschema.exceptions.SchemaError):
|
||||
jsonschema.validate(instance={"count": 1.23}, schema=bad_schema)
|
||||
|
||||
@ -0,0 +1,111 @@
|
||||
from time import time
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
class TestGraphRuntimeState:
|
||||
def test_property_getters_and_setters(self):
|
||||
# FIXME(-LAN-): Mock VariablePool if needed
|
||||
variable_pool = VariablePool()
|
||||
start_time = time()
|
||||
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=start_time)
|
||||
|
||||
# Test variable_pool property (read-only)
|
||||
assert state.variable_pool == variable_pool
|
||||
|
||||
# Test start_at property
|
||||
assert state.start_at == start_time
|
||||
new_time = time() + 100
|
||||
state.start_at = new_time
|
||||
assert state.start_at == new_time
|
||||
|
||||
# Test total_tokens property
|
||||
assert state.total_tokens == 0
|
||||
state.total_tokens = 100
|
||||
assert state.total_tokens == 100
|
||||
|
||||
# Test node_run_steps property
|
||||
assert state.node_run_steps == 0
|
||||
state.node_run_steps = 5
|
||||
assert state.node_run_steps == 5
|
||||
|
||||
def test_outputs_immutability(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test that getting outputs returns a copy
|
||||
outputs1 = state.outputs
|
||||
outputs2 = state.outputs
|
||||
assert outputs1 == outputs2
|
||||
assert outputs1 is not outputs2 # Different objects
|
||||
|
||||
# Test that modifying retrieved outputs doesn't affect internal state
|
||||
outputs = state.outputs
|
||||
outputs["test"] = "value"
|
||||
assert "test" not in state.outputs
|
||||
|
||||
# Test set_output method
|
||||
state.set_output("key1", "value1")
|
||||
assert state.get_output("key1") == "value1"
|
||||
|
||||
# Test update_outputs method
|
||||
state.update_outputs({"key2": "value2", "key3": "value3"})
|
||||
assert state.get_output("key2") == "value2"
|
||||
assert state.get_output("key3") == "value3"
|
||||
|
||||
def test_llm_usage_immutability(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test that getting llm_usage returns a copy
|
||||
usage1 = state.llm_usage
|
||||
usage2 = state.llm_usage
|
||||
assert usage1 is not usage2 # Different objects
|
||||
|
||||
def test_type_validation(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test total_tokens validation
|
||||
with pytest.raises(ValueError):
|
||||
state.total_tokens = -1
|
||||
|
||||
# Test node_run_steps validation
|
||||
with pytest.raises(ValueError):
|
||||
state.node_run_steps = -1
|
||||
|
||||
def test_helper_methods(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test increment_node_run_steps
|
||||
initial_steps = state.node_run_steps
|
||||
state.increment_node_run_steps()
|
||||
assert state.node_run_steps == initial_steps + 1
|
||||
|
||||
# Test add_tokens
|
||||
initial_tokens = state.total_tokens
|
||||
state.add_tokens(50)
|
||||
assert state.total_tokens == initial_tokens + 50
|
||||
|
||||
# Test add_tokens validation
|
||||
with pytest.raises(ValueError):
|
||||
state.add_tokens(-1)
|
||||
|
||||
def test_deep_copy_for_nested_objects(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test deep copy for nested dict
|
||||
nested_data = {"level1": {"level2": {"value": "test"}}}
|
||||
state.set_output("nested", nested_data)
|
||||
|
||||
retrieved = state.get_output("nested")
|
||||
retrieved["level1"]["level2"]["value"] = "modified"
|
||||
|
||||
# Original should remain unchanged
|
||||
assert state.get_output("nested")["level1"]["level2"]["value"] == "test"
|
||||
@ -498,10 +498,10 @@ def test_layer_system_basic():
|
||||
|
||||
def test_layer_chaining():
|
||||
"""Test chaining multiple layers."""
|
||||
from core.workflow.graph_engine.layers import DebugLoggingLayer, Layer
|
||||
from core.workflow.graph_engine.layers import DebugLoggingLayer, GraphEngineLayer
|
||||
|
||||
# Create a custom test layer
|
||||
class TestLayer(Layer):
|
||||
class TestLayer(GraphEngineLayer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.events_received = []
|
||||
@ -560,10 +560,10 @@ def test_layer_chaining():
|
||||
|
||||
def test_layer_error_handling():
|
||||
"""Test that layer errors don't crash the engine."""
|
||||
from core.workflow.graph_engine.layers import Layer
|
||||
from core.workflow.graph_engine.layers import GraphEngineLayer
|
||||
|
||||
# Create a layer that throws errors
|
||||
class FaultyLayer(Layer):
|
||||
class FaultyLayer(GraphEngineLayer):
|
||||
def on_graph_start(self):
|
||||
raise RuntimeError("Intentional error in on_graph_start")
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -135,15 +136,12 @@ class WorkflowRunner:
|
||||
raise ValueError(f"Fixtures directory does not exist: {self.fixtures_dir}")
|
||||
|
||||
def load_fixture(self, fixture_name: str) -> dict[str, Any]:
|
||||
"""Load a YAML fixture file."""
|
||||
"""Load a YAML fixture file with caching to avoid repeated parsing."""
|
||||
if not fixture_name.endswith(".yml") and not fixture_name.endswith(".yaml"):
|
||||
fixture_name = f"{fixture_name}.yml"
|
||||
|
||||
fixture_path = self.fixtures_dir / fixture_name
|
||||
if not fixture_path.exists():
|
||||
raise FileNotFoundError(f"Fixture file not found: {fixture_path}")
|
||||
|
||||
return load_yaml_file(str(fixture_path), ignore_error=False)
|
||||
return _load_fixture(fixture_path, fixture_name)
|
||||
|
||||
def create_graph_from_fixture(
|
||||
self,
|
||||
@ -709,3 +707,12 @@ class TableTestRunner:
|
||||
report.append("=" * 80)
|
||||
|
||||
return "\n".join(report)
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _load_fixture(fixture_path: Path, fixture_name: str) -> dict[str, Any]:
|
||||
"""Load a YAML fixture file with caching to avoid repeated parsing."""
|
||||
if not fixture_path.exists():
|
||||
raise FileNotFoundError(f"Fixture file not found: {fixture_path}")
|
||||
|
||||
return load_yaml_file(str(fixture_path), ignore_error=False)
|
||||
|
||||
Reference in New Issue
Block a user