This commit is contained in:
jyong
2025-05-16 17:22:17 +08:00
parent 9e72afee3c
commit 8bea88c8cc
11 changed files with 80 additions and 41 deletions

View File

@ -38,7 +38,7 @@ class PipelineTemplateListApi(Resource):
@account_initialization_required
@enterprise_license_required
def get(self):
type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"])
type = request.args.get("type", default="built-in", type=str)
language = request.args.get("language", default="en-US", type=str)
# get pipeline templates
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
@ -107,7 +107,7 @@ class CustomizedPipelineTemplateApi(Resource):
pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found.")
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True)
return {"data": dsl}, 200

View File

@ -90,11 +90,10 @@ class DraftRagPipelineApi(Resource):
if "application/json" in content_type:
parser = reqparse.RequestParser()
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
parser.add_argument("hash", type=str, required=False, location="json")
parser.add_argument("environment_variables", type=list, required=False, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json")
parser.add_argument("pipeline_variables", type=dict, required=False, location="json")
parser.add_argument("rag_pipeline_variables", type=dict, required=False, location="json")
args = parser.parse_args()
elif "text/plain" in content_type:
try:
@ -111,7 +110,7 @@ class DraftRagPipelineApi(Resource):
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
"pipeline_variables": data.get("pipeline_variables"),
"rag_pipeline_variables": data.get("rag_pipeline_variables"),
}
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
@ -130,21 +129,20 @@ class DraftRagPipelineApi(Resource):
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
pipeline_variables_list = args.get("pipeline_variables") or {}
pipeline_variables = {
rag_pipeline_variables_list = args.get("rag_pipeline_variables") or {}
rag_pipeline_variables = {
k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v]
for k, v in pipeline_variables_list.items()
for k, v in rag_pipeline_variables_list.items()
}
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline,
graph=args["graph"],
features=args["features"],
unique_hash=args.get("hash"),
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
pipeline_variables=pipeline_variables,
rag_pipeline_variables=rag_pipeline_variables,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
@ -476,7 +474,7 @@ class RagPipelineConfigApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
def get(self, pipeline_id):
return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}
@ -792,5 +790,5 @@ api.add_resource(
)
api.add_resource(
DatasourceListApi,
"/rag/pipelines/datasources",
"/rag/pipelines/datasource-plugins",
)