mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
refactor(api): continue decoupling dify_graph from API concerns (#33580)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
@ -56,6 +56,7 @@ from core.rag.retrieval.template_prompts import (
|
||||
)
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from core.workflow.nodes.knowledge_retrieval import exc
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import (
|
||||
KnowledgeRetrievalRequest,
|
||||
@ -160,7 +161,7 @@ class DatasetRetrieval:
|
||||
if request.model_provider is None or request.model_name is None or request.query is None:
|
||||
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id, user_id=request.user_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=request.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@ -383,23 +384,27 @@ class DatasetRetrieval:
|
||||
return None, []
|
||||
retrieve_config = config.retrieve_config
|
||||
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
|
||||
)
|
||||
model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
|
||||
# get model schema
|
||||
# Reuse the caller-bound model instance for both schema resolution and
|
||||
# downstream planner/invoke calls so a single request never mixes
|
||||
# tenant-scope and request-bound runtimes.
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model, credentials=model_config.credentials
|
||||
model=model_instance.model_name,
|
||||
credentials=model_instance.credentials,
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
return None, []
|
||||
|
||||
model_config.provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_config.credentials = model_instance.credentials
|
||||
model_config.model_schema = model_schema
|
||||
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
@ -517,11 +522,12 @@ class DatasetRetrieval:
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=segment.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
reference=build_file_reference(
|
||||
record_id=str(upload_file.id),
|
||||
),
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=sign_upload_file(upload_file.id, upload_file.extension),
|
||||
@ -986,6 +992,24 @@ class DatasetRetrieval:
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_creator_user_role(user_from: str) -> CreatorUserRole | None:
|
||||
"""Map runtime user source values to dataset query audit roles.
|
||||
|
||||
Workflow run context uses the hyphenated ``end-user`` value, while
|
||||
``DatasetQuery.created_by_role`` persists the underscore-based
|
||||
``CreatorUserRole.END_USER`` enum. Query logging is a side effect, so an
|
||||
unsupported value should be skipped instead of aborting retrieval.
|
||||
"""
|
||||
normalized_user_from = str(user_from).strip().lower().replace("-", "_")
|
||||
if normalized_user_from == CreatorUserRole.ACCOUNT.value:
|
||||
return CreatorUserRole.ACCOUNT
|
||||
if normalized_user_from == CreatorUserRole.END_USER.value:
|
||||
return CreatorUserRole.END_USER
|
||||
|
||||
logger.warning("Skipping dataset query audit log for unsupported user_from=%r", user_from)
|
||||
return None
|
||||
|
||||
def _on_query(
|
||||
self,
|
||||
query: str | None,
|
||||
@ -996,10 +1020,13 @@ class DatasetRetrieval:
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
Handle query.
|
||||
Persist dataset query audit rows for retrieval requests.
|
||||
"""
|
||||
if not query and not attachment_ids:
|
||||
return
|
||||
created_by_role = self._resolve_creator_user_role(user_from)
|
||||
if created_by_role is None:
|
||||
return
|
||||
dataset_queries = []
|
||||
for dataset_id in dataset_ids:
|
||||
contents = []
|
||||
@ -1014,7 +1041,7 @@ class DatasetRetrieval:
|
||||
content=json.dumps(contents),
|
||||
source=DatasetQuerySource.APP,
|
||||
source_app_id=app_id,
|
||||
created_by_role=CreatorUserRole(user_from),
|
||||
created_by_role=created_by_role,
|
||||
created_by=user_id,
|
||||
)
|
||||
dataset_queries.append(dataset_query)
|
||||
@ -1411,7 +1438,7 @@ class DatasetRetrieval:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
|
||||
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config, user_id=user_id)
|
||||
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._get_prompt_template(
|
||||
@ -1430,7 +1457,6 @@ class DatasetRetrieval:
|
||||
model_parameters=model_config.parameters,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -1533,7 +1559,7 @@ class DatasetRetrieval:
|
||||
return filters
|
||||
|
||||
def _fetch_model_config(
|
||||
self, tenant_id: str, model: ModelConfig
|
||||
self, tenant_id: str, model: ModelConfig, user_id: str | None = None
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
@ -1543,7 +1569,7 @@ class DatasetRetrieval:
|
||||
model_name = model.name
|
||||
provider_name = model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||
)
|
||||
|
||||
@ -3,13 +3,14 @@ from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction
|
||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
||||
|
||||
@ -119,6 +120,7 @@ class ReactMultiDatasetRouter:
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
result_text, usage = self._invoke_llm(
|
||||
completion_param=model_config.parameters,
|
||||
@ -150,19 +152,24 @@ class ReactMultiDatasetRouter:
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm(
|
||||
bound_model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=model_instance.provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=model_instance.model_name,
|
||||
)
|
||||
invoke_result: Generator[LLMResult, None, None] = bound_model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=completion_param,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
# deduct quota
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=bound_model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
|
||||
Reference in New Issue
Block a user