refactor(api): continue decoupling dify_graph from API concerns (#33580)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
-LAN-
2026-03-25 20:32:24 +08:00
committed by GitHub
parent b7b9b003c9
commit 56593f20b0
487 changed files with 17999 additions and 9186 deletions

View File

@ -56,6 +56,7 @@ from core.rag.retrieval.template_prompts import (
)
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.workflow.file_reference import build_file_reference
from core.workflow.nodes.knowledge_retrieval import exc
from core.workflow.nodes.knowledge_retrieval.retrieval import (
KnowledgeRetrievalRequest,
@ -160,7 +161,7 @@ class DatasetRetrieval:
if request.model_provider is None or request.model_name is None or request.query is None:
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id, user_id=request.user_id)
model_instance = model_manager.get_model_instance(
tenant_id=request.tenant_id,
model_type=ModelType.LLM,
@ -383,23 +384,27 @@ class DatasetRetrieval:
return None, []
retrieve_config = config.retrieve_config
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
)
model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
# get model schema
# Reuse the caller-bound model instance for both schema resolution and
# downstream planner/invoke calls so a single request never mixes
# tenant-scope and request-bound runtimes.
model_schema = model_type_instance.get_model_schema(
model=model_config.model, credentials=model_config.credentials
model=model_instance.model_name,
credentials=model_instance.credentials,
)
if not model_schema:
return None, []
model_config.provider_model_bundle = model_instance.provider_model_bundle
model_config.credentials = model_instance.credentials
model_config.model_schema = model_schema
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
@ -517,11 +522,12 @@ class DatasetRetrieval:
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=segment.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
reference=build_file_reference(
record_id=str(upload_file.id),
),
size=upload_file.size,
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
@ -986,6 +992,24 @@ class DatasetRetrieval:
)
)
@staticmethod
def _resolve_creator_user_role(user_from: str) -> CreatorUserRole | None:
"""Map runtime user source values to dataset query audit roles.
Workflow run context uses the hyphenated ``end-user`` value, while
``DatasetQuery.created_by_role`` persists the underscore-based
``CreatorUserRole.END_USER`` enum. Query logging is a side effect, so an
unsupported value should be skipped instead of aborting retrieval.
"""
normalized_user_from = str(user_from).strip().lower().replace("-", "_")
if normalized_user_from == CreatorUserRole.ACCOUNT.value:
return CreatorUserRole.ACCOUNT
if normalized_user_from == CreatorUserRole.END_USER.value:
return CreatorUserRole.END_USER
logger.warning("Skipping dataset query audit log for unsupported user_from=%r", user_from)
return None
def _on_query(
self,
query: str | None,
@ -996,10 +1020,13 @@ class DatasetRetrieval:
user_id: str,
):
"""
Handle query.
Persist dataset query audit rows for retrieval requests.
"""
if not query and not attachment_ids:
return
created_by_role = self._resolve_creator_user_role(user_from)
if created_by_role is None:
return
dataset_queries = []
for dataset_id in dataset_ids:
contents = []
@ -1014,7 +1041,7 @@ class DatasetRetrieval:
content=json.dumps(contents),
source=DatasetQuerySource.APP,
source_app_id=app_id,
created_by_role=CreatorUserRole(user_from),
created_by_role=created_by_role,
created_by=user_id,
)
dataset_queries.append(dataset_query)
@ -1411,7 +1438,7 @@ class DatasetRetrieval:
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config, user_id=user_id)
# fetch prompt messages
prompt_messages, stop = self._get_prompt_template(
@ -1430,7 +1457,6 @@ class DatasetRetrieval:
model_parameters=model_config.parameters,
stop=stop,
stream=True,
user=user_id,
),
)
@ -1533,7 +1559,7 @@ class DatasetRetrieval:
return filters
def _fetch_model_config(
self, tenant_id: str, model: ModelConfig
self, tenant_id: str, model: ModelConfig, user_id: str | None = None
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
@ -1543,7 +1569,7 @@ class DatasetRetrieval:
model_name = model.name
provider_name = model.provider
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)

View File

@ -3,13 +3,14 @@ from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.llm import deduct_llm_quota
from core.model_manager import ModelInstance
from core.model_manager import ModelInstance, ModelManager
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import ModelType
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
@ -119,6 +120,7 @@ class ReactMultiDatasetRouter:
memory_config=None,
memory=None,
model_config=model_config,
model_instance=model_instance,
)
result_text, usage = self._invoke_llm(
completion_param=model_config.parameters,
@ -150,19 +152,24 @@ class ReactMultiDatasetRouter:
:param stop: stop
:return:
"""
invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm(
bound_model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
tenant_id=tenant_id,
provider=model_instance.provider,
model_type=ModelType.LLM,
model=model_instance.model_name,
)
invoke_result: Generator[LLMResult, None, None] = bound_model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=completion_param,
stop=stop,
stream=True,
user=user_id,
)
# handle invoke result
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
# deduct quota
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
deduct_llm_quota(tenant_id=tenant_id, model_instance=bound_model_instance, usage=usage)
return text, usage