This commit is contained in:
jyong
2025-05-29 23:04:04 +08:00
parent a025db137d
commit e7c48c0b69
12 changed files with 339 additions and 202 deletions

View File

@ -1,5 +1,6 @@
import logging
import yaml
from flask import request
from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
@ -12,10 +13,9 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from libs.login import login_required
from models.dataset import Pipeline, PipelineCustomizedTemplate
from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
logger = logging.getLogger(__name__)
@ -84,8 +84,8 @@ class CustomizedPipelineTemplateApi(Resource):
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args)
pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return pipeline_template, 200
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200
@setup_required
@login_required
@ -106,13 +106,41 @@ class CustomizedPipelineTemplateApi(Resource):
)
if not template:
raise ValueError("Customized pipeline template not found.")
pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found.")
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True)
dsl = yaml.safe_load(template.yaml_content)
return {"data": dsl}, 200
class CustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, pipeline_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
args = parser.parse_args()
rag_pipeline_service = RagPipelineService()
RagPipelineService.publish_customized_pipeline_template(pipeline_id, args)
return 200
api.add_resource(
PipelineTemplateListApi,