fix style check

This commit is contained in:
jyong
2025-09-17 17:34:05 +08:00
parent f963eb525c
commit 69a402ba99
28 changed files with 85 additions and 371 deletions

View File

@ -288,7 +288,8 @@ class DatasetService:
names,
"Untitled",
)
if not current_user or not current_user.id:
raise ValueError("Current user or current user id not found")
pipeline = Pipeline(
tenant_id=tenant_id,
name=rag_pipeline_dataset_create_entity.name,
@ -814,6 +815,8 @@ class DatasetService:
def update_rag_pipeline_dataset_settings(
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
):
if not current_user or not current_user.current_tenant_id:
raise ValueError("Current user or current tenant not found")
dataset = session.merge(dataset)
if not has_published:
dataset.chunk_structure = knowledge_configuration.chunk_structure
@ -821,7 +824,7 @@ class DatasetService:
if knowledge_configuration.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_user.current_tenant_id, # ignore type error
provider=knowledge_configuration.embedding_model_provider or "",
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_configuration.embedding_model or "",
@ -895,6 +898,7 @@ class DatasetService:
):
action = "update"
model_manager = ModelManager()
embedding_model = None
try:
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
@ -908,14 +912,15 @@ class DatasetService:
# Skip the rest of the embedding model update
skip_embedding_update = True
if not skip_embedding_update:
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
if embedding_model:
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
)
)
dataset.collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
@ -1014,6 +1019,8 @@ class DatasetService:
if dataset is None:
raise NotFound("Dataset not found.")
dataset.enable_api = status
if not current_user or not current_user.id:
raise ValueError("Current user or current user id not found")
dataset.updated_by = current_user.id
dataset.updated_at = naive_utc_now()
db.session.commit()
@ -1350,6 +1357,8 @@ class DocumentService:
redis_client.setex(retry_indexing_cache_key, 600, 1)
# trigger async task
document_ids = [document.id for document in documents]
if not current_user or not current_user.id:
raise ValueError("Current user or current user id not found")
retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id)
@staticmethod

View File

@ -1,5 +1,6 @@
import logging
import time
from collections.abc import Mapping
from typing import Any
from flask_login import current_user
@ -68,11 +69,13 @@ class DatasourceProviderService:
tenant_id: str,
provider: str,
plugin_id: str,
raw_credentials: dict[str, Any],
raw_credentials: Mapping[str, Any],
datasource_provider: DatasourceProvider,
) -> dict[str, Any]:
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}", credential_type=datasource_provider.auth_type
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
encrypted_credentials = raw_credentials.copy()
for key, value in encrypted_credentials.items():

View File

@ -241,6 +241,9 @@ class MessageService:
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
if not app_config.additional_features:
raise ValueError("Additional features not found")
if not app_config.additional_features.suggested_questions_after_answer:
raise SuggestedQuestionsAfterAnswerDisabledError()

View File

@ -828,10 +828,10 @@ class RagPipelineService:
)
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e:
node_instance = e._node
node_instance = e._node # type: ignore
run_succeeded = False
node_run_result = None
error = e._error
error = e._error # type: ignore
workflow_node_execution = WorkflowNodeExecution(
id=str(uuid4()),
@ -1253,7 +1253,7 @@ class RagPipelineService:
repository.save(workflow_node_execution)
# Convert node_execution to WorkflowNodeExecution after save
workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution)
workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution) # type: ignore
with Session(bind=db.engine) as session, session.begin():
draft_var_saver = DraftVariableSaver(

View File

@ -47,6 +47,8 @@ class RagPipelineTransformService:
self._deal_dependencies(pipeline_yaml, dataset.tenant_id)
# Extract app data
workflow_data = pipeline_yaml.get("workflow")
if not workflow_data:
raise ValueError("Missing workflow data for rag pipeline")
graph = workflow_data.get("graph", {})
nodes = graph.get("nodes", [])
new_nodes = []
@ -252,7 +254,7 @@ class RagPipelineTransformService:
plugin_unique_identifier = dependency.get("value", {}).get("plugin_unique_identifier")
plugin_id = plugin_unique_identifier.split(":")[0]
if plugin_id not in installed_plugins_ids:
plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(plugin_id)
plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(plugin_id) # type: ignore
if plugin_unique_identifier:
need_install_plugin_unique_identifiers.append(plugin_unique_identifier)
if need_install_plugin_unique_identifiers:

View File

@ -1,7 +1,6 @@
import dataclasses
from collections.abc import Mapping
from enum import StrEnum
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, overload
from typing import Any, Generic, TypeAlias, TypeVar, overload
from configs import dify_config
from core.file.models import File
@ -39,30 +38,6 @@ class _PCKeys:
CHILD_CONTENTS = "child_contents"
class _QAStructureItem(TypedDict):
question: str
answer: str
class _QAStructure(TypedDict):
qa_chunks: list[_QAStructureItem]
class _ParentChildChunkItem(TypedDict):
parent_content: str
child_contents: list[str]
class _ParentChildStructure(TypedDict):
parent_mode: str
parent_child_chunks: list[_ParentChildChunkItem]
class _SpecialChunkType(StrEnum):
parent_child = "parent_child"
qa = "qa"
_T = TypeVar("_T")
@ -392,7 +367,7 @@ class VariableTruncator:
def _truncate_json_primitives(self, val: dict, target_size: int) -> _PartResult[dict]: ...
@overload
def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ...
def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ... # type: ignore
@overload
def _truncate_json_primitives(self, val: int, target_size: int) -> _PartResult[int]: ...

View File

@ -146,7 +146,7 @@ class WorkflowConverter:
graph=graph,
model_config=app_config.model,
prompt_template=app_config.prompt_template,
file_upload=app_config.additional_features.file_upload,
file_upload=app_config.additional_features.file_upload if app_config.additional_features else None,
external_data_variable_node_mapping=external_data_variable_node_mapping,
)

View File

@ -430,6 +430,10 @@ class WorkflowDraftVariableService:
.where(WorkflowDraftVariable.id == variable.id)
)
variable_reloaded = self._session.execute(variable_query).scalars().first()
if variable_reloaded is None:
logger.warning("Associated WorkflowDraftVariable not found, draft_var_id=%s", variable.id)
self._session.delete(variable)
return
variable_file = variable_reloaded.variable_file
if variable_file is None:
logger.warning(

View File

@ -811,7 +811,7 @@ class WorkflowService:
return node, node_run_result, run_succeeded, error
except WorkflowNodeRunFailedError as e:
return e._node, None, False, e._error
return e._node, None, False, e._error # type: ignore
def _apply_error_strategy(self, node: Node, node_run_result: NodeRunResult) -> NodeRunResult:
"""Apply error strategy when node execution fails."""