mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 08:28:03 +08:00
make the model_runtime support reading and parsing the opaque_body from plugin LLM call (and fix the tool call parsing in streaming mode)
This commit is contained in:
@ -0,0 +1,20 @@
|
||||
## Purpose
|
||||
|
||||
`core/model_runtime/model_providers/__base/large_language_model.py` defines the base `LargeLanguageModel` interface used
|
||||
by model providers, including plugin-backed providers via `PluginModelClient`.
|
||||
|
||||
## Plugin invocation flow
|
||||
|
||||
- For plugin-based providers, `invoke()` delegates to `PluginModelClient.invoke_llm(...)`, which streams
|
||||
`LLMResultChunk` objects from the plugin daemon.
|
||||
- Dify yields chunks to callers and also aggregates chunks to fire `after_invoke` callbacks (and to construct a
|
||||
blocking `LLMResult` when `stream=False`).
|
||||
|
||||
## Key invariants / edge cases
|
||||
|
||||
- When aggregating chunks into an `LLMResult`, preserve provider-specific fields on the assistant message:
|
||||
- `AssistantPromptMessage.opaque_body` (pass-through, uninterpreted JSON).
|
||||
- Incremental `tool_calls` (merge deltas via `_increase_tool_call`).
|
||||
- Chunk `.prompt_messages` may be empty for plugin responses (compat layer for the plugin daemon); Dify re-attaches the
|
||||
original request `prompt_messages` for downstream consumers.
|
||||
|
||||
@ -0,0 +1,12 @@
|
||||
## Purpose
|
||||
|
||||
Unit tests for plugin-backed `LargeLanguageModel.invoke()` behavior around preserving provider pass-through data.
|
||||
|
||||
## What it covers
|
||||
|
||||
- `AssistantPromptMessage.opaque_body` from plugin `LLMResultChunk` deltas is preserved:
|
||||
- On the returned `LLMResult` in blocking (`stream=False`) mode.
|
||||
- On the aggregated `LLMResult` passed to `on_after_invoke` callbacks in streaming mode.
|
||||
- Streaming mode also verifies that `chunk.prompt_messages` is re-attached to the original request prompt messages.
|
||||
- Streaming aggregation merges incremental `tool_calls` across chunks.
|
||||
|
||||
@ -164,6 +164,7 @@ class LargeLanguageModel(AIModel):
|
||||
usage = LLMUsage.empty_usage()
|
||||
system_fingerprint = None
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
assistant_opaque_body = None
|
||||
|
||||
for chunk in result:
|
||||
if isinstance(chunk.delta.message.content, str):
|
||||
@ -172,6 +173,8 @@ class LargeLanguageModel(AIModel):
|
||||
content_list.extend(chunk.delta.message.content)
|
||||
if chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
|
||||
if assistant_opaque_body is None and chunk.delta.message.opaque_body is not None:
|
||||
assistant_opaque_body = chunk.delta.message.opaque_body
|
||||
|
||||
usage = chunk.delta.usage or LLMUsage.empty_usage()
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
@ -183,6 +186,7 @@ class LargeLanguageModel(AIModel):
|
||||
message=AssistantPromptMessage(
|
||||
content=content or content_list,
|
||||
tool_calls=tools_calls,
|
||||
opaque_body=assistant_opaque_body,
|
||||
),
|
||||
usage=usage,
|
||||
system_fingerprint=system_fingerprint,
|
||||
@ -261,6 +265,8 @@ class LargeLanguageModel(AIModel):
|
||||
usage = None
|
||||
system_fingerprint = None
|
||||
real_model = model
|
||||
assistant_opaque_body = None
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None):
|
||||
if not content:
|
||||
@ -294,6 +300,10 @@ class LargeLanguageModel(AIModel):
|
||||
)
|
||||
|
||||
_update_message_content(chunk.delta.message.content)
|
||||
if chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
|
||||
if assistant_opaque_body is None and chunk.delta.message.opaque_body is not None:
|
||||
assistant_opaque_body = chunk.delta.message.opaque_body
|
||||
|
||||
real_model = chunk.model
|
||||
if chunk.delta.usage:
|
||||
@ -304,7 +314,11 @@ class LargeLanguageModel(AIModel):
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=message_content)
|
||||
assistant_message = AssistantPromptMessage(
|
||||
content=message_content,
|
||||
tool_calls=tools_calls,
|
||||
opaque_body=assistant_opaque_body,
|
||||
)
|
||||
self._trigger_after_invoke_callbacks(
|
||||
model=model,
|
||||
result=LLMResult(
|
||||
|
||||
@ -0,0 +1,145 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
|
||||
|
||||
class _CaptureAfterInvokeCallback(Callback):
|
||||
after_result: LLMResult | None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.after_result = None
|
||||
|
||||
def on_before_invoke(self, **kwargs: Any) -> None: # noqa: ANN401
|
||||
return None
|
||||
|
||||
def on_new_chunk(self, **kwargs: Any) -> None: # noqa: ANN401
|
||||
return None
|
||||
|
||||
def on_after_invoke(self, result: LLMResult, **kwargs: Any) -> None: # noqa: ANN401
|
||||
self.after_result = result
|
||||
|
||||
def on_invoke_error(self, **kwargs: Any) -> None: # noqa: ANN401
|
||||
return None
|
||||
|
||||
|
||||
def _build_llm_instance() -> LargeLanguageModel:
|
||||
declaration = ProviderEntity(
|
||||
provider="test",
|
||||
label=I18nObject(en_US="test"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
plugin_model_provider = PluginModelProviderEntity(
|
||||
id="pmp_1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
provider="test",
|
||||
tenant_id="tenant_1",
|
||||
plugin_unique_identifier="test/plugin",
|
||||
plugin_id="test/plugin",
|
||||
declaration=declaration,
|
||||
)
|
||||
|
||||
return LargeLanguageModel(
|
||||
tenant_id="tenant_1",
|
||||
plugin_id="test/plugin",
|
||||
provider_name="test",
|
||||
plugin_model_provider=plugin_model_provider,
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_non_stream_preserves_assistant_opaque_body() -> None:
|
||||
llm = _build_llm_instance()
|
||||
prompt_messages: list[PromptMessage] = [UserPromptMessage(content="hi")]
|
||||
|
||||
chunk = LLMResultChunk(
|
||||
model="gpt-test",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content="hello", opaque_body={"provider_message_id": "msg_123"}),
|
||||
),
|
||||
)
|
||||
|
||||
def _mock_invoke_llm(self, **kwargs: Any): # noqa: ANN001, ANN401
|
||||
yield chunk
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient.invoke_llm", new=_mock_invoke_llm):
|
||||
result = llm.invoke(
|
||||
model="gpt-test",
|
||||
credentials={},
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert result.message.opaque_body == {"provider_message_id": "msg_123"}
|
||||
assert list(result.prompt_messages) == prompt_messages
|
||||
|
||||
|
||||
def test_invoke_stream_preserves_assistant_opaque_body_in_after_callback() -> None:
|
||||
llm = _build_llm_instance()
|
||||
prompt_messages: list[PromptMessage] = [UserPromptMessage(content="hi")]
|
||||
callback = _CaptureAfterInvokeCallback()
|
||||
|
||||
tool_call_1 = AssistantPromptMessage.ToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": '),
|
||||
)
|
||||
tool_call_2 = AssistantPromptMessage.ToolCall(
|
||||
id="",
|
||||
type="",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='"value"}'),
|
||||
)
|
||||
|
||||
chunk1 = LLMResultChunk(
|
||||
model="gpt-test",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content="h", tool_calls=[tool_call_1], opaque_body={"provider_message_id": "msg_123"}),
|
||||
),
|
||||
)
|
||||
chunk2 = LLMResultChunk(
|
||||
model="gpt-test",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content="i", tool_calls=[tool_call_2]),
|
||||
),
|
||||
)
|
||||
|
||||
def _mock_invoke_llm(self, **kwargs: Any): # noqa: ANN001, ANN401
|
||||
yield chunk1
|
||||
yield chunk2
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient.invoke_llm", new=_mock_invoke_llm):
|
||||
gen = llm.invoke(
|
||||
model="gpt-test",
|
||||
credentials={},
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={},
|
||||
stream=True,
|
||||
callbacks=[callback],
|
||||
)
|
||||
chunks = list(gen)
|
||||
|
||||
assert chunks[0].prompt_messages == prompt_messages
|
||||
assert callback.after_result is not None
|
||||
assert callback.after_result.message.opaque_body == {"provider_message_id": "msg_123"}
|
||||
assert len(callback.after_result.message.tool_calls) == 1
|
||||
assert callback.after_result.message.tool_calls[0].function.arguments == '{"arg1": "value"}'
|
||||
|
||||
Reference in New Issue
Block a user