This commit is contained in:
jyong
2025-05-28 17:56:04 +08:00
parent 5fc2bc58a9
commit 7f59ffe7af
32 changed files with 680 additions and 202 deletions

View File

@ -15,6 +15,7 @@ from libs.login import login_required
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
def _validate_name(name):
@ -91,7 +92,7 @@ class CreateRagPipelineDatasetApi(Resource):
raise Forbidden()
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args)
try:
import_info = DatasetService.create_rag_pipeline_dataset(
import_info = RagPipelineDslService.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)

View File

@ -40,6 +40,7 @@ from libs.login import current_user, login_required
from models.account import Account
from models.dataset import Pipeline
from models.model import EndUser
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
@ -282,15 +283,18 @@ class PublishedRagPipelineRunApi(Resource):
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("is_preview", type=bool, required=True, location="json", default=False)
parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
try:
response = PipelineGenerateService.generate(
pipeline=pipeline,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED,
streaming=True,
streaming=streaming,
)
return helper.compact_generate_response(response)
@ -459,16 +463,17 @@ class PublishedRagPipelineApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
parser.add_argument("knowledge_base_setting", type=dict, location="json", help="Invalid knowledge base setting.")
args = parser.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
if not args.get("knowledge_base_setting"):
raise ValueError("Missing knowledge base setting.")
knowledge_base_setting_data = args.get("knowledge_base_setting")
if not knowledge_base_setting_data:
raise ValueError("Missing knowledge base setting.")
knowledge_base_setting = KnowledgeBaseUpdateConfiguration(**knowledge_base_setting_data)
rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session:
pipeline = session.merge(pipeline)
@ -476,8 +481,7 @@ class PublishedRagPipelineApi(Resource):
session=session,
pipeline=pipeline,
account=current_user,
marked_name=args.marked_name or "",
marked_comment=args.marked_comment or "",
knowledge_base_setting=knowledge_base_setting,
)
pipeline.is_published = True
pipeline.workflow_id = workflow.id