fix(rag-pipeline-dsl): dsl import session error

This commit is contained in:
Novice
2025-09-04 11:03:26 +08:00
parent c40dfe7060
commit 68f4d4b97c
5 changed files with 77 additions and 86 deletions

View File

@ -352,9 +352,10 @@ class RagPipelineService:
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
# update dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
with Session(db.engine) as session:
dataset = pipeline.retrieve_dataset(session=session)
if not dataset:
raise ValueError("Dataset not found")
DatasetService.update_rag_pipeline_dataset_settings(
session=session,
dataset=dataset,
@ -1110,9 +1111,10 @@ class RagPipelineService:
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError("Workflow not found")
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
with Session(db.engine) as session:
dataset = pipeline.retrieve_dataset(session=session)
if not dataset:
raise ValueError("Dataset not found")
# check template name is exist
template_name = args.get("name")
@ -1136,7 +1138,9 @@ class RagPipelineService:
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
pipeline_customized_template = PipelineCustomizedTemplate(
name=args.get("name"),

View File

@ -30,7 +30,6 @@ from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.tool.entities import ToolNodeData
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from factories import variable_factory
from models import Account
@ -235,10 +234,7 @@ class RagPipelineDslService:
status=ImportStatus.FAILED,
error="Pipeline not found",
)
dataset = pipeline.dataset
if dataset:
self._session.merge(dataset)
dataset_name = dataset.name
dataset = pipeline.retrieve_dataset(session=self._session)
# If major version mismatch, store import info in Redis
if status == ImportStatus.PENDING:
@ -300,7 +296,7 @@ class RagPipelineDslService:
):
raise ValueError("Chunk structure is not compatible with the published pipeline")
if not dataset:
datasets = db.session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all()
datasets = self._session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all()
names = [dataset.name for dataset in datasets]
generate_name = generate_incremental_name(names, name)
dataset = Dataset(
@ -321,7 +317,7 @@ class RagPipelineDslService:
)
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
self._session.query(DatasetCollectionBinding)
.filter(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
@ -339,8 +335,8 @@ class RagPipelineDslService:
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
)
db.session.add(dataset_collection_binding)
db.session.commit()
self._session.add(dataset_collection_binding)
self._session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = knowledge_configuration.embedding_model
@ -454,7 +450,7 @@ class RagPipelineDslService:
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
self._session.query(DatasetCollectionBinding)
.filter(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
@ -472,8 +468,8 @@ class RagPipelineDslService:
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
)
db.session.add(dataset_collection_binding)
db.session.commit()
self._session.add(dataset_collection_binding)
self._session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = knowledge_configuration.embedding_model
@ -538,18 +534,10 @@ class RagPipelineDslService:
account: Account,
dependencies: Optional[list[PluginDependency]] = None,
) -> Pipeline:
"""Create a new app or update an existing one."""
if not account.current_tenant_id:
raise ValueError("Tenant id is required")
"""Create a new app or update an existing one."""
pipeline_data = data.get("rag_pipeline", {})
# Set icon type
icon_type_value = pipeline_data.get("icon_type")
if icon_type_value in ["emoji", "link"]:
icon_type = icon_type_value
else:
icon_type = "emoji"
icon = str(pipeline_data.get("icon", ""))
# Initialize pipeline based on mode
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
@ -609,7 +597,7 @@ class RagPipelineDslService:
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
)
workflow = (
db.session.query(Workflow)
self._session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
@ -632,8 +620,8 @@ class RagPipelineDslService:
conversation_variables=conversation_variables,
rag_pipeline_variables=rag_pipeline_variables_list,
)
db.session.add(workflow)
db.session.flush()
self._session.add(workflow)
self._session.flush()
pipeline.workflow_id = workflow.id
else:
workflow.graph = json.dumps(graph)
@ -643,19 +631,18 @@ class RagPipelineDslService:
workflow.conversation_variables = conversation_variables
workflow.rag_pipeline_variables = rag_pipeline_variables_list
# commit db session changes
db.session.commit()
self._session.commit()
return pipeline
@classmethod
def export_rag_pipeline_dsl(cls, pipeline: Pipeline, include_secret: bool = False) -> str:
def export_rag_pipeline_dsl(self, pipeline: Pipeline, include_secret: bool = False) -> str:
"""
Export pipeline
:param pipeline: Pipeline instance
:param include_secret: Whether include secret variable
:return:
"""
dataset = pipeline.dataset
dataset = pipeline.retrieve_dataset(session=self._session)
if not dataset:
raise ValueError("Missing dataset for rag pipeline")
icon_info = dataset.icon_info
@ -672,12 +659,11 @@ class RagPipelineDslService:
},
}
cls._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret)
self._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret)
return yaml.dump(export_data, allow_unicode=True) # type: ignore
@classmethod
def _append_workflow_export_data(cls, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None:
def _append_workflow_export_data(self, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None:
"""
Append workflow export data
:param export_data: export data
@ -685,7 +671,7 @@ class RagPipelineDslService:
"""
workflow = (
db.session.query(Workflow)
self._session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
@ -701,11 +687,11 @@ class RagPipelineDslService:
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
for dataset_id in dataset_ids
]
export_data["workflow"] = workflow_dict
dependencies = cls._extract_dependencies_from_workflow(workflow)
dependencies = self._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
for d in DependenciesAnalysisService.generate_dependencies(
@ -713,19 +699,17 @@ class RagPipelineDslService:
)
]
@classmethod
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
def _extract_dependencies_from_workflow(self, workflow: Workflow) -> list[str]:
"""
Extract dependencies from workflow
:param workflow: Workflow instance
:return: dependencies list format like ["langgenius/google"]
"""
graph = workflow.graph_dict
dependencies = cls._extract_dependencies_from_workflow_graph(graph)
dependencies = self._extract_dependencies_from_workflow_graph(graph)
return dependencies
@classmethod
def _extract_dependencies_from_workflow_graph(cls, graph: Mapping) -> list[str]:
def _extract_dependencies_from_workflow_graph(self, graph: Mapping) -> list[str]:
"""
Extract dependencies from workflow graph
:param graph: Workflow graph
@ -882,25 +866,22 @@ class RagPipelineDslService:
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
@staticmethod
def _generate_aes_key(tenant_id: str) -> bytes:
def _generate_aes_key(self, tenant_id: str) -> bytes:
"""Generate AES key based on tenant_id"""
return hashlib.sha256(tenant_id.encode()).digest()
@classmethod
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str:
def encrypt_dataset_id(self, dataset_id: str, tenant_id: str) -> str:
"""Encrypt dataset_id using AES-CBC mode"""
key = cls._generate_aes_key(tenant_id)
key = self._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size))
return base64.b64encode(ct_bytes).decode()
@classmethod
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None:
def decrypt_dataset_id(self, encrypted_data: str, tenant_id: str) -> str | None:
"""AES decryption"""
try:
key = cls._generate_aes_key(tenant_id)
key = self._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
@ -908,39 +889,37 @@ class RagPipelineDslService:
except Exception:
return None
@staticmethod
def create_rag_pipeline_dataset(
self,
tenant_id: str,
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
):
if rag_pipeline_dataset_create_entity.name:
# check if dataset name already exists
if (
db.session.query(Dataset)
self._session.query(Dataset)
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
):
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")
else:
# generate a random name as Untitled 1 2 3 ...
datasets = db.session.query(Dataset).filter_by(tenant_id=tenant_id).all()
datasets = self._session.query(Dataset).filter_by(tenant_id=tenant_id).all()
names = [dataset.name for dataset in datasets]
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
names,
"Untitled",
)
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
account = cast(Account, current_user)
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=None,
dataset_name=rag_pipeline_dataset_create_entity.name,
icon_info=rag_pipeline_dataset_create_entity.icon_info,
)
account = cast(Account, current_user)
rag_pipeline_import_info: RagPipelineImportInfo = self.import_rag_pipeline(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=None,
dataset_name=rag_pipeline_dataset_create_entity.name,
icon_info=rag_pipeline_dataset_create_entity.icon_info,
)
return {
"id": rag_pipeline_import_info.id,
"dataset_id": rag_pipeline_import_info.dataset_id,