Merge branch 'fix/chore-fix' into dev/plugin-deploy

This commit is contained in:
Yeuoly
2024-11-25 17:19:51 +08:00
250 changed files with 7636 additions and 1975 deletions

View File

@ -1,6 +1,6 @@
import uuid
from collections.abc import Generator
from datetime import datetime, timezone
from datetime import UTC, datetime, timezone
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
@ -29,7 +29,7 @@ def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngine
def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None))
route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(UTC).replace(tzinfo=None))
parallel_id = graph.node_parallel_mapping.get(next_node_id)
parallel_start_node_id = None
@ -68,7 +68,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
)
route_node_state.status = RouteNodeState.Status.SUCCESS
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
route_node_state.finished_at = datetime.now(UTC).replace(tzinfo=None)
yield NodeRunSucceededEvent(
id=node_execution_id,
node_id=next_node_id,

View File

@ -1,125 +1,431 @@
from collections.abc import Sequence
from typing import Optional
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
from core.file import File, FileTransferMethod, FileType
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageRole,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
VisionConfigOptions,
)
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom
from models.provider import ProviderType
from models.workflow import WorkflowType
from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario
class TestLLMNode:
@pytest.fixture
def llm_node(self):
data = LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
memory=None,
context=ContextConfig(enabled=False),
vision=VisionConfig(
enabled=True,
configs=VisionConfigOptions(
variable_selector=["sys", "files"],
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = LLMNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
return node
class MockTokenBufferMemory:
def __init__(self, history_messages=None):
self.history_messages = history_messages or []
def test_fetch_files_with_file_segment(self, llm_node):
file = File(
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> Sequence[PromptMessage]:
if message_limit is not None:
return self.history_messages[-message_limit * 2 :]
return self.history_messages
@pytest.fixture
def llm_node():
data = LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
memory=None,
context=ContextConfig(enabled=False),
vision=VisionConfig(
enabled=True,
configs=VisionConfigOptions(
variable_selector=["sys", "files"],
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = LLMNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
return node
@pytest.fixture
def model_config():
# Create actual provider and model type instances
model_provider_factory = ModelProviderFactory()
provider_instance = model_provider_factory.get_provider_instance("openai")
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
# Create a ProviderModelBundle
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
tenant_id="1",
provider=provider_instance.get_provider_schema(),
preferred_provider_type=ProviderType.CUSTOM,
using_provider_type=ProviderType.CUSTOM,
system_configuration=SystemConfiguration(enabled=False),
custom_configuration=CustomConfiguration(provider=None),
model_settings=[],
),
provider_instance=provider_instance,
model_type_instance=model_type_instance,
)
# Create and return a ModelConfigWithCredentialsEntity
return ModelConfigWithCredentialsEntity(
provider="openai",
model="gpt-3.5-turbo",
model_schema=AIModelEntity(
model="gpt-3.5-turbo",
label=I18nObject(en_US="GPT-3.5 Turbo"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={},
),
mode="chat",
credentials={},
parameters={},
provider_model_bundle=provider_model_bundle,
)
def test_fetch_files_with_file_segment(llm_node):
file = File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
)
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == [file]
def test_fetch_files_with_array_file_segment(llm_node):
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test.jpg",
filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
),
File(
id="2",
tenant_id="test",
type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == files
def test_fetch_files_with_none_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_files_with_array_any_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_files_with_non_existent_variable(llm_node):
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
prompt_template = []
llm_node.node_data.prompt_template = prompt_template
fake_vision_detail = faker.random_element(
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
)
fake_remote_url = faker.url()
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
]
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == [file]
fake_query = faker.sentence()
def test_fetch_files_with_array_file_segment(self, llm_node):
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
),
File(
id="2",
tenant_id="test",
type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=fake_query,
user_files=files,
context=None,
memory=None,
model_config=model_config,
prompt_template=prompt_template,
memory_config=None,
vision_enabled=False,
vision_detail=fake_vision_detail,
variable_pool=llm_node.graph_runtime_state.variable_pool,
jinja2_variables=[],
)
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == files
assert prompt_messages == [UserPromptMessage(content=fake_query)]
def test_fetch_files_with_none_segment(self, llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# Setup dify config
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
def test_fetch_files_with_array_any_segment(self, llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
# Generate fake values for prompt template
fake_assistant_prompt = faker.sentence()
fake_query = faker.sentence()
fake_context = faker.sentence()
fake_window_size = faker.random_int(min=1, max=3)
fake_vision_detail = faker.random_element(
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
)
fake_remote_url = faker.url()
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
# Setup mock memory with history messages
mock_history = [
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
]
def test_fetch_files_with_non_existent_variable(self, llm_node):
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
# Setup memory configuration
memory_config = MemoryConfig(
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size),
query_prompt_template=None,
)
memory = MockTokenBufferMemory(history_messages=mock_history)
# Test scenarios covering different file input combinations
test_scenarios = [
LLMNodeTestScenario(
description="No files",
user_query=fake_query,
user_files=[],
features=[],
vision_enabled=False,
vision_detail=None,
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text=fake_context,
role=PromptMessageRole.SYSTEM,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{#context#}",
role=PromptMessageRole.USER,
edition_type="basic",
),
LLMNodeChatModelMessage(
text=fake_assistant_prompt,
role=PromptMessageRole.ASSISTANT,
edition_type="basic",
),
],
expected_messages=[
SystemPromptMessage(content=fake_context),
UserPromptMessage(content=fake_context),
AssistantPromptMessage(content=fake_assistant_prompt),
]
+ mock_history[fake_window_size * -2 :]
+ [
UserPromptMessage(content=fake_query),
],
),
LLMNodeTestScenario(
description="User files",
user_query=fake_query,
user_files=[
File(
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
],
vision_enabled=True,
vision_detail=fake_vision_detail,
features=[ModelFeature.VISION],
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text=fake_context,
role=PromptMessageRole.SYSTEM,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{#context#}",
role=PromptMessageRole.USER,
edition_type="basic",
),
LLMNodeChatModelMessage(
text=fake_assistant_prompt,
role=PromptMessageRole.ASSISTANT,
edition_type="basic",
),
],
expected_messages=[
SystemPromptMessage(content=fake_context),
UserPromptMessage(content=fake_context),
AssistantPromptMessage(content=fake_assistant_prompt),
]
+ mock_history[fake_window_size * -2 :]
+ [
UserPromptMessage(
content=[
TextPromptMessageContent(data=fake_query),
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
]
),
],
),
LLMNodeTestScenario(
description="Prompt template with variable selector of File",
user_query=fake_query,
user_files=[],
vision_enabled=False,
vision_detail=fake_vision_detail,
features=[ModelFeature.VISION],
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text="{{#input.image#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
],
expected_messages=[
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
]
),
]
+ mock_history[fake_window_size * -2 :]
+ [UserPromptMessage(content=fake_query)],
file_variables={
"input.image": File(
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
},
),
]
for scenario in test_scenarios:
model_config.model_schema.features = scenario.features
for k, v in scenario.file_variables.items():
selector = k.split(".")
llm_node.graph_runtime_state.variable_pool.add(selector, v)
# Call the method under test
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=scenario.user_query,
user_files=scenario.user_files,
context=fake_context,
memory=memory,
model_config=model_config,
prompt_template=scenario.prompt_template,
memory_config=memory_config,
vision_enabled=scenario.vision_enabled,
vision_detail=scenario.vision_detail,
variable_pool=llm_node.graph_runtime_state.variable_pool,
jinja2_variables=[],
)
# Verify the result
assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
assert (
prompt_messages == scenario.expected_messages
), f"Message content mismatch in scenario: {scenario.description}"

View File

@ -0,0 +1,25 @@
from collections.abc import Mapping, Sequence
from pydantic import BaseModel, Field
from core.file import File
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelFeature
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage
class LLMNodeTestScenario(BaseModel):
"""Test scenario for LLM node testing."""
description: str = Field(..., description="Description of the test scenario")
user_query: str = Field(..., description="User query input")
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
window_size: int = Field(..., description="Window size for memory")
prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
file_variables: Mapping[str, File | Sequence[File]] = Field(
default_factory=dict, description="List of file variables"
)
expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing")

View File

@ -1,47 +0,0 @@
import pytest
from packaging import version
from services.app_dsl_service import AppDslService
from services.app_dsl_service.exc import DSLVersionNotSupportedError
from services.app_dsl_service.service import _check_or_fix_dsl, current_dsl_version
class TestAppDSLService:
@pytest.mark.skip(reason="Test skipped")
def test_check_or_fix_dsl_missing_version(self):
import_data = {}
result = _check_or_fix_dsl(import_data)
assert result["version"] == "0.1.0"
assert result["kind"] == "app"
@pytest.mark.skip(reason="Test skipped")
def test_check_or_fix_dsl_missing_kind(self):
import_data = {"version": "0.1.0"}
result = _check_or_fix_dsl(import_data)
assert result["kind"] == "app"
@pytest.mark.skip(reason="Test skipped")
def test_check_or_fix_dsl_older_version(self):
import_data = {"version": "0.0.9", "kind": "app"}
result = _check_or_fix_dsl(import_data)
assert result["version"] == "0.0.9"
@pytest.mark.skip(reason="Test skipped")
def test_check_or_fix_dsl_current_version(self):
import_data = {"version": current_dsl_version, "kind": "app"}
result = _check_or_fix_dsl(import_data)
assert result["version"] == current_dsl_version
@pytest.mark.skip(reason="Test skipped")
def test_check_or_fix_dsl_newer_version(self):
current_version = version.parse(current_dsl_version)
newer_version = f"{current_version.major}.{current_version.minor + 1}.0"
import_data = {"version": newer_version, "kind": "app"}
with pytest.raises(DSLVersionNotSupportedError):
_check_or_fix_dsl(import_data)
@pytest.mark.skip(reason="Test skipped")
def test_check_or_fix_dsl_invalid_kind(self):
import_data = {"version": current_dsl_version, "kind": "invalid"}
result = _check_or_fix_dsl(import_data)
assert result["kind"] == "app"