mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
chore: improve type checking
This commit is contained in:
@ -7,9 +7,8 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.workflow.entities import AgentNodeStrategyInit
|
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
|
||||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||||
from core.workflow.graph_events import ToolCall, ToolResult
|
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,6 @@ from sqlalchemy import select
|
|||||||
|
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
@ -90,6 +89,8 @@ class DatasetIndexToolCallbackHandler:
|
|||||||
# TODO(-LAN-): Improve type check
|
# TODO(-LAN-): Improve type check
|
||||||
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
|
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
|
||||||
"""Handle return_retriever_resource_info."""
|
"""Handle return_retriever_resource_info."""
|
||||||
|
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||||
|
|
||||||
self._queue_manager.publish(
|
self._queue_manager.publish(
|
||||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|||||||
@ -0,0 +1,148 @@
|
|||||||
|
import types
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.workflow.entities import ToolCallResult
|
||||||
|
from core.workflow.entities.tool_entities import ToolResultStatus
|
||||||
|
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeEventBase
|
||||||
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
|
|
||||||
|
|
||||||
|
class _StubModelInstance:
|
||||||
|
"""Minimal stub to satisfy _stream_llm_events signature."""
|
||||||
|
|
||||||
|
provider_model_bundle = None
|
||||||
|
|
||||||
|
|
||||||
|
def _drain(generator: Generator[NodeEventBase, None, Any]):
|
||||||
|
events: list = []
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
events.append(next(generator))
|
||||||
|
except StopIteration as exc:
|
||||||
|
return events, exc.value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_deduct_llm_quota(monkeypatch):
|
||||||
|
# Avoid touching real quota logic during unit tests
|
||||||
|
monkeypatch.setattr("core.workflow.nodes.llm.node.llm_utils.deduct_llm_quota", lambda **_: None)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_llm_node(reasoning_format: str) -> LLMNode:
|
||||||
|
node = LLMNode.__new__(LLMNode)
|
||||||
|
object.__setattr__(node, "_node_data", types.SimpleNamespace(reasoning_format=reasoning_format, tools=[]))
|
||||||
|
object.__setattr__(node, "tenant_id", "tenant")
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_llm_events_extracts_reasoning_for_tagged():
|
||||||
|
node = _make_llm_node(reasoning_format="tagged")
|
||||||
|
tagged_text = "<think>Thought</think>Answer"
|
||||||
|
usage = LLMUsage.empty_usage()
|
||||||
|
|
||||||
|
def generator():
|
||||||
|
yield ModelInvokeCompletedEvent(
|
||||||
|
text=tagged_text,
|
||||||
|
usage=usage,
|
||||||
|
finish_reason="stop",
|
||||||
|
reasoning_content="",
|
||||||
|
structured_output=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
events, returned = _drain(
|
||||||
|
node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None))
|
||||||
|
)
|
||||||
|
|
||||||
|
assert events == []
|
||||||
|
clean_text, reasoning_content, gen_reasoning, gen_clean, ret_usage, finish_reason, structured, gen_data = returned
|
||||||
|
assert clean_text == tagged_text # original preserved for output
|
||||||
|
assert reasoning_content == "" # tagged mode keeps reasoning separate
|
||||||
|
assert gen_clean == "Answer" # stripped content for generation
|
||||||
|
assert gen_reasoning == "Thought" # reasoning extracted from <think> tag
|
||||||
|
assert ret_usage == usage
|
||||||
|
assert finish_reason == "stop"
|
||||||
|
assert structured is None
|
||||||
|
assert gen_data is None
|
||||||
|
|
||||||
|
# generation building should include reasoning and sequence
|
||||||
|
generation_content = gen_clean or clean_text
|
||||||
|
sequence = [
|
||||||
|
{"type": "reasoning", "index": 0},
|
||||||
|
{"type": "content", "start": 0, "end": len(generation_content)},
|
||||||
|
]
|
||||||
|
assert sequence == [
|
||||||
|
{"type": "reasoning", "index": 0},
|
||||||
|
{"type": "content", "start": 0, "end": len("Answer")},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_llm_events_no_reasoning_results_in_empty_sequence():
|
||||||
|
node = _make_llm_node(reasoning_format="tagged")
|
||||||
|
plain_text = "Hello world"
|
||||||
|
usage = LLMUsage.empty_usage()
|
||||||
|
|
||||||
|
def generator():
|
||||||
|
yield ModelInvokeCompletedEvent(
|
||||||
|
text=plain_text,
|
||||||
|
usage=usage,
|
||||||
|
finish_reason=None,
|
||||||
|
reasoning_content="",
|
||||||
|
structured_output=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
events, returned = _drain(
|
||||||
|
node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None))
|
||||||
|
)
|
||||||
|
|
||||||
|
assert events == []
|
||||||
|
_, _, gen_reasoning, gen_clean, *_ = returned
|
||||||
|
generation_content = gen_clean or plain_text
|
||||||
|
assert gen_reasoning == ""
|
||||||
|
assert generation_content == plain_text
|
||||||
|
# Empty reasoning should imply empty sequence in generation construction
|
||||||
|
sequence = []
|
||||||
|
assert sequence == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_tool_call_strips_files_to_ids():
|
||||||
|
file_cls = pytest.importorskip("core.file").File
|
||||||
|
file_type = pytest.importorskip("core.file.enums").FileType
|
||||||
|
transfer_method = pytest.importorskip("core.file.enums").FileTransferMethod
|
||||||
|
|
||||||
|
file_with_id = file_cls(
|
||||||
|
id="f1",
|
||||||
|
tenant_id="t",
|
||||||
|
type=file_type.IMAGE,
|
||||||
|
transfer_method=transfer_method.REMOTE_URL,
|
||||||
|
remote_url="http://example.com/f1",
|
||||||
|
storage_key="k1",
|
||||||
|
)
|
||||||
|
file_with_related = file_cls(
|
||||||
|
id=None,
|
||||||
|
tenant_id="t",
|
||||||
|
type=file_type.IMAGE,
|
||||||
|
transfer_method=transfer_method.REMOTE_URL,
|
||||||
|
related_id="rel2",
|
||||||
|
remote_url="http://example.com/f2",
|
||||||
|
storage_key="k2",
|
||||||
|
)
|
||||||
|
tool_call = ToolCallResult(
|
||||||
|
id="tc",
|
||||||
|
name="do",
|
||||||
|
arguments='{"a":1}',
|
||||||
|
output="ok",
|
||||||
|
files=[file_with_id, file_with_related],
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
)
|
||||||
|
|
||||||
|
serialized = LLMNode._serialize_tool_call(tool_call)
|
||||||
|
|
||||||
|
assert serialized["files"] == ["f1", "rel2"]
|
||||||
|
assert serialized["id"] == "tc"
|
||||||
|
assert serialized["name"] == "do"
|
||||||
|
assert serialized["arguments"] == '{"a":1}'
|
||||||
|
assert serialized["output"] == "ok"
|
||||||
Reference in New Issue
Block a user