This commit is contained in:
jyong
2025-06-03 19:02:57 +08:00
parent 309fffd1e4
commit 9cdd2cbb27
35 changed files with 229 additions and 300 deletions

View File

@ -1,4 +1,3 @@
from calendar import day_abbr
import copy
import datetime
import json
@ -7,7 +6,7 @@ import random
import time
import uuid
from collections import Counter
from typing import Any, Optional, cast
from typing import Any, Optional
from flask_login import current_user
from sqlalchemy import func, select
@ -282,7 +281,6 @@ class DatasetService:
db.session.commit()
return dataset
@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
@ -494,10 +492,9 @@ class DatasetService:
return dataset
@staticmethod
def update_rag_pipeline_dataset_settings(session: Session,
dataset: Dataset,
knowledge_configuration: KnowledgeConfiguration,
has_published: bool = False):
def update_rag_pipeline_dataset_settings(
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
):
dataset = session.merge(dataset)
if not has_published:
dataset.chunk_structure = knowledge_configuration.chunk_structure
@ -616,7 +613,6 @@ class DatasetService:
if action:
deal_dataset_index_update_task.delay(dataset.id, action)
@staticmethod
def delete_dataset(dataset_id, user):
dataset = DatasetService.get_dataset(dataset_id)

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
from flask_login import current_user
@ -22,11 +21,9 @@ class DatasourceProviderService:
def __init__(self) -> None:
self.provider_manager = PluginDatasourceManager()
def datasource_provider_credentials_validate(self,
tenant_id: str,
provider: str,
plugin_id: str,
credentials: dict) -> None:
def datasource_provider_credentials_validate(
self, tenant_id: str, provider: str, plugin_id: str, credentials: dict
) -> None:
"""
validate datasource provider credentials.
@ -34,29 +31,30 @@ class DatasourceProviderService:
:param provider:
:param credentials:
"""
credential_valid = self.provider_manager.validate_provider_credentials(tenant_id=tenant_id,
user_id=current_user.id,
provider=provider,
credentials=credentials)
credential_valid = self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id, user_id=current_user.id, provider=provider, credentials=credentials
)
if credential_valid:
# Get all provider configurations of the current workspace
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id).first()
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.first()
)
provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id,
provider=provider
)
provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
if not datasource_provider:
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
auth_type="api_key",
encrypted_credentials=credentials)
datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
auth_type="api_key",
encrypted_credentials=credentials,
)
db.session.add(datasource_provider)
db.session.commit()
else:
@ -101,11 +99,15 @@ class DatasourceProviderService:
:return:
"""
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id
).all()
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.filter(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.all()
)
if not datasource_providers:
return []
copy_credentials_list = []
@ -128,10 +130,7 @@ class DatasourceProviderService:
return copy_credentials_list
def remove_datasource_credentials(self,
tenant_id: str,
provider: str,
plugin_id: str) -> None:
def remove_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> None:
"""
remove datasource credentials.
@ -140,9 +139,11 @@ class DatasourceProviderService:
:param plugin_id: plugin id
:return:
"""
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id).first()
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.first()
)
if datasource_provider:
db.session.delete(datasource_provider)
db.session.commit()

View File

@ -107,6 +107,7 @@ class KnowledgeConfiguration(BaseModel):
"""
Knowledge Base Configuration.
"""
chunk_structure: str
indexing_technique: Literal["high_quality", "economy"]
embedding_model_provider: Optional[str] = ""

View File

@ -3,7 +3,6 @@ from typing import Any, Union
from configs import dify_config
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from models.dataset import Pipeline
from models.model import Account, App, EndUser

View File

@ -1,13 +1,12 @@
from typing import Optional
from flask_login import current_user
import yaml
from flask_login import current_user
from extensions.ext_database import db
from models.dataset import PipelineCustomizedTemplate
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
@ -43,7 +42,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
)
recommended_pipelines_results = []
for pipeline_customized_template in pipeline_customized_templates:
recommended_pipeline_result = {
"id": pipeline_customized_template.id,
"name": pipeline_customized_template.name,
@ -56,7 +54,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return {"pipeline_templates": recommended_pipelines_results}
@classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
"""

View File

@ -38,7 +38,6 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
recommended_pipelines_results = []
for pipeline_built_in_template in pipeline_built_in_templates:
recommended_pipeline_result = {
"id": pipeline_built_in_template.id,
"name": pipeline_built_in_template.name,

View File

@ -35,7 +35,7 @@ from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
@ -57,9 +57,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi
class RagPipelineService:
@classmethod
def get_pipeline_templates(
cls, type: str = "built-in", language: str = "en-US"
) -> dict:
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
@ -308,7 +306,7 @@ class RagPipelineService:
session=session,
dataset=dataset,
knowledge_configuration=knowledge_configuration,
has_published=pipeline.is_published
has_published=pipeline.is_published,
)
# return new workflow
return workflow
@ -444,12 +442,10 @@ class RagPipelineService:
)
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: GetOnlineDocumentPagesResponse = (
datasource_runtime._get_online_document_pages(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
return {
"result": [page.model_dump() for page in online_document_result.result],
@ -470,7 +466,6 @@ class RagPipelineService:
else:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
@ -689,8 +684,8 @@ class RagPipelineService:
WorkflowRun.app_id == pipeline.id,
or_(
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value
)
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
),
)
if args.get("last_id"):
@ -763,18 +758,17 @@ class RagPipelineService:
# Use the repository to get the node execution
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine,
app_id=pipeline.id,
user=user,
triggered_from=None
session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
)
# Use the repository to get the node executions with ordering
order_config = OrderConfig(order_by=["index"], order_direction="desc")
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id,
order_config=order_config,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN)
# Convert domain models to database models
node_executions = repository.get_by_workflow_run(
workflow_run_id=run_id,
order_config=order_config,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
# Convert domain models to database models
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
return workflow_node_executions

View File

@ -279,7 +279,11 @@ class RagPipelineDslService:
if node.get("data", {}).get("type") == "knowledge_index":
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure:
if (
dataset
and pipeline.is_published
and dataset.chunk_structure != knowledge_configuration.chunk_structure
):
raise ValueError("Chunk structure is not compatible with the published pipeline")
else:
dataset = Dataset(
@ -304,8 +308,7 @@ class RagPipelineDslService:
.filter(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name
== knowledge_configuration.embedding_model,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
@ -323,12 +326,8 @@ class RagPipelineDslService:
db.session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = (
knowledge_configuration.embedding_model
)
dataset.embedding_model_provider = (
knowledge_configuration.embedding_model_provider
)
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id
@ -443,8 +442,7 @@ class RagPipelineDslService:
.filter(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name
== knowledge_configuration.embedding_model,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
@ -462,12 +460,8 @@ class RagPipelineDslService:
db.session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = (
knowledge_configuration.embedding_model
)
dataset.embedding_model_provider = (
knowledge_configuration.embedding_model_provider
)
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id
@ -538,7 +532,6 @@ class RagPipelineDslService:
icon_type = "emoji"
icon = str(pipeline_data.get("icon", ""))
# Initialize pipeline based on mode
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
@ -554,7 +547,6 @@ class RagPipelineDslService:
]
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
@ -576,7 +568,6 @@ class RagPipelineDslService:
pipeline.description = pipeline_data.get("description", pipeline.description)
pipeline.updated_by = account.id
else:
if account.current_tenant_id is None:
raise ValueError("Current tenant is not set")
@ -636,7 +627,6 @@ class RagPipelineDslService:
# commit db session changes
db.session.commit()
return pipeline
@classmethod
@ -874,7 +864,6 @@ class RagPipelineDslService:
except Exception:
return None
@staticmethod
def create_rag_pipeline_dataset(
tenant_id: str,
@ -886,9 +875,7 @@ class RagPipelineDslService:
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
):
raise ValueError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)

View File

@ -12,12 +12,12 @@ class RagPipelineManageService:
# get all builtin providers
manager = PluginDatasourceManager()
datasources = manager.fetch_datasource_providers(tenant_id)
datasources = manager.fetch_datasource_providers(tenant_id)
for datasource in datasources:
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(tenant_id=tenant_id,
provider=datasource.provider,
plugin_id=datasource.plugin_id)
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
if credentials:
datasource.is_authorized = True
return datasources