mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
refactor: decouple the business logic from datasource_node (#32515)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
135
api/tests/unit_tests/core/datasource/test_datasource_manager.py
Normal file
135
api/tests/unit_tests/core/datasource/test_datasource_manager.py
Normal file
@ -0,0 +1,135 @@
|
||||
import types
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent
|
||||
|
||||
|
||||
def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.TEXT,
|
||||
message=DatasourceMessage.TextMessage(text=text),
|
||||
meta=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_icon_url_calls_runtime(mocker):
|
||||
fake_runtime = mocker.Mock()
|
||||
fake_runtime.get_icon_url.return_value = "https://icon"
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=fake_runtime)
|
||||
|
||||
url = DatasourceManager.get_icon_url(
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
)
|
||||
assert url == "https://icon"
|
||||
DatasourceManager.get_datasource_runtime.assert_called_once()
|
||||
|
||||
|
||||
def test_stream_online_results_yields_messages_online_document(mocker):
|
||||
# stub runtime to yield a text message
|
||||
def _doc_messages(**_):
|
||||
yield from _gen_messages_text_only("hello")
|
||||
|
||||
fake_runtime = mocker.Mock()
|
||||
fake_runtime.get_online_document_page_content.side_effect = _doc_messages
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=fake_runtime)
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
gen = DatasourceManager.stream_online_results(
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
|
||||
online_drive_request=None,
|
||||
)
|
||||
msgs = list(gen)
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0].message.text == "hello"
|
||||
|
||||
|
||||
def test_stream_node_events_emits_events_online_document(mocker):
|
||||
# make manager's low-level stream produce TEXT only
|
||||
mocker.patch.object(
|
||||
DatasourceManager,
|
||||
"stream_online_results",
|
||||
return_value=_gen_messages_text_only("hello"),
|
||||
)
|
||||
|
||||
events = list(
|
||||
DatasourceManager.stream_node_events(
|
||||
node_id="nodeA",
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
parameters_for_log={"k": "v"},
|
||||
datasource_info={"user_id": "u1"},
|
||||
variable_pool=mocker.Mock(),
|
||||
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
# should contain one StreamChunkEvent then a final chunk (empty) and a completed event
|
||||
assert isinstance(events[0], StreamChunkEvent)
|
||||
assert events[0].chunk == "hello"
|
||||
assert isinstance(events[-1], StreamCompletedEvent)
|
||||
assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
||||
|
||||
def test_get_upload_file_by_id_builds_file(mocker):
|
||||
# fake UploadFile row
|
||||
fake_row = types.SimpleNamespace(
|
||||
id="fid",
|
||||
name="f",
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
size=1,
|
||||
key="k",
|
||||
source_url="http://x",
|
||||
)
|
||||
|
||||
class _Q:
|
||||
def __init__(self, row):
|
||||
self._row = row
|
||||
|
||||
def where(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._row
|
||||
|
||||
class _S:
|
||||
def __init__(self, row):
|
||||
self._row = row
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
def query(self, *_):
|
||||
return _Q(self._row)
|
||||
|
||||
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S(fake_row))
|
||||
|
||||
f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1")
|
||||
assert f.related_id == "fid"
|
||||
assert f.extension == ".txt"
|
||||
@ -0,0 +1,93 @@
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
|
||||
class _VarSeg:
|
||||
def __init__(self, v):
|
||||
self.value = v
|
||||
|
||||
|
||||
class _VarPool:
|
||||
def __init__(self, mapping):
|
||||
self._m = mapping
|
||||
|
||||
def get(self, selector):
|
||||
d = self._m
|
||||
for k in selector:
|
||||
d = d[k]
|
||||
return _VarSeg(d)
|
||||
|
||||
def add(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class _GraphState:
|
||||
def __init__(self, var_pool):
|
||||
self.variable_pool = var_pool
|
||||
|
||||
|
||||
class _GraphParams:
|
||||
tenant_id = "t1"
|
||||
app_id = "app-1"
|
||||
workflow_id = "wf-1"
|
||||
graph_config = {}
|
||||
user_id = "u1"
|
||||
user_from = "account"
|
||||
invoke_from = "debugger"
|
||||
call_depth = 0
|
||||
|
||||
|
||||
def test_datasource_node_delegates_to_manager_stream(mocker):
|
||||
# prepare sys variables
|
||||
sys_vars = {
|
||||
"sys": {
|
||||
"datasource_type": "online_document",
|
||||
"datasource_info": {
|
||||
"workspace_id": "w",
|
||||
"page": {"page_id": "pg", "type": "t"},
|
||||
"credential_id": "",
|
||||
},
|
||||
}
|
||||
}
|
||||
var_pool = _VarPool(sys_vars)
|
||||
gs = _GraphState(var_pool)
|
||||
gp = _GraphParams()
|
||||
|
||||
# stub manager class
|
||||
class _Mgr:
|
||||
@classmethod
|
||||
def get_icon_url(cls, **_):
|
||||
return "icon"
|
||||
|
||||
@classmethod
|
||||
def stream_node_events(cls, **_):
|
||||
yield StreamChunkEvent(selector=["n", "text"], chunk="hi", is_final=False)
|
||||
yield StreamCompletedEvent(node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED))
|
||||
|
||||
@classmethod
|
||||
def get_upload_file_by_id(cls, **_):
|
||||
raise AssertionError("not called")
|
||||
|
||||
node = DatasourceNode(
|
||||
id="n",
|
||||
config={
|
||||
"id": "n",
|
||||
"data": {
|
||||
"type": "datasource",
|
||||
"version": "1",
|
||||
"title": "Datasource",
|
||||
"provider_type": "plugin",
|
||||
"provider_name": "p",
|
||||
"plugin_id": "plug",
|
||||
"datasource_name": "ds",
|
||||
},
|
||||
},
|
||||
graph_init_params=gp,
|
||||
graph_runtime_state=gs,
|
||||
datasource_manager=_Mgr,
|
||||
)
|
||||
|
||||
evts = list(node._run())
|
||||
assert isinstance(evts[0], StreamChunkEvent)
|
||||
assert isinstance(evts[-1], StreamCompletedEvent)
|
||||
Reference in New Issue
Block a user