mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
fix(rag-pipeline-dsl): dsl import session error
This commit is contained in:
@ -352,9 +352,10 @@ class RagPipelineService:
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
|
||||
|
||||
# update dataset
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
with Session(db.engine) as session:
|
||||
dataset = pipeline.retrieve_dataset(session=session)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
DatasetService.update_rag_pipeline_dataset_settings(
|
||||
session=session,
|
||||
dataset=dataset,
|
||||
@ -1110,9 +1111,10 @@ class RagPipelineService:
|
||||
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
with Session(db.engine) as session:
|
||||
dataset = pipeline.retrieve_dataset(session=session)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
|
||||
# check template name is exist
|
||||
template_name = args.get("name")
|
||||
@ -1136,7 +1138,9 @@ class RagPipelineService:
|
||||
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
|
||||
|
||||
pipeline_customized_template = PipelineCustomizedTemplate(
|
||||
name=args.get("name"),
|
||||
|
||||
@ -30,7 +30,6 @@ from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import variable_factory
|
||||
from models import Account
|
||||
@ -235,10 +234,7 @@ class RagPipelineDslService:
|
||||
status=ImportStatus.FAILED,
|
||||
error="Pipeline not found",
|
||||
)
|
||||
dataset = pipeline.dataset
|
||||
if dataset:
|
||||
self._session.merge(dataset)
|
||||
dataset_name = dataset.name
|
||||
dataset = pipeline.retrieve_dataset(session=self._session)
|
||||
|
||||
# If major version mismatch, store import info in Redis
|
||||
if status == ImportStatus.PENDING:
|
||||
@ -300,7 +296,7 @@ class RagPipelineDslService:
|
||||
):
|
||||
raise ValueError("Chunk structure is not compatible with the published pipeline")
|
||||
if not dataset:
|
||||
datasets = db.session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all()
|
||||
datasets = self._session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all()
|
||||
names = [dataset.name for dataset in datasets]
|
||||
generate_name = generate_incremental_name(names, name)
|
||||
dataset = Dataset(
|
||||
@ -321,7 +317,7 @@ class RagPipelineDslService:
|
||||
)
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
self._session.query(DatasetCollectionBinding)
|
||||
.filter(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
@ -339,8 +335,8 @@ class RagPipelineDslService:
|
||||
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
|
||||
type="dataset",
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.commit()
|
||||
self._session.add(dataset_collection_binding)
|
||||
self._session.commit()
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
@ -454,7 +450,7 @@ class RagPipelineDslService:
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
self._session.query(DatasetCollectionBinding)
|
||||
.filter(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
@ -472,8 +468,8 @@ class RagPipelineDslService:
|
||||
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
|
||||
type="dataset",
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.commit()
|
||||
self._session.add(dataset_collection_binding)
|
||||
self._session.commit()
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
@ -538,18 +534,10 @@ class RagPipelineDslService:
|
||||
account: Account,
|
||||
dependencies: Optional[list[PluginDependency]] = None,
|
||||
) -> Pipeline:
|
||||
"""Create a new app or update an existing one."""
|
||||
if not account.current_tenant_id:
|
||||
raise ValueError("Tenant id is required")
|
||||
"""Create a new app or update an existing one."""
|
||||
pipeline_data = data.get("rag_pipeline", {})
|
||||
# Set icon type
|
||||
icon_type_value = pipeline_data.get("icon_type")
|
||||
if icon_type_value in ["emoji", "link"]:
|
||||
icon_type = icon_type_value
|
||||
else:
|
||||
icon_type = "emoji"
|
||||
icon = str(pipeline_data.get("icon", ""))
|
||||
|
||||
# Initialize pipeline based on mode
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
@ -609,7 +597,7 @@ class RagPipelineDslService:
|
||||
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
|
||||
)
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
self._session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
@ -632,8 +620,8 @@ class RagPipelineDslService:
|
||||
conversation_variables=conversation_variables,
|
||||
rag_pipeline_variables=rag_pipeline_variables_list,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.flush()
|
||||
self._session.add(workflow)
|
||||
self._session.flush()
|
||||
pipeline.workflow_id = workflow.id
|
||||
else:
|
||||
workflow.graph = json.dumps(graph)
|
||||
@ -643,19 +631,18 @@ class RagPipelineDslService:
|
||||
workflow.conversation_variables = conversation_variables
|
||||
workflow.rag_pipeline_variables = rag_pipeline_variables_list
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
self._session.commit()
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def export_rag_pipeline_dsl(cls, pipeline: Pipeline, include_secret: bool = False) -> str:
|
||||
def export_rag_pipeline_dsl(self, pipeline: Pipeline, include_secret: bool = False) -> str:
|
||||
"""
|
||||
Export pipeline
|
||||
:param pipeline: Pipeline instance
|
||||
:param include_secret: Whether include secret variable
|
||||
:return:
|
||||
"""
|
||||
dataset = pipeline.dataset
|
||||
dataset = pipeline.retrieve_dataset(session=self._session)
|
||||
if not dataset:
|
||||
raise ValueError("Missing dataset for rag pipeline")
|
||||
icon_info = dataset.icon_info
|
||||
@ -672,12 +659,11 @@ class RagPipelineDslService:
|
||||
},
|
||||
}
|
||||
|
||||
cls._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret)
|
||||
self._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret)
|
||||
|
||||
return yaml.dump(export_data, allow_unicode=True) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def _append_workflow_export_data(cls, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None:
|
||||
def _append_workflow_export_data(self, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None:
|
||||
"""
|
||||
Append workflow export data
|
||||
:param export_data: export data
|
||||
@ -685,7 +671,7 @@ class RagPipelineDslService:
|
||||
"""
|
||||
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
self._session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
@ -701,11 +687,11 @@ class RagPipelineDslService:
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
|
||||
self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
|
||||
for dataset_id in dataset_ids
|
||||
]
|
||||
export_data["workflow"] = workflow_dict
|
||||
dependencies = cls._extract_dependencies_from_workflow(workflow)
|
||||
dependencies = self._extract_dependencies_from_workflow(workflow)
|
||||
export_data["dependencies"] = [
|
||||
jsonable_encoder(d.model_dump())
|
||||
for d in DependenciesAnalysisService.generate_dependencies(
|
||||
@ -713,19 +699,17 @@ class RagPipelineDslService:
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
|
||||
def _extract_dependencies_from_workflow(self, workflow: Workflow) -> list[str]:
|
||||
"""
|
||||
Extract dependencies from workflow
|
||||
:param workflow: Workflow instance
|
||||
:return: dependencies list format like ["langgenius/google"]
|
||||
"""
|
||||
graph = workflow.graph_dict
|
||||
dependencies = cls._extract_dependencies_from_workflow_graph(graph)
|
||||
dependencies = self._extract_dependencies_from_workflow_graph(graph)
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def _extract_dependencies_from_workflow_graph(cls, graph: Mapping) -> list[str]:
|
||||
def _extract_dependencies_from_workflow_graph(self, graph: Mapping) -> list[str]:
|
||||
"""
|
||||
Extract dependencies from workflow graph
|
||||
:param graph: Workflow graph
|
||||
@ -882,25 +866,22 @@ class RagPipelineDslService:
|
||||
|
||||
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
|
||||
|
||||
@staticmethod
|
||||
def _generate_aes_key(tenant_id: str) -> bytes:
|
||||
def _generate_aes_key(self, tenant_id: str) -> bytes:
|
||||
"""Generate AES key based on tenant_id"""
|
||||
return hashlib.sha256(tenant_id.encode()).digest()
|
||||
|
||||
@classmethod
|
||||
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str:
|
||||
def encrypt_dataset_id(self, dataset_id: str, tenant_id: str) -> str:
|
||||
"""Encrypt dataset_id using AES-CBC mode"""
|
||||
key = cls._generate_aes_key(tenant_id)
|
||||
key = self._generate_aes_key(tenant_id)
|
||||
iv = key[:16]
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size))
|
||||
return base64.b64encode(ct_bytes).decode()
|
||||
|
||||
@classmethod
|
||||
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None:
|
||||
def decrypt_dataset_id(self, encrypted_data: str, tenant_id: str) -> str | None:
|
||||
"""AES decryption"""
|
||||
try:
|
||||
key = cls._generate_aes_key(tenant_id)
|
||||
key = self._generate_aes_key(tenant_id)
|
||||
iv = key[:16]
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
|
||||
@ -908,39 +889,37 @@ class RagPipelineDslService:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_dataset(
|
||||
self,
|
||||
tenant_id: str,
|
||||
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
||||
):
|
||||
if rag_pipeline_dataset_create_entity.name:
|
||||
# check if dataset name already exists
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
self._session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
):
|
||||
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")
|
||||
else:
|
||||
# generate a random name as Untitled 1 2 3 ...
|
||||
datasets = db.session.query(Dataset).filter_by(tenant_id=tenant_id).all()
|
||||
datasets = self._session.query(Dataset).filter_by(tenant_id=tenant_id).all()
|
||||
names = [dataset.name for dataset in datasets]
|
||||
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
|
||||
names,
|
||||
"Untitled",
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
account = cast(Account, current_user)
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=None,
|
||||
dataset_name=rag_pipeline_dataset_create_entity.name,
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info,
|
||||
)
|
||||
account = cast(Account, current_user)
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = self.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=None,
|
||||
dataset_name=rag_pipeline_dataset_create_entity.name,
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info,
|
||||
)
|
||||
return {
|
||||
"id": rag_pipeline_import_info.id,
|
||||
"dataset_id": rag_pipeline_import_info.dataset_id,
|
||||
|
||||
Reference in New Issue
Block a user