This commit is contained in:
jyong
2025-05-28 17:56:04 +08:00
parent 5fc2bc58a9
commit 7f59ffe7af
32 changed files with 680 additions and 202 deletions

View File

@ -51,7 +51,10 @@ from services.entities.knowledge_entities.knowledge_entities import (
RetrievalModel,
SegmentUpdateArgs,
)
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeBaseUpdateConfiguration,
RagPipelineDatasetCreateEntity,
)
from services.errors.account import InvalidActionError, NoPermissionError
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
from services.errors.dataset import DatasetNameDuplicateError
@ -59,11 +62,11 @@ from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureModel, FeatureService
from services.rag_pipeline.rag_pipeline_dsl_service import ImportMode, RagPipelineDslService, RagPipelineImportInfo
from services.tag_service import TagService
from services.vector_service import VectorService
from tasks.batch_clean_document_task import batch_clean_document_task
from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_index_update_task import deal_dataset_index_update_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
@ -278,47 +281,6 @@ class DatasetService:
db.session.commit()
return dataset
@staticmethod
def create_rag_pipeline_dataset(
tenant_id: str,
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
):
# check if dataset name already exists
if (
db.session.query(Dataset)
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
):
raise DatasetNameDuplicateError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
dataset = Dataset(
name=rag_pipeline_dataset_create_entity.name,
description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag-pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
)
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
account = cast(Account, current_user)
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=dataset,
)
return {
"id": rag_pipeline_import_info.id,
"dataset_id": dataset.id,
"pipeline_id": rag_pipeline_import_info.pipeline_id,
"status": rag_pipeline_import_info.status,
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
"current_dsl_version": rag_pipeline_import_info.current_dsl_version,
"error": rag_pipeline_import_info.error,
}
@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
@ -529,6 +491,130 @@ class DatasetService:
if action:
deal_dataset_vector_index_task.delay(dataset_id, action)
return dataset
@staticmethod
def update_rag_pipeline_dataset_settings(session: Session,
dataset: Dataset,
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
has_published: bool = False):
if not has_published:
dataset.chunk_structure = knowledge_base_setting.chunk_structure
index_method = knowledge_base_setting.index_method
dataset.indexing_technique = index_method.indexing_technique
if index_method == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=index_method.embedding_setting.embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=index_method.embedding_setting.embedding_model_name,
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
dataset.collection_binding_id = dataset_collection_binding.id
elif index_method == "economy":
dataset.keyword_number = index_method.economy_setting.keyword_number
else:
raise ValueError("Invalid index method")
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
session.add(dataset)
else:
if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure:
raise ValueError("Chunk structure is not allowed to be updated.")
action = None
if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique:
# if update indexing_technique
if knowledge_base_setting.index_method.indexing_technique == "economy":
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
elif knowledge_base_setting.index_method.indexing_technique == "high_quality":
action = "add"
# get embedding model setting
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
dataset.collection_binding_id = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
else:
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
# Skip embedding model checks if not provided in the update request
if dataset.indexing_technique == "high_quality":
skip_embedding_update = False
try:
# Handle existing model provider
plugin_model_provider = dataset.embedding_model_provider
plugin_model_provider_str = None
if plugin_model_provider:
plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))
# Handle new model provider from request
new_plugin_model_provider = knowledge_base_setting.index_method.embedding_setting.embedding_provider_name
new_plugin_model_provider_str = None
if new_plugin_model_provider:
new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
# Only update embedding model if both values are provided and different from current
if (
plugin_model_provider_str != new_plugin_model_provider_str
or knowledge_base_setting.index_method.embedding_setting.embedding_model_name != dataset.embedding_model
):
action = "update"
model_manager = ModelManager()
try:
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
)
except ProviderTokenNotInitError:
# If we can't get the embedding model, skip updating it
# and keep the existing settings if available
# Skip the rest of the embedding model update
skip_embedding_update = True
if not skip_embedding_update:
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
)
dataset.collection_binding_id = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
elif dataset.indexing_technique == "economy":
if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number:
dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
session.add(dataset)
session.commit()
if action:
deal_dataset_index_update_task.delay(dataset.id, action)
@staticmethod
def delete_dataset(dataset_id, user):

View File

@ -4,29 +4,12 @@ from typing import Optional
from flask_login import current_user
from constants import HIDDEN_VALUE
from core import datasource
from core.datasource.__base import datasource_provider
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.helper import encrypter
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
from core.model_runtime.entities.provider_entities import FormType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.impl.datasource import PluginDatasourceManager
from core.provider_manager import ProviderManager
from extensions.ext_database import db
from models.oauth import DatasourceProvider
from models.provider import ProviderType
from services.entities.model_provider_entities import (
CustomConfigurationResponse,
CustomConfigurationStatus,
DefaultModelResponse,
ModelWithProviderEntityResponse,
ProviderResponse,
ProviderWithModelsResponse,
SimpleProviderEntityResponse,
SystemConfigurationResponse,
)
from extensions.database import db
logger = logging.getLogger(__name__)
@ -115,16 +98,26 @@ class DatasourceProviderService:
:param tenant_id: workspace id
:param provider: provider name
:param datasource_name: datasource name
:param plugin_id: plugin id
:return:
"""
# Get all provider configurations of the current workspace
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
datasource_provider: DatasourceProvider | None = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id).first()
if not datasource_provider:
return None
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.obfuscated_token(value)
return copy_credentials
def remove_datasource_credentials(self,
@ -136,11 +129,9 @@ class DatasourceProviderService:
:param tenant_id: workspace id
:param provider: provider name
:param datasource_name: datasource name
:param plugin_id: plugin id
:return:
"""
# Get all provider configurations of the current workspace
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id).first()

View File

@ -111,3 +111,12 @@ class KnowledgeConfiguration(BaseModel):
chunk_structure: str
index_method: IndexMethod
retrieval_setting: RetrievalSetting
class KnowledgeBaseUpdateConfiguration(BaseModel):
"""
Knowledge Base Update Configuration.
"""
index_method: IndexMethod
chunk_structure: str
retrieval_setting: RetrievalSetting

View File

@ -69,9 +69,9 @@ class PipelineGenerateService:
@classmethod
def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True):
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
return WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
app_model=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
return PipelineGenerator.convert_to_event_stream(
PipelineGenerator().single_loop_generate(
pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)

View File

@ -36,7 +36,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
recommended_pipelines_results = []
for pipeline_built_in_template in pipeline_built_in_templates:
pipeline_model: Pipeline = pipeline_built_in_template.pipeline
pipeline_model: Pipeline | None = pipeline_built_in_template.pipeline
if not pipeline_model:
continue
recommended_pipeline_result = {
"id": pipeline_built_in_template.id,
@ -48,7 +50,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"privacy_policy": pipeline_built_in_template.privacy_policy,
"position": pipeline_built_in_template.position,
}
dataset: Dataset = pipeline_model.dataset
dataset: Dataset | None = pipeline_model.dataset
if dataset:
recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure
recommended_pipelines_results.append(recommended_pipeline_result)
@ -72,15 +74,19 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
if not pipeline_template:
return None
# get app detail
# get pipeline detail
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first()
if not pipeline or not pipeline.is_public:
return None
dataset: Dataset | None = pipeline.dataset
if not dataset:
return None
return {
"id": pipeline.id,
"name": pipeline.name,
"icon": pipeline.icon,
"mode": pipeline.mode,
"icon": pipeline_template.icon,
"chunk_structure": dataset.chunk_structure,
"export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline),
}

View File

@ -46,7 +46,8 @@ from models.workflow import (
WorkflowRun,
WorkflowType,
)
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, PipelineTemplateInfoEntity
from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
@ -261,8 +262,7 @@ class RagPipelineService:
session: Session,
pipeline: Pipeline,
account: Account,
marked_name: str = "",
marked_comment: str = "",
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
) -> Workflow:
draft_workflow_stmt = select(Workflow).where(
Workflow.tenant_id == pipeline.tenant_id,
@ -282,18 +282,25 @@ class RagPipelineService:
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
environment_variables=draft_workflow.environment_variables,
environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables,
marked_name=marked_name,
marked_comment=marked_comment,
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
marked_name="",
marked_comment="",
)
# commit db session changes
session.add(workflow)
# trigger app workflow events TODO
# app_published_workflow_was_updated.send(pipeline, published_workflow=workflow)
# update dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
DatasetService.update_rag_pipeline_dataset_settings(
session=session,
dataset=dataset,
knowledge_base_setting=knowledge_base_setting,
has_published=pipeline.is_published
)
# return new workflow
return workflow

View File

@ -4,13 +4,14 @@ import logging
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import Optional
from typing import Optional, cast
from urllib.parse import urlparse
from uuid import uuid4
import yaml # type: ignore
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from flask_login import current_user
from packaging import version
from pydantic import BaseModel, Field
from sqlalchemy import select
@ -31,7 +32,10 @@ from factories import variable_factory
from models import Account
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
from models.workflow import Workflow
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration,
RagPipelineDatasetCreateEntity,
)
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -540,9 +544,6 @@ class RagPipelineDslService:
# Update existing pipeline
pipeline.name = pipeline_data.get("name", pipeline.name)
pipeline.description = pipeline_data.get("description", pipeline.description)
pipeline.icon_type = icon_type
pipeline.icon = icon
pipeline.icon_background = pipeline_data.get("icon_background", pipeline.icon_background)
pipeline.updated_by = account.id
else:
if account.current_tenant_id is None:
@ -554,12 +555,6 @@ class RagPipelineDslService:
pipeline.tenant_id = account.current_tenant_id
pipeline.name = pipeline_data.get("name", "")
pipeline.description = pipeline_data.get("description", "")
pipeline.icon_type = icon_type
pipeline.icon = icon
pipeline.icon_background = pipeline_data.get("icon_background", "#FFFFFF")
pipeline.enable_site = True
pipeline.enable_api = True
pipeline.use_icon_as_answer_icon = pipeline_data.get("use_icon_as_answer_icon", False)
pipeline.created_by = account.id
pipeline.updated_by = account.id
@ -674,26 +669,6 @@ class RagPipelineDslService:
)
]
@classmethod
def _append_model_config_export_data(cls, export_data: dict, pipeline: Pipeline) -> None:
"""
Append model config export data
:param export_data: export data
:param pipeline: Pipeline instance
"""
app_model_config = pipeline.app_model_config
if not app_model_config:
raise ValueError("Missing app configuration, please check.")
export_data["model_config"] = app_model_config.to_dict()
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
for d in DependenciesAnalysisService.generate_dependencies(
tenant_id=pipeline.tenant_id, dependencies=dependencies
)
]
@classmethod
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
"""
@ -863,3 +838,46 @@ class RagPipelineDslService:
return pt.decode()
except Exception:
return None
@staticmethod
def create_rag_pipeline_dataset(
tenant_id: str,
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
):
# check if dataset name already exists
if (
db.session.query(Dataset)
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
):
raise ValueError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
dataset = Dataset(
name=rag_pipeline_dataset_create_entity.name,
description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag-pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
)
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
account = cast(Account, current_user)
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=dataset,
)
return {
"id": rag_pipeline_import_info.id,
"dataset_id": dataset.id,
"pipeline_id": rag_pipeline_import_info.pipeline_id,
"status": rag_pipeline_import_info.status,
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
"current_dsl_version": rag_pipeline_import_info.current_dsl_version,
"error": rag_pipeline_import_info.error,
}