Merge branch 'feat/r2' into deploy/rag-dev

# Conflicts:
#	docker/docker-compose.middleware.yaml
This commit is contained in:
jyong
2025-06-06 17:15:24 +08:00
4 changed files with 156 additions and 3 deletions

View File

@ -300,6 +300,86 @@ class PublishedRagPipelineRunApi(Resource):
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
"""
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("job_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
args = parser.parse_args()
job_id = args.get("job_id")
if job_id == None:
raise ValueError("missing job_id")
datasource_type = args.get("datasource_type")
if datasource_type == None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.run_datasource_workflow_node_status(
pipeline=pipeline,
node_id=node_id,
job_id=job_id,
account=current_user,
datasource_type=datasource_type,
is_published=True
)
return result
class RagPipelineDraftDatasourceNodeRunStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
"""
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("job_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
args = parser.parse_args()
job_id = args.get("job_id")
if job_id == None:
raise ValueError("missing job_id")
datasource_type = args.get("datasource_type")
if datasource_type == None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.run_datasource_workflow_node_status(
pipeline=pipeline,
node_id=node_id,
job_id=job_id,
account=current_user,
datasource_type=datasource_type,
is_published=False
)
return result
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@setup_required
@ -894,6 +974,14 @@ api.add_resource(
RagPipelinePublishedDatasourceNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelinePublishedDatasourceNodeRunStatusApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run-status",
)
api.add_resource(
RagPipelineDraftDatasourceNodeRunStatusApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run-status",
)
api.add_resource(
RagPipelineDrafDatasourceNodeRunApi,

View File

@ -304,4 +304,6 @@ class GetWebsiteCrawlResponse(BaseModel):
Get website crawl response
"""
result: list[WebSiteInfo]
result: Optional[list[WebSiteInfo]] = []
job_id: str = Field(..., description="The job id")
status: str = Field(..., description="The status of the job")

View File

@ -415,6 +415,67 @@ class RagPipelineService:
return workflow_node_execution
def run_datasource_workflow_node_status(
self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, datasource_type: str, is_published: bool
) -> dict:
"""
Run published workflow datasource
"""
if is_published:
# fetch published workflow by app_model
workflow = self.get_published_workflow(pipeline=pipeline)
else:
workflow = self.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
# run draft workflow node
datasource_node_data = None
start_at = time.perf_counter()
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get('provider_name'),
plugin_id=datasource_node_data.get('plugin_id'),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
match datasource_type:
case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
user_id=account.id,
datasource_parameters={"job_id": job_id},
provider_type=datasource_runtime.datasource_provider_type(),
)
return {
"result": [result.model_dump() for result in website_crawl_result.result],
"job_id": website_crawl_result.job_id,
"status": website_crawl_result.status,
"provider_type": datasource_node_data.get("provider_type"),
}
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
def run_datasource_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool
) -> dict: