mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
fix style check
This commit is contained in:
@ -288,7 +288,8 @@ class DatasetService:
|
||||
names,
|
||||
"Untitled",
|
||||
)
|
||||
|
||||
if not current_user or not current_user.id:
|
||||
raise ValueError("Current user or current user id not found")
|
||||
pipeline = Pipeline(
|
||||
tenant_id=tenant_id,
|
||||
name=rag_pipeline_dataset_create_entity.name,
|
||||
@ -814,6 +815,8 @@ class DatasetService:
|
||||
def update_rag_pipeline_dataset_settings(
|
||||
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
|
||||
):
|
||||
if not current_user or not current_user.current_tenant_id:
|
||||
raise ValueError("Current user or current tenant not found")
|
||||
dataset = session.merge(dataset)
|
||||
if not has_published:
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
@ -821,7 +824,7 @@ class DatasetService:
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id, # ignore type error
|
||||
provider=knowledge_configuration.embedding_model_provider or "",
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=knowledge_configuration.embedding_model or "",
|
||||
@ -895,6 +898,7 @@ class DatasetService:
|
||||
):
|
||||
action = "update"
|
||||
model_manager = ModelManager()
|
||||
embedding_model = None
|
||||
try:
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@ -908,14 +912,15 @@ class DatasetService:
|
||||
# Skip the rest of the embedding model update
|
||||
skip_embedding_update = True
|
||||
if not skip_embedding_update:
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
if embedding_model:
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
)
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
@ -1014,6 +1019,8 @@ class DatasetService:
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
dataset.enable_api = status
|
||||
if not current_user or not current_user.id:
|
||||
raise ValueError("Current user or current user id not found")
|
||||
dataset.updated_by = current_user.id
|
||||
dataset.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
@ -1350,6 +1357,8 @@ class DocumentService:
|
||||
redis_client.setex(retry_indexing_cache_key, 600, 1)
|
||||
# trigger async task
|
||||
document_ids = [document.id for document in documents]
|
||||
if not current_user or not current_user.id:
|
||||
raise ValueError("Current user or current user id not found")
|
||||
retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask_login import current_user
|
||||
@ -68,11 +69,13 @@ class DatasourceProviderService:
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
raw_credentials: dict[str, Any],
|
||||
raw_credentials: Mapping[str, Any],
|
||||
datasource_provider: DatasourceProvider,
|
||||
) -> dict[str, Any]:
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}", credential_type=datasource_provider.auth_type
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=CredentialType.of(datasource_provider.auth_type),
|
||||
)
|
||||
encrypted_credentials = raw_credentials.copy()
|
||||
for key, value in encrypted_credentials.items():
|
||||
|
||||
@ -241,6 +241,9 @@ class MessageService:
|
||||
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
if not app_config.additional_features:
|
||||
raise ValueError("Additional features not found")
|
||||
|
||||
if not app_config.additional_features.suggested_questions_after_answer:
|
||||
raise SuggestedQuestionsAfterAnswerDisabledError()
|
||||
|
||||
|
||||
@ -828,10 +828,10 @@ class RagPipelineService:
|
||||
)
|
||||
error = node_run_result.error if not run_succeeded else None
|
||||
except WorkflowNodeRunFailedError as e:
|
||||
node_instance = e._node
|
||||
node_instance = e._node # type: ignore
|
||||
run_succeeded = False
|
||||
node_run_result = None
|
||||
error = e._error
|
||||
error = e._error # type: ignore
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
@ -1253,7 +1253,7 @@ class RagPipelineService:
|
||||
repository.save(workflow_node_execution)
|
||||
|
||||
# Convert node_execution to WorkflowNodeExecution after save
|
||||
workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution)
|
||||
workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution) # type: ignore
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
|
||||
@ -47,6 +47,8 @@ class RagPipelineTransformService:
|
||||
self._deal_dependencies(pipeline_yaml, dataset.tenant_id)
|
||||
# Extract app data
|
||||
workflow_data = pipeline_yaml.get("workflow")
|
||||
if not workflow_data:
|
||||
raise ValueError("Missing workflow data for rag pipeline")
|
||||
graph = workflow_data.get("graph", {})
|
||||
nodes = graph.get("nodes", [])
|
||||
new_nodes = []
|
||||
@ -252,7 +254,7 @@ class RagPipelineTransformService:
|
||||
plugin_unique_identifier = dependency.get("value", {}).get("plugin_unique_identifier")
|
||||
plugin_id = plugin_unique_identifier.split(":")[0]
|
||||
if plugin_id not in installed_plugins_ids:
|
||||
plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(plugin_id)
|
||||
plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(plugin_id) # type: ignore
|
||||
if plugin_unique_identifier:
|
||||
need_install_plugin_unique_identifiers.append(plugin_unique_identifier)
|
||||
if need_install_plugin_unique_identifiers:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import dataclasses
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, overload
|
||||
from typing import Any, Generic, TypeAlias, TypeVar, overload
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.models import File
|
||||
@ -39,30 +38,6 @@ class _PCKeys:
|
||||
CHILD_CONTENTS = "child_contents"
|
||||
|
||||
|
||||
class _QAStructureItem(TypedDict):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class _QAStructure(TypedDict):
|
||||
qa_chunks: list[_QAStructureItem]
|
||||
|
||||
|
||||
class _ParentChildChunkItem(TypedDict):
|
||||
parent_content: str
|
||||
child_contents: list[str]
|
||||
|
||||
|
||||
class _ParentChildStructure(TypedDict):
|
||||
parent_mode: str
|
||||
parent_child_chunks: list[_ParentChildChunkItem]
|
||||
|
||||
|
||||
class _SpecialChunkType(StrEnum):
|
||||
parent_child = "parent_child"
|
||||
qa = "qa"
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
@ -392,7 +367,7 @@ class VariableTruncator:
|
||||
def _truncate_json_primitives(self, val: dict, target_size: int) -> _PartResult[dict]: ...
|
||||
|
||||
@overload
|
||||
def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ...
|
||||
def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ... # type: ignore
|
||||
|
||||
@overload
|
||||
def _truncate_json_primitives(self, val: int, target_size: int) -> _PartResult[int]: ...
|
||||
|
||||
@ -146,7 +146,7 @@ class WorkflowConverter:
|
||||
graph=graph,
|
||||
model_config=app_config.model,
|
||||
prompt_template=app_config.prompt_template,
|
||||
file_upload=app_config.additional_features.file_upload,
|
||||
file_upload=app_config.additional_features.file_upload if app_config.additional_features else None,
|
||||
external_data_variable_node_mapping=external_data_variable_node_mapping,
|
||||
)
|
||||
|
||||
|
||||
@ -430,6 +430,10 @@ class WorkflowDraftVariableService:
|
||||
.where(WorkflowDraftVariable.id == variable.id)
|
||||
)
|
||||
variable_reloaded = self._session.execute(variable_query).scalars().first()
|
||||
if variable_reloaded is None:
|
||||
logger.warning("Associated WorkflowDraftVariable not found, draft_var_id=%s", variable.id)
|
||||
self._session.delete(variable)
|
||||
return
|
||||
variable_file = variable_reloaded.variable_file
|
||||
if variable_file is None:
|
||||
logger.warning(
|
||||
|
||||
@ -811,7 +811,7 @@ class WorkflowService:
|
||||
return node, node_run_result, run_succeeded, error
|
||||
|
||||
except WorkflowNodeRunFailedError as e:
|
||||
return e._node, None, False, e._error
|
||||
return e._node, None, False, e._error # type: ignore
|
||||
|
||||
def _apply_error_strategy(self, node: Node, node_run_result: NodeRunResult) -> NodeRunResult:
|
||||
"""Apply error strategy when node execution fails."""
|
||||
|
||||
Reference in New Issue
Block a user