mirror of
https://github.com/langgenius/dify.git
synced 2026-04-23 04:06:13 +08:00
feat: Implement the APIs of downloading evaluation dataset template file and downloading evaluation dataset file/evaluation result file.
This commit is contained in:
@ -2,10 +2,13 @@ import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import request
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@ -15,11 +18,14 @@ from controllers.console.wraps import (
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import UploadFile
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.evaluation_service import EvaluationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -158,7 +164,8 @@ def get_evaluation_target(view_func: Callable[P, R]):
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/dataset-template/download")
|
||||
class EvaluationDatasetTemplateDownloadApi(Resource):
|
||||
@console_ns.doc("download_evaluation_dataset_template")
|
||||
@console_ns.response(200, "Template download URL generated successfully")
|
||||
@console_ns.response(200, "Template file streamed as XLSX attachment")
|
||||
@console_ns.response(400, "Invalid target type or excluded app mode")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -169,14 +176,25 @@ class EvaluationDatasetTemplateDownloadApi(Resource):
|
||||
"""
|
||||
Download evaluation dataset template.
|
||||
|
||||
Generates a download URL for the evaluation dataset template
|
||||
based on the target type (app or snippets).
|
||||
Generates an XLSX template based on the target's input parameters
|
||||
and streams it directly as a file attachment.
|
||||
"""
|
||||
# TODO: Implement actual template generation logic
|
||||
# This is a placeholder implementation
|
||||
return {
|
||||
"download_url": f"/api/evaluation/{target_type}/{target.id}/template.csv",
|
||||
}
|
||||
try:
|
||||
xlsx_content, filename = EvaluationService.generate_dataset_template(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
)
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
encoded_filename = quote(filename)
|
||||
response = Response(
|
||||
xlsx_content,
|
||||
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Length"] = str(len(xlsx_content))
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation")
|
||||
@ -241,18 +259,32 @@ class EvaluationFileDownloadApi(Resource):
|
||||
"""
|
||||
Download evaluation test file or result file.
|
||||
|
||||
Returns file information and download URL for the specified file.
|
||||
Looks up the specified file, verifies it belongs to the same tenant,
|
||||
and returns file info and download URL.
|
||||
"""
|
||||
file_id = str(file_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == file_id,
|
||||
UploadFile.tenant_id == current_tenant_id,
|
||||
)
|
||||
upload_file = session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
# TODO: Implement actual file download logic
|
||||
# This is a placeholder implementation
|
||||
return {
|
||||
"created_at": None,
|
||||
"created_by": None,
|
||||
"test_file": None,
|
||||
"result_file": None,
|
||||
"version": None,
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
|
||||
"download_url": download_url,
|
||||
}
|
||||
|
||||
|
||||
|
||||
178
api/services/evaluation_service.py
Normal file
178
api/services/evaluation_service.py
Normal file
@ -0,0 +1,178 @@
|
||||
import io
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
from openpyxl import Workbook
|
||||
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
|
||||
from openpyxl.utils import get_column_letter
|
||||
|
||||
from models.model import App, AppMode
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.snippet_service import SnippetService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvaluationService:
|
||||
"""
|
||||
Service for evaluation-related operations.
|
||||
|
||||
Provides functionality to generate evaluation dataset templates
|
||||
based on App or Snippet input parameters.
|
||||
"""
|
||||
|
||||
# Excluded app modes that don't support evaluation templates
|
||||
EXCLUDED_APP_MODES = {AppMode.RAG_PIPELINE}
|
||||
|
||||
@classmethod
|
||||
def generate_dataset_template(
|
||||
cls,
|
||||
target: Union[App, CustomizedSnippet],
|
||||
target_type: str,
|
||||
) -> tuple[bytes, str]:
|
||||
"""
|
||||
Generate evaluation dataset template as XLSX bytes.
|
||||
|
||||
Creates an XLSX file with headers based on the evaluation target's input parameters.
|
||||
The first column is index, followed by input parameter columns.
|
||||
|
||||
:param target: App or CustomizedSnippet instance
|
||||
:param target_type: Target type string ("app" or "snippet")
|
||||
:return: Tuple of (xlsx_content_bytes, filename)
|
||||
:raises ValueError: If target type is not supported or app mode is excluded
|
||||
"""
|
||||
# Validate target type
|
||||
if target_type == "app":
|
||||
if not isinstance(target, App):
|
||||
raise ValueError("Invalid target: expected App instance")
|
||||
if AppMode.value_of(target.mode) in cls.EXCLUDED_APP_MODES:
|
||||
raise ValueError(f"App mode '{target.mode}' does not support evaluation templates")
|
||||
input_fields = cls._get_app_input_fields(target)
|
||||
elif target_type == "snippet":
|
||||
if not isinstance(target, CustomizedSnippet):
|
||||
raise ValueError("Invalid target: expected CustomizedSnippet instance")
|
||||
input_fields = cls._get_snippet_input_fields(target)
|
||||
else:
|
||||
raise ValueError(f"Unsupported target type: {target_type}")
|
||||
|
||||
# Generate XLSX template
|
||||
xlsx_content = cls._generate_xlsx_template(input_fields, target.name)
|
||||
|
||||
# Build filename
|
||||
truncated_name = target.name[:10] + "..." if len(target.name) > 10 else target.name
|
||||
filename = f"{truncated_name}-evaluation-dataset.xlsx"
|
||||
|
||||
return xlsx_content, filename
|
||||
|
||||
@classmethod
|
||||
def _get_app_input_fields(cls, app: App) -> list[dict]:
|
||||
"""
|
||||
Get input fields from App's workflow.
|
||||
|
||||
:param app: App instance
|
||||
:return: List of input field definitions
|
||||
"""
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_published_workflow(app_model=app)
|
||||
if not workflow:
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||
|
||||
if not workflow:
|
||||
return []
|
||||
|
||||
# Get user input form from workflow
|
||||
user_input_form = workflow.user_input_form()
|
||||
return user_input_form
|
||||
|
||||
@classmethod
|
||||
def _get_snippet_input_fields(cls, snippet: CustomizedSnippet) -> list[dict]:
|
||||
"""
|
||||
Get input fields from Snippet.
|
||||
|
||||
Tries to get from snippet's own input_fields first,
|
||||
then falls back to workflow's user_input_form.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:return: List of input field definitions
|
||||
"""
|
||||
# Try snippet's own input_fields first
|
||||
input_fields = snippet.input_fields_list
|
||||
if input_fields:
|
||||
return input_fields
|
||||
|
||||
# Fallback to workflow's user_input_form
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_published_workflow(snippet=snippet)
|
||||
if not workflow:
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
|
||||
if workflow:
|
||||
return workflow.user_input_form()
|
||||
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _generate_xlsx_template(cls, input_fields: list[dict], target_name: str) -> bytes:
|
||||
"""
|
||||
Generate XLSX template file content.
|
||||
|
||||
Creates a workbook with:
|
||||
- First row as header row with "index" and input field names
|
||||
- Styled header with background color and borders
|
||||
- Empty data rows ready for user input
|
||||
|
||||
:param input_fields: List of input field definitions
|
||||
:param target_name: Name of the target (for sheet name)
|
||||
:return: XLSX file content as bytes
|
||||
"""
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
|
||||
sheet_name = "Evaluation Dataset"
|
||||
ws.title = sheet_name
|
||||
|
||||
header_font = Font(bold=True, color="FFFFFF")
|
||||
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
|
||||
header_alignment = Alignment(horizontal="center", vertical="center")
|
||||
thin_border = Border(
|
||||
left=Side(style="thin"),
|
||||
right=Side(style="thin"),
|
||||
top=Side(style="thin"),
|
||||
bottom=Side(style="thin"),
|
||||
)
|
||||
|
||||
# Build header row
|
||||
headers = ["index"]
|
||||
|
||||
for field in input_fields:
|
||||
field_label = field.get("label") or field.get("variable")
|
||||
headers.append(field_label)
|
||||
|
||||
# Write header row
|
||||
for col_idx, header in enumerate(headers, start=1):
|
||||
cell = ws.cell(row=1, column=col_idx, value=header)
|
||||
cell.font = header_font
|
||||
cell.fill = header_fill
|
||||
cell.alignment = header_alignment
|
||||
cell.border = thin_border
|
||||
|
||||
# Set column widths
|
||||
ws.column_dimensions["A"].width = 10 # index column
|
||||
for col_idx in range(2, len(headers) + 1):
|
||||
ws.column_dimensions[get_column_letter(col_idx)].width = 20
|
||||
|
||||
# Add one empty row with row number for user reference
|
||||
for col_idx in range(1, len(headers) + 1):
|
||||
cell = ws.cell(row=2, column=col_idx, value="")
|
||||
cell.border = thin_border
|
||||
if col_idx == 1:
|
||||
cell.value = 1
|
||||
cell.alignment = Alignment(horizontal="center")
|
||||
|
||||
# Save to bytes
|
||||
output = io.BytesIO()
|
||||
wb.save(output)
|
||||
output.seek(0)
|
||||
|
||||
return output.getvalue()
|
||||
Reference in New Issue
Block a user