mirror of
https://github.com/langgenius/dify.git
synced 2026-06-02 15:06:32 +08:00
Compare commits
2 Commits
feat/evalu
...
cursor/ref
| Author | SHA1 | Date | |
|---|---|---|---|
| 6f582c1b2f | |||
| f734c35443 |
@ -20,11 +20,11 @@
|
||||
```typescript
|
||||
// ❌ WRONG: Don't mock base components
|
||||
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
|
||||
vi.mock('@langgenius/dify-ui/button', () => ({ children }: any) => <button>{children}</button>)
|
||||
vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
|
||||
|
||||
// ✅ CORRECT: Import and use real base components
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { Button } from '@langgenius/dify-ui/button'
|
||||
import Button from '@/app/components/base/button'
|
||||
// They will render normally in tests
|
||||
```
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_publisher,trigger_refresh_executor,retention\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
|
||||
1
.github/workflows/autofix.yml
vendored
1
.github/workflows/autofix.yml
vendored
@ -120,6 +120,7 @@ jobs:
|
||||
- name: ESLint autofix
|
||||
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd web
|
||||
vp exec eslint --concurrency=2 --prune-suppressions --quiet || true
|
||||
|
||||
- if: github.event_name != 'merge_group'
|
||||
|
||||
18
.github/workflows/pyrefly-diff-comment.yml
vendored
18
.github/workflows/pyrefly-diff-comment.yml
vendored
@ -76,11 +76,13 @@ jobs:
|
||||
diff += '\\n\\n... (truncated) ...';
|
||||
}
|
||||
|
||||
if (diff.trim()) {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>',
|
||||
});
|
||||
}
|
||||
const body = diff.trim()
|
||||
? '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>'
|
||||
: '### Pyrefly Diff\nNo changes detected.';
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
|
||||
34
.github/workflows/web-tests.yml
vendored
34
.github/workflows/web-tests.yml
vendored
@ -89,37 +89,3 @@ jobs:
|
||||
flags: web
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
|
||||
dify-ui-test:
|
||||
name: dify-ui Tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./packages/dify-ui
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Install Chromium for Browser Mode
|
||||
run: vp exec playwright install --with-deps chromium
|
||||
|
||||
- name: Run dify-ui tests
|
||||
run: vp test run --coverage --silent=passed-only
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
with:
|
||||
directory: packages/dify-ui/coverage
|
||||
flags: dify-ui
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
|
||||
@ -77,7 +77,7 @@ if $web_modified; then
|
||||
fi
|
||||
|
||||
cd ./web || exit 1
|
||||
pnpm exec vp staged
|
||||
vp staged
|
||||
|
||||
if $web_ts_modified; then
|
||||
echo "Running TypeScript type-check:tsgo"
|
||||
|
||||
15
.vscode/launch.json.template
vendored
15
.vscode/launch.json.template
vendored
@ -2,10 +2,21 @@
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: API (gevent)",
|
||||
"name": "Python: Flask API",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/api/app.py",
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_ENV": "development"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
"--host=0.0.0.0",
|
||||
"--port=5001",
|
||||
"--no-debugger",
|
||||
"--no-reload"
|
||||
],
|
||||
"jinja": true,
|
||||
"justMyCode": true,
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
|
||||
@ -33,9 +33,6 @@ TRIGGER_URL=http://localhost:5001
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
# Collaboration mode toggle
|
||||
ENABLE_COLLABORATION_MODE=false
|
||||
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
|
||||
@ -106,6 +106,3 @@ msg = "Use Pydantic payload/query models instead of reqparse."
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"]
|
||||
msg = "Use Pydantic payload/query models instead of reqparse."
|
||||
|
||||
[lint.isort]
|
||||
known-first-party = ["graphon"]
|
||||
18
api/.vscode/launch.json.example
vendored
18
api/.vscode/launch.json.example
vendored
@ -3,21 +3,29 @@
|
||||
"compounds": [
|
||||
{
|
||||
"name": "Launch Flask and Celery",
|
||||
"configurations": ["Python: API (gevent)", "Python: Celery"]
|
||||
"configurations": ["Python: Flask", "Python: Celery"]
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: API (gevent)",
|
||||
"consoleName": "API",
|
||||
"name": "Python: Flask",
|
||||
"consoleName": "Flask",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "${workspaceFolder}/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"envFile": ".env",
|
||||
"program": "${workspaceFolder}/app.py",
|
||||
"module": "flask",
|
||||
"justMyCode": true,
|
||||
"jinja": true
|
||||
"jinja": true,
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
"--port=5001"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Python: Celery",
|
||||
|
||||
29
api/app.py
29
api/app.py
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
@ -10,35 +9,17 @@ if TYPE_CHECKING:
|
||||
celery: Celery
|
||||
|
||||
|
||||
HOST = "0.0.0.0"
|
||||
PORT = 5001
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_db_command() -> bool:
|
||||
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def log_startup_banner(host: str, port: int) -> None:
|
||||
debugger_attached = sys.gettrace() is not None
|
||||
logger.info("Serving Dify API via gevent WebSocket server")
|
||||
logger.info("Bound to http://%s:%s", host, port)
|
||||
logger.info("Debugger attached: %s", "on" if debugger_attached else "off")
|
||||
logger.info("Press CTRL+C to quit")
|
||||
|
||||
|
||||
# create app
|
||||
flask_app = None
|
||||
socketio_app = None
|
||||
|
||||
if is_db_command():
|
||||
from app_factory import create_migrations_app
|
||||
|
||||
app = create_migrations_app()
|
||||
socketio_app = app
|
||||
flask_app = app
|
||||
else:
|
||||
# Gunicorn and Celery handle monkey patching automatically in production by
|
||||
# specifying the `gevent` worker class. Manual monkey patching is not required here.
|
||||
@ -49,14 +30,8 @@ else:
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
socketio_app, flask_app = create_app()
|
||||
app = flask_app
|
||||
app = create_app()
|
||||
celery = cast("Celery", app.extensions["celery"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gevent import pywsgi
|
||||
from geventwebsocket.handler import WebSocketHandler # type: ignore[reportMissingTypeStubs]
|
||||
|
||||
log_startup_banner(HOST, PORT)
|
||||
server = pywsgi.WSGIServer((HOST, PORT), socketio_app, handler_class=WebSocketHandler)
|
||||
server.serve_forever()
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import socketio # type: ignore[reportMissingTypeStubs]
|
||||
from flask import request
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
@ -11,7 +10,6 @@ from contexts.wrapper import RecyclableContextVar
|
||||
from controllers.console.error import UnauthorizedAndForceLogout
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_socketio import sio
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
@ -124,18 +122,14 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
return dify_app
|
||||
|
||||
|
||||
def create_app() -> tuple[socketio.WSGIApp, DifyApp]:
|
||||
def create_app() -> DifyApp:
|
||||
start_time = time.perf_counter()
|
||||
app = create_flask_app_with_configs()
|
||||
initialize_extensions(app)
|
||||
|
||||
sio.app = app
|
||||
socketio_app = socketio.WSGIApp(sio, app)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
|
||||
return socketio_app, app
|
||||
return app
|
||||
|
||||
|
||||
def initialize_extensions(app: DifyApp):
|
||||
|
||||
@ -2,7 +2,6 @@ import base64
|
||||
import secrets
|
||||
|
||||
import click
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
@ -44,11 +43,10 @@ def reset_password(email, new_password, password_confirm):
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
with Session(db.engine) as session:
|
||||
account = session.merge(account)
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
session.commit()
|
||||
account = db.session.merge(account)
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
@ -79,10 +77,9 @@ def reset_email(email, new_email, email_confirm):
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.merge(account)
|
||||
account.email = normalized_new_email
|
||||
session.commit()
|
||||
account = db.session.merge(account)
|
||||
account.email = normalized_new_email
|
||||
db.session.commit()
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
|
||||
@ -1274,13 +1274,6 @@ class PositionConfig(BaseSettings):
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class CollaborationConfig(BaseSettings):
|
||||
ENABLE_COLLABORATION_MODE: bool = Field(
|
||||
description="Whether to enable collaboration mode features across the workspace",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class LoginConfig(BaseSettings):
|
||||
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
|
||||
description="whether to enable email code login",
|
||||
@ -1373,32 +1366,6 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class EvaluationConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for evaluation runtime
|
||||
"""
|
||||
|
||||
EVALUATION_FRAMEWORK: str = Field(
|
||||
description="Evaluation framework to use (ragas/deepeval/none)",
|
||||
default="none",
|
||||
)
|
||||
|
||||
EVALUATION_MAX_CONCURRENT_RUNS: PositiveInt = Field(
|
||||
description="Maximum number of concurrent evaluation runs per tenant",
|
||||
default=3,
|
||||
)
|
||||
|
||||
EVALUATION_MAX_DATASET_ROWS: PositiveInt = Field(
|
||||
description="Maximum number of rows allowed in an evaluation dataset",
|
||||
default=500,
|
||||
)
|
||||
|
||||
EVALUATION_TASK_TIMEOUT: PositiveInt = Field(
|
||||
description="Timeout in seconds for a single evaluation task",
|
||||
default=3600,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@ -1411,7 +1378,6 @@ class FeatureConfig(
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
EndpointConfig,
|
||||
EvaluationConfig,
|
||||
FileAccessConfig,
|
||||
FileUploadConfig,
|
||||
HttpConfig,
|
||||
@ -1433,7 +1399,6 @@ class FeatureConfig(
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
CollaborationConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
|
||||
@ -1 +0,0 @@
|
||||
CURRENT_APP_DSL_VERSION = "0.6.0"
|
||||
@ -2,9 +2,9 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
|
||||
from graphon.file import helpers as file_helpers
|
||||
from models.model import IconType
|
||||
|
||||
type JSONValue = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
|
||||
@ -65,7 +65,6 @@ from .app import (
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
workflow_comment,
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
@ -108,9 +107,6 @@ from .datasets.rag_pipeline import (
|
||||
rag_pipeline_workflow,
|
||||
)
|
||||
|
||||
# Import evaluation controllers
|
||||
from .evaluation import evaluation
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
banner,
|
||||
@ -121,13 +117,6 @@ from .explore import (
|
||||
trial,
|
||||
)
|
||||
|
||||
# Import snippet controllers
|
||||
from .snippets import snippet_workflow, snippet_workflow_draft_variable
|
||||
from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# Import snippet controllers
|
||||
from .snippets import snippet_workflow, snippet_workflow_draft_variable
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
|
||||
@ -141,7 +130,6 @@ from .workspace import (
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
snippets,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
@ -179,7 +167,6 @@ __all__ = [
|
||||
"datasource_content_preview",
|
||||
"email_register",
|
||||
"endpoint",
|
||||
"evaluation",
|
||||
"extension",
|
||||
"external",
|
||||
"feature",
|
||||
@ -214,13 +201,6 @@ __all__ = [
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
"snippet_workflow",
|
||||
"snippet_workflow_draft_variable",
|
||||
"snippets",
|
||||
"socketio_workflow",
|
||||
"snippet_workflow",
|
||||
"snippet_workflow_draft_variable",
|
||||
"snippets",
|
||||
"spec",
|
||||
"statistic",
|
||||
"tags",
|
||||
@ -231,7 +211,6 @@ __all__ = [
|
||||
"website",
|
||||
"workflow",
|
||||
"workflow_app_log",
|
||||
"workflow_comment",
|
||||
"workflow_draft_variable",
|
||||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
|
||||
@ -34,7 +34,6 @@ from libs.helper import build_icon_url
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from models.workflow import resolve_workflow_kind
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -330,8 +329,6 @@ class AppPartial(ResponseModel):
|
||||
create_user_name: str | None = None
|
||||
author_name: str | None = None
|
||||
has_draft_trigger: bool | None = None
|
||||
workflow_type: str | None = None
|
||||
workflow_kind: str | None = None
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
@ -366,8 +363,6 @@ class AppDetail(ResponseModel):
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
access_mode: str | None = None
|
||||
workflow_type: str | None = None
|
||||
workflow_kind: str | None = None
|
||||
tags: list[Tag] = Field(default_factory=list)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@ -510,25 +505,6 @@ class AppListApi(Resource):
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
|
||||
workflow_ids = [str(app.workflow_id) for app in app_pagination.items if app.workflow_id]
|
||||
workflow_info_map: dict[str, tuple[str, str]] = {}
|
||||
if workflow_ids:
|
||||
rows = db.session.execute(
|
||||
select(Workflow.id, Workflow.type, Workflow.kind).where(Workflow.id.in_(workflow_ids))
|
||||
).all()
|
||||
workflow_info_map = {
|
||||
str(row.id): (
|
||||
row.type.value if hasattr(row.type, "value") else str(row.type),
|
||||
resolve_workflow_kind(row.kind).value,
|
||||
)
|
||||
for row in rows
|
||||
}
|
||||
|
||||
for app in app_pagination.items:
|
||||
workflow_info = workflow_info_map.get(str(app.workflow_id)) if app.workflow_id else None
|
||||
app.workflow_type = workflow_info[0] if workflow_info else None
|
||||
app.workflow_kind = workflow_info[1] if workflow_info else None
|
||||
|
||||
pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
|
||||
return pagination_model.model_dump(mode="json"), 200
|
||||
|
||||
@ -575,18 +551,6 @@ class AppApi(Resource):
|
||||
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
|
||||
app_model.access_mode = app_setting.access_mode
|
||||
|
||||
if app_model.workflow_id:
|
||||
row = db.session.execute(
|
||||
select(Workflow.type, Workflow.kind).where(Workflow.id == app_model.workflow_id)
|
||||
).first()
|
||||
app_model.workflow_type = (
|
||||
(row.type.value if hasattr(row.type, "value") else str(row.type)) if row else None
|
||||
)
|
||||
app_model.workflow_kind = resolve_workflow_kind(row.kind).value if row else None
|
||||
else:
|
||||
app_model.workflow_type = None
|
||||
app_model.workflow_kind = None
|
||||
|
||||
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
|
||||
@ -2,37 +2,20 @@ from typing import Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import selectinload
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
Conversation as ConversationResponse,
|
||||
)
|
||||
from fields.conversation_fields import (
|
||||
ConversationDetail as ConversationDetailResponse,
|
||||
)
|
||||
from fields.conversation_fields import (
|
||||
ConversationMessageDetail as ConversationMessageDetailResponse,
|
||||
)
|
||||
from fields.conversation_fields import (
|
||||
ConversationPagination as ConversationPaginationResponse,
|
||||
)
|
||||
from fields.conversation_fields import (
|
||||
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
|
||||
)
|
||||
from fields.conversation_fields import (
|
||||
ResultResponse,
|
||||
)
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import AppMode
|
||||
@ -79,16 +62,267 @@ console_ns.schema_model(
|
||||
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
CompletionConversationQuery,
|
||||
ChatConversationQuery,
|
||||
ConversationResponse,
|
||||
ConversationPaginationResponse,
|
||||
ConversationMessageDetailResponse,
|
||||
ConversationWithSummaryPaginationResponse,
|
||||
ConversationDetailResponse,
|
||||
ResultResponse,
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model(
|
||||
"SimpleAccount",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"email": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
feedback_stat_model = console_ns.model(
|
||||
"FeedbackStat",
|
||||
{
|
||||
"like": fields.Integer,
|
||||
"dislike": fields.Integer,
|
||||
},
|
||||
)
|
||||
|
||||
status_count_model = console_ns.model(
|
||||
"StatusCount",
|
||||
{
|
||||
"success": fields.Integer,
|
||||
"failed": fields.Integer,
|
||||
"partial_success": fields.Integer,
|
||||
"paused": fields.Integer,
|
||||
},
|
||||
)
|
||||
|
||||
message_file_model = console_ns.model(
|
||||
"MessageFile",
|
||||
{
|
||||
"id": fields.String,
|
||||
"filename": fields.String,
|
||||
"type": fields.String,
|
||||
"url": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"size": fields.Integer,
|
||||
"transfer_method": fields.String,
|
||||
"belongs_to": fields.String(default="user"),
|
||||
"upload_file_id": fields.String(default=None),
|
||||
},
|
||||
)
|
||||
|
||||
agent_thought_model = console_ns.model(
|
||||
"AgentThought",
|
||||
{
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
},
|
||||
)
|
||||
|
||||
simple_model_config_model = console_ns.model(
|
||||
"SimpleModelConfig",
|
||||
{
|
||||
"model": fields.Raw(attribute="model_dict"),
|
||||
"pre_prompt": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
model_config_model = console_ns.model(
|
||||
"ModelConfig",
|
||||
{
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"model": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"pre_prompt": fields.String,
|
||||
"agent_mode": fields.Raw,
|
||||
},
|
||||
)
|
||||
|
||||
# Models that depend on simple_account_model
|
||||
feedback_model = console_ns.model(
|
||||
"Feedback",
|
||||
{
|
||||
"rating": fields.String,
|
||||
"content": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
},
|
||||
)
|
||||
|
||||
annotation_model = console_ns.model(
|
||||
"Annotation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"content": fields.String,
|
||||
"account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
annotation_hit_history_model = console_ns.model(
|
||||
"AnnotationHitHistory",
|
||||
{
|
||||
"annotation_id": fields.String(attribute="id"),
|
||||
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class MessageTextField(fields.Raw):
|
||||
def format(self, value):
|
||||
return value[0]["text"] if value else ""
|
||||
|
||||
|
||||
# Simple message detail model
|
||||
simple_message_detail_model = console_ns.model(
|
||||
"SimpleMessageDetail",
|
||||
{
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": MessageTextField,
|
||||
"answer": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Message detail model that depends on multiple models
|
||||
message_detail_model = console_ns.model(
|
||||
"MessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": fields.Raw,
|
||||
"message_tokens": fields.Integer,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"answer_tokens": fields.Integer,
|
||||
"provider_response_latency": fields.Float,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"feedbacks": fields.List(fields.Nested(feedback_model)),
|
||||
"workflow_run_id": fields.String,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Conversation models
|
||||
conversation_fields_model = console_ns.model(
|
||||
"Conversation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String(),
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"model_config": fields.Nested(simple_model_config_model),
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"message": fields.Nested(simple_message_detail_model, attribute="first_message"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_pagination_model = console_ns.model(
|
||||
"ConversationPagination",
|
||||
{
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_fields_model), attribute="items"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_message_detail_model = console_ns.model(
|
||||
"ConversationMessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"model_config": fields.Nested(model_config_model),
|
||||
"message": fields.Nested(message_detail_model, attribute="first_message"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_with_summary_model = console_ns.model(
|
||||
"ConversationWithSummary",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"name": fields.String,
|
||||
"summary": fields.String(attribute="summary_or_query"),
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"model_config": fields.Nested(simple_model_config_model),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"status_count": fields.Nested(status_count_model),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_with_summary_pagination_model = console_ns.model(
|
||||
"ConversationWithSummaryPagination",
|
||||
{
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_with_summary_model), attribute="items"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_detail_model = console_ns.model(
|
||||
"ConversationDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"introduction": fields.String,
|
||||
"model_config": fields.Nested(model_config_model),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -98,12 +332,13 @@ class CompletionConversationApi(Resource):
|
||||
@console_ns.doc(description="Get completion conversations with pagination and filtering")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[ConversationPaginationResponse.__name__])
|
||||
@console_ns.response(200, "Success", conversation_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -159,9 +394,7 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
|
||||
return ConversationPaginationResponse.model_validate(conversations, from_attributes=True).model_dump(
|
||||
mode="json"
|
||||
)
|
||||
return conversations
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
|
||||
@ -169,19 +402,19 @@ class CompletionConversationDetailApi(Resource):
|
||||
@console_ns.doc("get_completion_conversation")
|
||||
@console_ns.doc(description="Get completion conversation details with messages")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(200, "Success", console_ns.models[ConversationMessageDetailResponse.__name__])
|
||||
@console_ns.response(200, "Success", conversation_message_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_message_detail_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id):
|
||||
conversation_id = str(conversation_id)
|
||||
return ConversationMessageDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
return _get_conversation(app_model, conversation_id)
|
||||
|
||||
@console_ns.doc("delete_completion_conversation")
|
||||
@console_ns.doc(description="Delete a completion conversation")
|
||||
@ -203,7 +436,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-conversations")
|
||||
@ -212,12 +445,13 @@ class ChatConversationApi(Resource):
|
||||
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[ConversationWithSummaryPaginationResponse.__name__])
|
||||
@console_ns.response(200, "Success", conversation_with_summary_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(conversation_with_summary_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -312,9 +546,7 @@ class ChatConversationApi(Resource):
|
||||
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
|
||||
return ConversationWithSummaryPaginationResponse.model_validate(conversations, from_attributes=True).model_dump(
|
||||
mode="json"
|
||||
)
|
||||
return conversations
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
|
||||
@ -322,19 +554,19 @@ class ChatConversationDetailApi(Resource):
|
||||
@console_ns.doc("get_chat_conversation")
|
||||
@console_ns.doc(description="Get chat conversation details")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(200, "Success", console_ns.models[ConversationDetailResponse.__name__])
|
||||
@console_ns.response(200, "Success", conversation_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(conversation_detail_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id):
|
||||
conversation_id = str(conversation_id)
|
||||
return ConversationDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
return _get_conversation(app_model, conversation_id)
|
||||
|
||||
@console_ns.doc("delete_chat_conversation")
|
||||
@console_ns.doc(description="Delete a chat conversation")
|
||||
@ -356,7 +588,7 @@ class ChatConversationDetailApi(Resource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
|
||||
@ -45,7 +45,7 @@ class ConversationVariableResponse(ResponseModel):
|
||||
def _normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type())
|
||||
return str(exposed_type().value)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, func, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
@ -39,7 +40,6 @@ from fields.conversation_fields import (
|
||||
format_files_contained,
|
||||
to_timestamp,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from graphon.enums import NodeType
|
||||
from graphon.file import File
|
||||
from graphon.file import helpers as file_helpers
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
@ -15,7 +14,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DefaultBlockConfigQuery
|
||||
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.workflow_run import workflow_run_node_execution_model
|
||||
@ -40,7 +39,6 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.online_user_fields import online_user_list_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -48,8 +46,7 @@ from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow, WorkflowKind
|
||||
from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX
|
||||
from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -60,7 +57,6 @@ _file_access_controller = DatabaseFileAccessController()
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
|
||||
MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS = 50
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@ -154,31 +150,6 @@ class ConvertToWorkflowPayload(BaseModel):
|
||||
icon_background: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
keyword: str | None = Field(default=None, max_length=255)
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class WorkflowTypeConvertQuery(BaseModel):
|
||||
target_type: Literal["workflow", "evaluation"]
|
||||
|
||||
|
||||
class WorkflowFeaturesPayload(BaseModel):
|
||||
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
|
||||
|
||||
|
||||
class WorkflowOnlineUsersQuery(BaseModel):
|
||||
app_ids: str = Field(..., description="Comma-separated app IDs")
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
node_id: str
|
||||
|
||||
@ -202,9 +173,6 @@ reg(DefaultBlockConfigQuery)
|
||||
reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(WorkflowTypeConvertQuery)
|
||||
reg(WorkflowFeaturesPayload)
|
||||
reg(WorkflowOnlineUsersQuery)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
|
||||
@ -877,54 +845,6 @@ class PublishedWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/publish/evaluation")
|
||||
class EvaluationPublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("publish_evaluation_workflow")
|
||||
@console_ns.doc(description="Publish draft workflow as evaluation workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
|
||||
@console_ns.response(200, "Evaluation workflow published successfully")
|
||||
@console_ns.response(400, "Invalid workflow or unsupported node type")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Publish draft workflow as evaluation workflow.
|
||||
|
||||
Evaluation workflows cannot include trigger or human-input nodes.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
workflow = workflow_service.publish_evaluation_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
marked_name=args.marked_name or "",
|
||||
marked_comment=args.marked_comment or "",
|
||||
)
|
||||
|
||||
# Keep workflow_id aligned with the latest published workflow.
|
||||
app_model_in_session = session.get(App, app_model.id)
|
||||
if app_model_in_session:
|
||||
app_model_in_session.workflow_id = workflow.id
|
||||
app_model_in_session.updated_by = current_user.id
|
||||
app_model_in_session.updated_at = naive_utc_now()
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||
class DefaultBlockConfigsApi(Resource):
|
||||
@console_ns.doc("get_default_block_configs")
|
||||
@ -1011,32 +931,6 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/features")
|
||||
class WorkflowFeaturesApi(Resource):
|
||||
"""Update draft workflow features."""
|
||||
|
||||
@console_ns.expect(console_ns.models[WorkflowFeaturesPayload.__name__])
|
||||
@console_ns.doc("update_workflow_features")
|
||||
@console_ns.doc(description="Update draft workflow features")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Workflow features updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {})
|
||||
features = args.features
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
|
||||
@ -1122,52 +1016,6 @@ class DraftWorkflowRestoreApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/convert-type")
|
||||
class WorkflowTypeConvertApi(Resource):
|
||||
@console_ns.doc("convert_published_workflow_type")
|
||||
@console_ns.doc(description="Convert current effective published workflow type in-place")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowTypeConvertQuery.__name__])
|
||||
@console_ns.response(200, "Workflow type converted successfully")
|
||||
@console_ns.response(400, "Invalid workflow type or unsupported workflow graph")
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = WorkflowTypeConvertQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
target_type = WorkflowKind.EVALUATION if args.target_type == "evaluation" else WorkflowKind.STANDARD
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
workflow = workflow_service.convert_published_workflow_type(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
target_type=target_type,
|
||||
account=current_user,
|
||||
)
|
||||
except WorkflowNotFoundError as exc:
|
||||
raise NotFound(str(exc)) from exc
|
||||
except IsDraftWorkflowError as exc:
|
||||
raise BadRequest(str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise BadRequest(str(exc)) from exc
|
||||
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"workflow_id": workflow.id,
|
||||
"type": workflow.type.value,
|
||||
"kind": workflow.kind_or_standard,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
|
||||
class WorkflowByIdApi(Resource):
|
||||
@console_ns.doc("update_workflow_by_id")
|
||||
@ -1492,62 +1340,3 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
"status": "error",
|
||||
}
|
||||
), 400
|
||||
|
||||
|
||||
@console_ns.route("/apps/workflows/online-users")
|
||||
class WorkflowOnlineUsersApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__])
|
||||
@console_ns.doc("get_workflow_online_users")
|
||||
@console_ns.doc(description="Get workflow online users")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(online_user_list_fields)
|
||||
def get(self):
|
||||
args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
app_ids = list(dict.fromkeys(app_id.strip() for app_id in args.app_ids.split(",") if app_id.strip()))
|
||||
if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS:
|
||||
raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS} app_ids are allowed per request.")
|
||||
|
||||
if not app_ids:
|
||||
return {"data": []}
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id)
|
||||
|
||||
results = []
|
||||
for app_id in app_ids:
|
||||
if app_id not in accessible_app_ids:
|
||||
continue
|
||||
|
||||
users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}")
|
||||
|
||||
users = []
|
||||
for _, user_info_json in users_json.items():
|
||||
try:
|
||||
user_info = json.loads(user_info_json)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not isinstance(user_info, dict):
|
||||
continue
|
||||
|
||||
avatar = user_info.get("avatar")
|
||||
if isinstance(avatar, str) and avatar and not avatar.startswith(("http://", "https://")):
|
||||
try:
|
||||
user_info["avatar"] = file_helpers.get_signed_file_url(avatar)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to sign workflow online user avatar; using original value. "
|
||||
"app_id=%s avatar=%s error=%s",
|
||||
app_id,
|
||||
avatar,
|
||||
exc,
|
||||
)
|
||||
|
||||
users.append(user_info)
|
||||
results.append({"app_id": app_id, "users": users})
|
||||
|
||||
return {"data": results}
|
||||
|
||||
@ -1,335 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.member_fields import AccountWithRole
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowCommentCreatePayload(BaseModel):
|
||||
content: str = Field(..., description="Comment content")
|
||||
position_x: float = Field(..., description="Comment X position")
|
||||
position_y: float = Field(..., description="Comment Y position")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Comment content")
|
||||
position_x: float | None = Field(default=None, description="Comment X position")
|
||||
position_y: float | None = Field(default=None, description="Comment Y position")
|
||||
mentioned_user_ids: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Mentioned user IDs. Omit to keep existing mentions.",
|
||||
)
|
||||
|
||||
|
||||
class WorkflowCommentReplyPayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentMentionUsersPayload(BaseModel):
|
||||
users: list[AccountWithRole]
|
||||
|
||||
|
||||
for model in (
|
||||
WorkflowCommentCreatePayload,
|
||||
WorkflowCommentUpdatePayload,
|
||||
WorkflowCommentReplyPayload,
|
||||
):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
register_schema_models(console_ns, AccountWithRole, WorkflowCommentMentionUsersPayload)
|
||||
|
||||
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
|
||||
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
|
||||
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
|
||||
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
|
||||
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
|
||||
workflow_comment_reply_create_model = console_ns.model(
|
||||
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
|
||||
)
|
||||
workflow_comment_reply_update_model = console_ns.model(
|
||||
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments")
|
||||
class WorkflowCommentListApi(Resource):
|
||||
"""API for listing and creating workflow comments."""
|
||||
|
||||
@console_ns.doc("list_workflow_comments")
|
||||
@console_ns.doc(description="Get all comments for a workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_basic_model, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
|
||||
@console_ns.doc("create_workflow_comment")
|
||||
@console_ns.doc(description="Create a new workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_create_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
class WorkflowCommentDetailApi(Resource):
|
||||
"""API for managing individual workflow comments."""
|
||||
|
||||
@console_ns.doc("get_workflow_comment")
|
||||
@console_ns.doc(description="Get a specific workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_detail_model)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
@console_ns.doc("update_workflow_comment")
|
||||
@console_ns.doc(description="Update a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_update_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@console_ns.doc("delete_workflow_comment")
|
||||
@console_ns.doc(description="Delete a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(204, "Comment deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
class WorkflowCommentResolveApi(Resource):
|
||||
"""API for resolving and reopening workflow comments."""
|
||||
|
||||
@console_ns.doc("resolve_workflow_comment")
|
||||
@console_ns.doc(description="Resolve a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_resolve_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
class WorkflowCommentReplyApi(Resource):
|
||||
"""API for managing comment replies."""
|
||||
|
||||
@console_ns.doc("create_workflow_comment_reply")
|
||||
@console_ns.doc(description="Add a reply to a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__])
|
||||
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_create_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_reply(
|
||||
comment_id=comment_id,
|
||||
content=payload.content,
|
||||
created_by=current_user.id,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
|
||||
class WorkflowCommentReplyDetailApi(Resource):
|
||||
"""API for managing individual comment replies."""
|
||||
|
||||
@console_ns.doc("update_workflow_comment_reply")
|
||||
@console_ns.doc(description="Update a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__])
|
||||
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_update_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
reply_id=reply_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
@console_ns.doc("delete_workflow_comment_reply")
|
||||
@console_ns.doc(description="Delete a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.response(204, "Reply deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
reply_id=reply_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
class WorkflowCommentMentionUsersApi(Resource):
|
||||
"""API for getting mentionable users for workflow comments."""
|
||||
|
||||
@console_ns.doc("workflow_comment_mention_users")
|
||||
@console_ns.doc(description="Get all users in current tenant for mentions")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(
|
||||
200, "Mentionable users retrieved successfully", console_ns.models[WorkflowCommentMentionUsersPayload.__name__]
|
||||
)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = WorkflowCommentMentionUsersPayload(users=users)
|
||||
return response.model_dump(mode="json"), 200
|
||||
@ -5,6 +5,10 @@ from typing import Any, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from graphon.file import helpers as file_helpers
|
||||
from graphon.variables.segment_group import SegmentGroup
|
||||
from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from graphon.variables.types import SegmentType
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@ -18,13 +22,8 @@ from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from graphon.file import helpers as file_helpers
|
||||
from graphon.variables.segment_group import SegmentGroup
|
||||
from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
@ -46,16 +45,6 @@ class WorkflowDraftVariableUpdatePayload(BaseModel):
|
||||
value: Any | None = Field(default=None, description="Variable value")
|
||||
|
||||
|
||||
class ConversationVariableUpdatePayload(BaseModel):
|
||||
conversation_variables: list[dict[str, Any]] = Field(
|
||||
..., description="Conversation variables for the draft workflow"
|
||||
)
|
||||
|
||||
|
||||
class EnvironmentVariableUpdatePayload(BaseModel):
|
||||
environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableListQuery.__name__,
|
||||
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
@ -64,14 +53,6 @@ console_ns.schema_model(
|
||||
WorkflowDraftVariableUpdatePayload.__name__,
|
||||
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ConversationVariableUpdatePayload.__name__,
|
||||
ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EnvironmentVariableUpdatePayload.__name__,
|
||||
EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
@ -102,7 +83,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
|
||||
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||
value_type = workflow_draft_var.value_type
|
||||
return str(value_type.exposed_type())
|
||||
return value_type.exposed_type().value
|
||||
|
||||
|
||||
class FullContentDict(TypedDict):
|
||||
@ -122,7 +103,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
|
||||
|
||||
result: FullContentDict = {
|
||||
"size_bytes": variable_file.size,
|
||||
"value_type": str(variable_file.value_type.exposed_type()),
|
||||
"value_type": variable_file.value_type.exposed_type().value,
|
||||
"length": variable_file.length,
|
||||
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
|
||||
}
|
||||
@ -529,34 +510,6 @@ class ConversationVariableCollectionApi(Resource):
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
@console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__])
|
||||
@console_ns.doc("update_conversation_variables")
|
||||
@console_ns.doc(description="Update conversation variables for workflow draft")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Conversation variables updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def post(self, app_model: App):
|
||||
payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
conversation_variables_list = payload.conversation_variables
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_conversation_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@ -598,7 +551,7 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": str(v.value_type.exposed_type()),
|
||||
"value_type": v.value_type.exposed_type().value,
|
||||
"value": v.value,
|
||||
# Do not track edited for env vars.
|
||||
"edited": False,
|
||||
@ -608,31 +561,3 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
@console_ns.expect(console_ns.models[EnvironmentVariableUpdatePayload.__name__])
|
||||
@console_ns.doc("update_environment_variables")
|
||||
@console_ns.doc(description="Update environment variables for workflow draft")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Environment variables updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
environment_variables_list = payload.environment_variables
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_environment_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
import base64
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
@ -12,7 +10,6 @@ from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
@ -80,39 +77,3 @@ class PartnerTenants(Resource):
|
||||
raise BadRequest("Invalid partner information")
|
||||
|
||||
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
|
||||
|
||||
|
||||
_DEBUG_KEY = "billing:debug"
|
||||
_DEBUG_TTL = timedelta(days=7)
|
||||
|
||||
|
||||
class DebugDataPayload(BaseModel):
|
||||
type: str = Field(..., min_length=1, description="Data type key")
|
||||
data: str = Field(..., min_length=1, description="Data value to append")
|
||||
|
||||
|
||||
@console_ns.route("/billing/debug/data")
|
||||
class DebugData(Resource):
|
||||
def post(self):
|
||||
body = DebugDataPayload.model_validate(request.get_json(force=True))
|
||||
item = json.dumps({
|
||||
"type": body.type,
|
||||
"data": body.data,
|
||||
"createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
})
|
||||
redis_client.lpush(_DEBUG_KEY, item)
|
||||
redis_client.expire(_DEBUG_KEY, _DEBUG_TTL)
|
||||
return {"result": "ok"}, 201
|
||||
|
||||
def get(self):
|
||||
recent = request.args.get("recent", 10, type=int)
|
||||
items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1)
|
||||
return {
|
||||
"data": [
|
||||
json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items
|
||||
]
|
||||
}
|
||||
|
||||
def delete(self):
|
||||
redis_client.delete(_DEBUG_KEY)
|
||||
return {"result": "ok"}
|
||||
|
||||
@ -1,14 +1,11 @@
|
||||
import json
|
||||
from typing import Any, cast
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@ -25,7 +22,6 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -34,7 +30,6 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from fields.app_fields import app_detail_kernel_fields, related_app_list
|
||||
from fields.dataset_fields import (
|
||||
content_fields,
|
||||
@ -55,19 +50,12 @@ from fields.dataset_fields import (
|
||||
)
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, EvaluationRun, EvaluationTargetType, UploadFile
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermission, DatasetPermissionEnum
|
||||
from models.enums import ApiTokenType, SegmentStatus
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.errors.evaluation import (
|
||||
EvaluationDatasetInvalidError,
|
||||
EvaluationFrameworkNotConfiguredError,
|
||||
EvaluationMaxConcurrentRunsError,
|
||||
EvaluationNotFoundError,
|
||||
)
|
||||
from services.evaluation_service import EvaluationService
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
|
||||
@ -995,432 +983,3 @@ class DatasetAutoDisableLogApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
|
||||
|
||||
|
||||
# ---- Knowledge Base Retrieval Evaluation ----
|
||||
|
||||
|
||||
def _serialize_dataset_evaluation_run(run: EvaluationRun) -> dict[str, Any]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"tenant_id": run.tenant_id,
|
||||
"target_type": run.target_type,
|
||||
"target_id": run.target_id,
|
||||
"evaluation_config_id": run.evaluation_config_id,
|
||||
"status": run.status,
|
||||
"dataset_file_id": run.dataset_file_id,
|
||||
"result_file_id": run.result_file_id,
|
||||
"total_items": run.total_items,
|
||||
"completed_items": run.completed_items,
|
||||
"failed_items": run.failed_items,
|
||||
"progress": run.progress,
|
||||
"metrics_summary": json.loads(run.metrics_summary) if run.metrics_summary else {},
|
||||
"error": run.error,
|
||||
"created_by": run.created_by,
|
||||
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
|
||||
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
|
||||
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_dataset_evaluation_run_item(item: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"id": item.id,
|
||||
"item_index": item.item_index,
|
||||
"inputs": item.inputs_dict,
|
||||
"expected_output": item.expected_output,
|
||||
"actual_output": item.actual_output,
|
||||
"metrics": item.metrics_list,
|
||||
"judgment": item.judgment_dict,
|
||||
"metadata": item.metadata_dict,
|
||||
"error": item.error,
|
||||
"overall_score": item.overall_score,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/template/download")
|
||||
class DatasetEvaluationTemplateDownloadApi(Resource):
|
||||
@console_ns.doc("download_dataset_evaluation_template")
|
||||
@console_ns.response(200, "Template file streamed as XLSX attachment")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
"""Download evaluation dataset template for knowledge base retrieval."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
xlsx_content, filename = EvaluationService.generate_retrieval_dataset_template()
|
||||
encoded_filename = quote(filename)
|
||||
response = Response(
|
||||
xlsx_content,
|
||||
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Length"] = str(len(xlsx_content))
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation")
|
||||
class DatasetEvaluationDetailApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_config")
|
||||
@console_ns.response(200, "Evaluation configuration retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
"""Get evaluation configuration for the knowledge base."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.get_evaluation_config(
|
||||
session, current_tenant_id, "dataset", dataset_id_str
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return {
|
||||
"evaluation_model": None,
|
||||
"evaluation_model_provider": None,
|
||||
"default_metrics": None,
|
||||
"customized_metrics": None,
|
||||
"judgment_config": None,
|
||||
}
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": config.default_metrics_list,
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": config.judgment_config_dict,
|
||||
}
|
||||
|
||||
@console_ns.doc("save_dataset_evaluation_config")
|
||||
@console_ns.response(200, "Evaluation configuration saved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, dataset_id):
|
||||
"""Save evaluation configuration for the knowledge base."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
body = request.get_json(force=True)
|
||||
try:
|
||||
config_data = EvaluationConfigData.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.save_evaluation_config(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type="dataset",
|
||||
target_id=dataset_id_str,
|
||||
account_id=str(current_user.id),
|
||||
data=config_data,
|
||||
)
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": config.default_metrics_list,
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": config.judgment_config_dict,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/run")
|
||||
class DatasetEvaluationRunApi(Resource):
|
||||
@console_ns.doc("start_dataset_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run started")
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
"""Start an evaluation run for the knowledge base retrieval."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
body = request.get_json(force=True)
|
||||
if not body:
|
||||
raise BadRequest("Request body is required.")
|
||||
|
||||
try:
|
||||
run_request = EvaluationRunRequest.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=current_tenant_id).first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("Dataset file not found.")
|
||||
|
||||
try:
|
||||
dataset_content = storage.load_once(upload_file.key)
|
||||
except Exception:
|
||||
raise BadRequest("Failed to read dataset file.")
|
||||
|
||||
if not dataset_content:
|
||||
raise BadRequest("Dataset file is empty.")
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
evaluation_run = EvaluationService.start_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=EvaluationTargetType.KNOWLEDGE_BASE,
|
||||
target_id=dataset_id_str,
|
||||
account_id=str(current_user.id),
|
||||
dataset_file_content=dataset_content,
|
||||
run_request=run_request,
|
||||
)
|
||||
return _serialize_dataset_evaluation_run(evaluation_run), 200
|
||||
except EvaluationFrameworkNotConfiguredError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except EvaluationMaxConcurrentRunsError as e:
|
||||
return {"message": str(e.description)}, 429
|
||||
except EvaluationDatasetInvalidError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/logs")
|
||||
class DatasetEvaluationLogsApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_logs")
|
||||
@console_ns.response(200, "Evaluation logs retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
"""Get evaluation run history for the knowledge base."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 20, type=int)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
runs, total = EvaluationService.get_evaluation_runs(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type="dataset",
|
||||
target_id=dataset_id_str,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": [_serialize_dataset_evaluation_run(run) for run in runs],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>")
|
||||
class DatasetEvaluationRunDetailApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_run_detail")
|
||||
@console_ns.response(200, "Evaluation run detail retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset or run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, run_id):
|
||||
"""Get evaluation run detail including per-item results."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
run_id_str = str(run_id)
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 50, type=int)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.get_evaluation_run_detail(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id_str,
|
||||
)
|
||||
items, total_items = EvaluationService.get_evaluation_run_items(
|
||||
session=session,
|
||||
run_id=run_id_str,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return {
|
||||
"run": _serialize_dataset_evaluation_run(run),
|
||||
"items": {
|
||||
"data": [_serialize_dataset_evaluation_run_item(item) for item in items],
|
||||
"total": total_items,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
}
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>/cancel")
|
||||
class DatasetEvaluationRunCancelApi(Resource):
|
||||
@console_ns.doc("cancel_dataset_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run cancelled")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset or run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id, run_id):
|
||||
"""Cancel a running knowledge base evaluation."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
run_id_str = str(run_id)
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.cancel_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id_str,
|
||||
)
|
||||
return _serialize_dataset_evaluation_run(run)
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/metrics")
|
||||
class DatasetEvaluationMetricsApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_metrics")
|
||||
@console_ns.response(200, "Available retrieval metrics retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
"""Get available evaluation metrics for knowledge base retrieval."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
return {
|
||||
"metrics": EvaluationService.get_supported_metrics(EvaluationCategory.KNOWLEDGE_BASE)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/files/<uuid:file_id>")
|
||||
class DatasetEvaluationFileDownloadApi(Resource):
|
||||
@console_ns.doc("download_dataset_evaluation_file")
|
||||
@console_ns.response(200, "File download URL generated")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset or file not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, file_id):
|
||||
"""Download evaluation test file or result file for the knowledge base."""
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
file_id_str = str(file_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == file_id_str,
|
||||
UploadFile.tenant_id == current_tenant_id,
|
||||
)
|
||||
upload_file = session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
|
||||
"download_url": download_url,
|
||||
}
|
||||
|
||||
@ -1 +0,0 @@
|
||||
# Evaluation controller module
|
||||
@ -1,871 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, ParamSpec, TypeVar, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.workflow import WorkflowListQuery
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from fields.member_fields import simple_account_fields
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, Dataset
|
||||
from models.model import UploadFile
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.errors.evaluation import (
|
||||
EvaluationDatasetInvalidError,
|
||||
EvaluationFrameworkNotConfiguredError,
|
||||
EvaluationMaxConcurrentRunsError,
|
||||
EvaluationNotFoundError,
|
||||
)
|
||||
from services.evaluation_service import EvaluationService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.evaluation import EvaluationRun, EvaluationRunItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Valid evaluation target types
|
||||
EVALUATE_TARGET_TYPES = {"app", "snippets"}
|
||||
|
||||
|
||||
class VersionQuery(BaseModel):
|
||||
"""Query parameters for version endpoint."""
|
||||
|
||||
version: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
VersionQuery,
|
||||
)
|
||||
|
||||
|
||||
# Response field definitions
|
||||
file_info_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
}
|
||||
|
||||
evaluation_log_fields = {
|
||||
"created_at": TimestampField,
|
||||
"created_by": fields.String,
|
||||
"test_file": fields.Nested(
|
||||
console_ns.model(
|
||||
"EvaluationTestFile",
|
||||
file_info_fields,
|
||||
)
|
||||
),
|
||||
"result_file": fields.Nested(
|
||||
console_ns.model(
|
||||
"EvaluationResultFile",
|
||||
file_info_fields,
|
||||
),
|
||||
allow_null=True,
|
||||
),
|
||||
"version": fields.String,
|
||||
}
|
||||
|
||||
evaluation_log_list_model = console_ns.model(
|
||||
"EvaluationLogList",
|
||||
{
|
||||
"data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))),
|
||||
},
|
||||
)
|
||||
|
||||
evaluation_default_metric_node_info_fields = {
|
||||
"node_id": fields.String,
|
||||
"type": fields.String,
|
||||
"title": fields.String,
|
||||
}
|
||||
evaluation_default_metric_item_fields = {
|
||||
"metric": fields.String,
|
||||
"value_type": fields.String,
|
||||
"node_info_list": fields.List(
|
||||
fields.Nested(
|
||||
console_ns.model("EvaluationDefaultMetricNodeInfo", evaluation_default_metric_node_info_fields),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
customized_metrics_fields = {
|
||||
"evaluation_workflow_id": fields.String,
|
||||
"input_fields": fields.Raw,
|
||||
"output_fields": fields.Raw,
|
||||
}
|
||||
|
||||
judgment_condition_fields = {
|
||||
"variable_selector": fields.List(fields.String),
|
||||
"comparison_operator": fields.String,
|
||||
"value": fields.String,
|
||||
}
|
||||
|
||||
judgment_config_fields = {
|
||||
"logical_operator": fields.String,
|
||||
"conditions": fields.List(fields.Nested(console_ns.model("JudgmentCondition", judgment_condition_fields))),
|
||||
}
|
||||
|
||||
evaluation_detail_fields = {
|
||||
"evaluation_model": fields.String,
|
||||
"evaluation_model_provider": fields.String,
|
||||
"default_metrics": fields.List(
|
||||
fields.Nested(console_ns.model("EvaluationDefaultMetricItem_Detail", evaluation_default_metric_item_fields)),
|
||||
allow_null=True,
|
||||
),
|
||||
"customized_metrics": fields.Nested(
|
||||
console_ns.model("EvaluationCustomizedMetrics", customized_metrics_fields),
|
||||
allow_null=True,
|
||||
),
|
||||
"judgment_config": fields.Nested(
|
||||
console_ns.model("EvaluationJudgmentConfig", judgment_config_fields),
|
||||
allow_null=True,
|
||||
),
|
||||
}
|
||||
|
||||
evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields)
|
||||
|
||||
available_evaluation_workflow_list_fields = {
|
||||
"id": fields.String,
|
||||
"app_id": fields.String,
|
||||
"app_name": fields.String,
|
||||
"type": fields.String,
|
||||
"kind": fields.String,
|
||||
"version": fields.String,
|
||||
"marked_name": fields.String,
|
||||
"marked_comment": fields.String,
|
||||
"hash": fields.String,
|
||||
"created_by": fields.Nested(simple_account_fields),
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.Nested(simple_account_fields, allow_null=True),
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
available_evaluation_workflow_pagination_fields = {
|
||||
"items": fields.List(fields.Nested(available_evaluation_workflow_list_fields)),
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
}
|
||||
|
||||
available_evaluation_workflow_pagination_model = console_ns.model(
|
||||
"AvailableEvaluationWorkflowPagination",
|
||||
available_evaluation_workflow_pagination_fields,
|
||||
)
|
||||
|
||||
evaluation_default_metrics_response_model = console_ns.model(
|
||||
"EvaluationDefaultMetricsResponse",
|
||||
{
|
||||
"default_metrics": fields.List(
|
||||
fields.Nested(console_ns.model("EvaluationDefaultMetricItem", evaluation_default_metric_item_fields)),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_evaluation_target(view_func: Callable[P, R]):
|
||||
"""
|
||||
Decorator to resolve polymorphic evaluation target (app or snippet).
|
||||
|
||||
Validates the target_type parameter and fetches the corresponding
|
||||
model (App or CustomizedSnippet) with tenant isolation.
|
||||
"""
|
||||
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
target_type = kwargs.get("evaluate_target_type")
|
||||
target_id = kwargs.get("evaluate_target_id")
|
||||
|
||||
if target_type not in EVALUATE_TARGET_TYPES:
|
||||
raise NotFound(f"Invalid evaluation target type: {target_type}")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
target_id = str(target_id)
|
||||
|
||||
# Remove path parameters
|
||||
del kwargs["evaluate_target_type"]
|
||||
del kwargs["evaluate_target_id"]
|
||||
|
||||
target: Union[App, CustomizedSnippet, Dataset] | None = None
|
||||
|
||||
if target_type == "app":
|
||||
target = db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
|
||||
elif target_type == "snippets":
|
||||
target = (
|
||||
db.session.query(CustomizedSnippet)
|
||||
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
elif target_type == "knowledge":
|
||||
target = (db.session.query(Dataset)
|
||||
.where(Dataset.id == target_id, Dataset.tenant_id == current_tenant_id)
|
||||
.first())
|
||||
|
||||
if not target:
|
||||
raise NotFound(f"{str(target_type)} not found")
|
||||
|
||||
kwargs["target"] = target
|
||||
kwargs["target_type"] = target_type
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/dataset-template/download")
|
||||
class EvaluationDatasetTemplateDownloadApi(Resource):
|
||||
@console_ns.doc("download_evaluation_dataset_template")
|
||||
@console_ns.response(200, "Template file streamed as XLSX attachment")
|
||||
@console_ns.response(400, "Invalid target type or excluded app mode")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Download evaluation dataset template.
|
||||
|
||||
Generates an XLSX template based on the target's input parameters
|
||||
and streams it directly as a file attachment.
|
||||
"""
|
||||
try:
|
||||
xlsx_content, filename = EvaluationService.generate_dataset_template(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
)
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
encoded_filename = quote(filename)
|
||||
response = Response(
|
||||
xlsx_content,
|
||||
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Length"] = str(len(xlsx_content))
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation")
|
||||
class EvaluationDetailApi(Resource):
|
||||
@console_ns.doc("get_evaluation_detail")
|
||||
@console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model)
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation configuration for the target.
|
||||
|
||||
Returns evaluation configuration including model settings,
|
||||
metrics config, and judgement conditions.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.get_evaluation_config(session, current_tenant_id, target_type, str(target.id))
|
||||
|
||||
if config is None:
|
||||
return {
|
||||
"evaluation_model": None,
|
||||
"evaluation_model_provider": None,
|
||||
"default_metrics": None,
|
||||
"customized_metrics": None,
|
||||
"judgment_config": None,
|
||||
}
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": config.default_metrics_list,
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": config.judgment_config_dict,
|
||||
}
|
||||
|
||||
@console_ns.doc("save_evaluation_detail")
|
||||
@console_ns.response(200, "Evaluation configuration saved successfully")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def put(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Save evaluation configuration for the target.
|
||||
"""
|
||||
current_account, current_tenant_id = current_account_with_tenant()
|
||||
body = request.get_json(force=True)
|
||||
|
||||
try:
|
||||
config_data = EvaluationConfigData.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.save_evaluation_config(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
account_id=str(current_account.id),
|
||||
data=config_data,
|
||||
)
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": config.default_metrics_list,
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": config.judgment_config_dict,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/logs")
|
||||
class EvaluationLogsApi(Resource):
|
||||
@console_ns.doc("get_evaluation_logs")
|
||||
@console_ns.response(200, "Evaluation logs retrieved successfully")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation run history for the target.
|
||||
|
||||
Returns a paginated list of evaluation runs.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 20, type=int)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
runs, total = EvaluationService.get_evaluation_runs(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": [_serialize_evaluation_run(run) for run in runs],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/run")
|
||||
class EvaluationRunApi(Resource):
|
||||
@console_ns.doc("start_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run started")
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet, Dataset], target_type: str):
|
||||
"""
|
||||
Start an evaluation run.
|
||||
|
||||
Expects JSON body with:
|
||||
- file_id: uploaded dataset file ID
|
||||
- evaluation_model: evaluation model name
|
||||
- evaluation_model_provider: evaluation model provider
|
||||
- default_metrics: list of default metric objects
|
||||
- customized_metrics: customized metrics object (optional)
|
||||
- judgment_config: judgment conditions config (optional)
|
||||
"""
|
||||
current_account, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
body = request.get_json(force=True)
|
||||
if not body:
|
||||
raise BadRequest("Request body is required.")
|
||||
|
||||
# Validate and parse request body
|
||||
try:
|
||||
run_request = EvaluationRunRequest.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
# Load dataset file
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=current_tenant_id).first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("Dataset file not found.")
|
||||
|
||||
try:
|
||||
dataset_content = storage.load_once(upload_file.key)
|
||||
except Exception:
|
||||
raise BadRequest("Failed to read dataset file.")
|
||||
|
||||
if not dataset_content:
|
||||
raise BadRequest("Dataset file is empty.")
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
evaluation_run = EvaluationService.start_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
account_id=str(current_account.id),
|
||||
dataset_file_content=dataset_content,
|
||||
run_request=run_request,
|
||||
)
|
||||
return _serialize_evaluation_run(evaluation_run), 200
|
||||
except EvaluationFrameworkNotConfiguredError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except EvaluationMaxConcurrentRunsError as e:
|
||||
return {"message": str(e.description)}, 429
|
||||
except EvaluationDatasetInvalidError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>")
|
||||
class EvaluationRunDetailApi(Resource):
|
||||
@console_ns.doc("get_evaluation_run_detail")
|
||||
@console_ns.response(200, "Evaluation run detail retrieved")
|
||||
@console_ns.response(404, "Run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
|
||||
"""
|
||||
Get evaluation run detail including items.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
run_id = str(run_id)
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 50, type=int)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.get_evaluation_run_detail(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
items, total_items = EvaluationService.get_evaluation_run_items(
|
||||
session=session,
|
||||
run_id=run_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"run": _serialize_evaluation_run(run),
|
||||
"items": {
|
||||
"data": [_serialize_evaluation_run_item(item) for item in items],
|
||||
"total": total_items,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
}
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>/cancel")
|
||||
class EvaluationRunCancelApi(Resource):
|
||||
@console_ns.doc("cancel_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run cancelled")
|
||||
@console_ns.response(404, "Run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
|
||||
"""Cancel a running evaluation."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
run_id = str(run_id)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.cancel_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
return _serialize_evaluation_run(run)
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/metrics")
|
||||
class EvaluationMetricsApi(Resource):
|
||||
@console_ns.doc("get_evaluation_metrics")
|
||||
@console_ns.response(200, "Available metrics retrieved")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get available evaluation metrics for the current framework.
|
||||
"""
|
||||
result = {}
|
||||
for category in EvaluationCategory:
|
||||
result[category.value] = EvaluationService.get_supported_metrics(category)
|
||||
return {"metrics": result}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/default-metrics")
|
||||
class EvaluationDefaultMetricsApi(Resource):
|
||||
@console_ns.doc(
|
||||
"get_evaluation_default_metrics_with_nodes",
|
||||
description=(
|
||||
"List default metrics supported by the current evaluation framework with matching nodes "
|
||||
"from the target's published workflow only (draft is ignored)."
|
||||
),
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Default metrics and node candidates for the published workflow",
|
||||
evaluation_default_metrics_response_model,
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
default_metrics = EvaluationService.get_default_metrics_with_nodes_for_published_target(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
)
|
||||
return {"default_metrics": [m.model_dump() for m in default_metrics]}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/node-info")
|
||||
class EvaluationNodeInfoApi(Resource):
|
||||
@console_ns.doc("get_evaluation_node_info")
|
||||
@console_ns.response(200, "Node info grouped by metric")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""Return workflow/snippet node info grouped by requested metrics.
|
||||
|
||||
Request body (JSON):
|
||||
- metrics: list[str] | None – metric names to query; omit or pass
|
||||
an empty list to get all nodes under key ``"all"``.
|
||||
|
||||
Response:
|
||||
``{metric_or_all: [{"node_id": ..., "type": ..., "title": ...}, ...]}``
|
||||
"""
|
||||
body = request.get_json(silent=True) or {}
|
||||
metrics: list[str] | None = body.get("metrics") or None
|
||||
|
||||
result = EvaluationService.get_nodes_for_metrics(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
metrics=metrics,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/evaluation/available-metrics")
|
||||
class EvaluationAvailableMetricsApi(Resource):
|
||||
@console_ns.doc("get_available_evaluation_metrics")
|
||||
@console_ns.response(200, "Available metrics list")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""Return the centrally-defined list of evaluation metrics."""
|
||||
return {"metrics": EvaluationService.get_available_metrics()}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/files/<uuid:file_id>")
|
||||
class EvaluationFileDownloadApi(Resource):
|
||||
@console_ns.doc("download_evaluation_file")
|
||||
@console_ns.response(200, "File download URL generated successfully")
|
||||
@console_ns.response(404, "Target or file not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str):
|
||||
"""
|
||||
Download evaluation test file or result file.
|
||||
|
||||
Looks up the specified file, verifies it belongs to the same tenant,
|
||||
and returns file info and download URL.
|
||||
"""
|
||||
file_id = str(file_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == file_id,
|
||||
UploadFile.tenant_id == current_tenant_id,
|
||||
)
|
||||
upload_file = session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
|
||||
"download_url": download_url,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/version")
|
||||
class EvaluationVersionApi(Resource):
|
||||
@console_ns.doc("get_evaluation_version_detail")
|
||||
@console_ns.expect(console_ns.models.get(VersionQuery.__name__))
|
||||
@console_ns.response(200, "Version details retrieved successfully")
|
||||
@console_ns.response(404, "Target or version not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation target version details.
|
||||
|
||||
Returns the workflow graph for the specified version.
|
||||
"""
|
||||
version = request.args.get("version")
|
||||
|
||||
if not version:
|
||||
return {"message": "version parameter is required"}, 400
|
||||
|
||||
graph = {}
|
||||
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
|
||||
graph = target.graph_dict
|
||||
|
||||
return {
|
||||
"graph": graph,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/available-evaluation-workflows")
|
||||
class AvailableEvaluationWorkflowsApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
|
||||
@console_ns.doc("list_available_evaluation_workflows")
|
||||
@console_ns.doc(description="List published evaluation workflows in the current workspace (all apps)")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Available evaluation workflows retrieved",
|
||||
available_evaluation_workflow_pagination_model,
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self):
|
||||
"""List published evaluation-type workflows for the current tenant (cross-app)."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
user_id = args.user_id
|
||||
named_only = args.named_only
|
||||
keyword = args.keyword
|
||||
|
||||
if user_id and user_id != current_user.id:
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
workflows, has_more = workflow_service.list_published_evaluation_workflows(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
page=page,
|
||||
limit=limit,
|
||||
user_id=user_id,
|
||||
named_only=named_only,
|
||||
keyword=keyword,
|
||||
)
|
||||
|
||||
app_ids = {w.app_id for w in workflows}
|
||||
if app_ids:
|
||||
apps = session.scalars(select(App).where(App.id.in_(app_ids))).all()
|
||||
app_names = {a.id: a.name for a in apps}
|
||||
else:
|
||||
app_names = {}
|
||||
|
||||
items = []
|
||||
for wf in workflows:
|
||||
items.append(
|
||||
{
|
||||
"id": wf.id,
|
||||
"app_id": wf.app_id,
|
||||
"app_name": app_names.get(wf.app_id, ""),
|
||||
"type": wf.type.value,
|
||||
"kind": wf.kind_or_standard,
|
||||
"version": wf.version,
|
||||
"marked_name": wf.marked_name,
|
||||
"marked_comment": wf.marked_comment,
|
||||
"hash": wf.unique_hash,
|
||||
"created_by": wf.created_by_account,
|
||||
"created_at": wf.created_at,
|
||||
"updated_by": wf.updated_by_account,
|
||||
"updated_at": wf.updated_at,
|
||||
}
|
||||
)
|
||||
|
||||
return (
|
||||
marshal(
|
||||
{"items": items, "page": page, "limit": limit, "has_more": has_more},
|
||||
available_evaluation_workflow_pagination_fields,
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/evaluation-workflows/<string:workflow_id>/associated-targets")
|
||||
class EvaluationWorkflowAssociatedTargetsApi(Resource):
|
||||
@console_ns.doc("list_evaluation_workflow_associated_targets")
|
||||
@console_ns.doc(
|
||||
description="List targets (apps / snippets / knowledge bases) that use the given workflow as customized metrics"
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, workflow_id: str):
|
||||
"""Return all evaluation targets that reference this workflow as customized metrics."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
configs = EvaluationService.list_targets_by_customized_workflow(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
customized_workflow_id=workflow_id,
|
||||
)
|
||||
|
||||
target_ids_by_type: dict[str, list[str]] = {}
|
||||
for cfg in configs:
|
||||
target_ids_by_type.setdefault(cfg.target_type, []).append(cfg.target_id)
|
||||
|
||||
app_names: dict[str, str] = {}
|
||||
if "app" in target_ids_by_type:
|
||||
apps = session.scalars(select(App).where(App.id.in_(target_ids_by_type["app"]))).all()
|
||||
app_names = {a.id: a.name for a in apps}
|
||||
|
||||
snippet_names: dict[str, str] = {}
|
||||
if "snippets" in target_ids_by_type:
|
||||
snippets = session.scalars(
|
||||
select(CustomizedSnippet).where(CustomizedSnippet.id.in_(target_ids_by_type["snippets"]))
|
||||
).all()
|
||||
snippet_names = {s.id: s.name for s in snippets}
|
||||
|
||||
dataset_names: dict[str, str] = {}
|
||||
if "knowledge_base" in target_ids_by_type:
|
||||
datasets = session.scalars(
|
||||
select(Dataset).where(Dataset.id.in_(target_ids_by_type["knowledge_base"]))
|
||||
).all()
|
||||
dataset_names = {d.id: d.name for d in datasets}
|
||||
|
||||
items = []
|
||||
for cfg in configs:
|
||||
name = ""
|
||||
if cfg.target_type == "app":
|
||||
name = app_names.get(cfg.target_id, "")
|
||||
elif cfg.target_type == "snippets":
|
||||
name = snippet_names.get(cfg.target_id, "")
|
||||
elif cfg.target_type == "knowledge_base":
|
||||
name = dataset_names.get(cfg.target_id, "")
|
||||
|
||||
items.append(
|
||||
{
|
||||
"target_type": cfg.target_type,
|
||||
"target_id": cfg.target_id,
|
||||
"target_name": name,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": items}, 200
|
||||
|
||||
|
||||
# ---- Serialization Helpers ----
|
||||
|
||||
|
||||
def _serialize_evaluation_run(run: EvaluationRun) -> dict[str, object]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"tenant_id": run.tenant_id,
|
||||
"target_type": run.target_type,
|
||||
"target_id": run.target_id,
|
||||
"evaluation_config_id": run.evaluation_config_id,
|
||||
"status": run.status,
|
||||
"dataset_file_id": run.dataset_file_id,
|
||||
"result_file_id": run.result_file_id,
|
||||
"total_items": run.total_items,
|
||||
"completed_items": run.completed_items,
|
||||
"failed_items": run.failed_items,
|
||||
"progress": run.progress,
|
||||
"metrics_summary": run.metrics_summary_dict,
|
||||
"error": run.error,
|
||||
"created_by": run.created_by,
|
||||
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
|
||||
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
|
||||
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_evaluation_run_item(item: EvaluationRunItem) -> dict[str, object]:
|
||||
return {
|
||||
"id": item.id,
|
||||
"item_index": item.item_index,
|
||||
"inputs": item.inputs_dict,
|
||||
"expected_output": item.expected_output,
|
||||
"actual_output": item.actual_output,
|
||||
"metrics": item.metrics_list,
|
||||
"judgment": item.judgment_dict,
|
||||
"metadata": item.metadata_dict,
|
||||
"error": item.error,
|
||||
"overall_score": item.overall_score,
|
||||
}
|
||||
@ -4,6 +4,7 @@ from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import and_, select
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
@ -14,7 +15,6 @@ from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
|
||||
@ -1,142 +0,0 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class SnippetListQuery(BaseModel):
|
||||
"""Query parameters for listing snippets."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
keyword: str | None = None
|
||||
is_published: bool | None = Field(default=None, description="Filter by published status")
|
||||
creators: list[str] | None = Field(default=None, description="Filter by creator account IDs")
|
||||
|
||||
@field_validator("creators", mode="before")
|
||||
@classmethod
|
||||
def parse_creators(cls, value: object) -> list[str] | None:
|
||||
"""Normalize creators filter from query string or list input."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return [creator.strip() for creator in value.split(",") if creator.strip()] or None
|
||||
if isinstance(value, list):
|
||||
return [str(creator).strip() for creator in value if str(creator).strip()] or None
|
||||
return None
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
"""Icon information model."""
|
||||
|
||||
icon: str | None = None
|
||||
icon_type: Literal["emoji", "image"] | None = None
|
||||
icon_background: str | None = None
|
||||
icon_url: str | None = None
|
||||
|
||||
|
||||
class InputFieldDefinition(BaseModel):
|
||||
"""Input field definition for snippet parameters."""
|
||||
|
||||
default: str | None = None
|
||||
hint: bool | None = None
|
||||
label: str | None = None
|
||||
max_length: int | None = None
|
||||
options: list[str] | None = None
|
||||
placeholder: str | None = None
|
||||
required: bool | None = None
|
||||
type: str | None = None # e.g., "text-input"
|
||||
|
||||
|
||||
class CreateSnippetPayload(BaseModel):
|
||||
"""Payload for creating a new snippet."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
type: Literal["node", "group"] = "node"
|
||||
icon_info: IconInfo | None = None
|
||||
graph: dict[str, Any] | None = None
|
||||
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
|
||||
|
||||
|
||||
class UpdateSnippetPayload(BaseModel):
|
||||
"""Payload for updating a snippet."""
|
||||
|
||||
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
icon_info: IconInfo | None = None
|
||||
|
||||
|
||||
class SnippetDraftSyncPayload(BaseModel):
|
||||
"""Payload for syncing snippet draft workflow."""
|
||||
|
||||
graph: dict[str, Any]
|
||||
hash: str | None = None
|
||||
conversation_variables: list[dict[str, Any]] | None = Field(
|
||||
default=None,
|
||||
description="Ignored. Snippet workflows do not persist conversation variables.",
|
||||
)
|
||||
input_fields: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetWorkflowListQuery(BaseModel):
|
||||
"""Query parameters for listing snippet published workflows."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
"""Query parameters for workflow runs."""
|
||||
|
||||
last_id: str | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SnippetDraftRunPayload(BaseModel):
|
||||
"""Payload for running snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetDraftNodeRunPayload(BaseModel):
|
||||
"""Payload for running a single node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetIterationNodeRunPayload(BaseModel):
|
||||
"""Payload for running an iteration node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetLoopNodeRunPayload(BaseModel):
|
||||
"""Payload for running a loop node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PublishWorkflowPayload(BaseModel):
|
||||
"""Payload for publishing snippet workflow."""
|
||||
|
||||
knowledge_base_setting: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetImportPayload(BaseModel):
|
||||
"""Payload for importing snippet from DSL."""
|
||||
|
||||
mode: str = Field(..., description="Import mode: yaml-content or yaml-url")
|
||||
yaml_content: str | None = Field(default=None, description="YAML content (required for yaml-content mode)")
|
||||
yaml_url: str | None = Field(default=None, description="YAML URL (required for yaml-url mode)")
|
||||
name: str | None = Field(default=None, description="Override snippet name")
|
||||
description: str | None = Field(default=None, description="Override snippet description")
|
||||
snippet_id: str | None = Field(default=None, description="Snippet ID to update (optional)")
|
||||
|
||||
|
||||
class IncludeSecretQuery(BaseModel):
|
||||
"""Query parameter for including secret variables in export."""
|
||||
|
||||
include_secret: str = Field(default="false", description="Whether to include secret variables")
|
||||
@ -1,579 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.workflow import workflow_model, workflow_pagination_model
|
||||
from controllers.console.app.workflow_run import (
|
||||
workflow_run_detail_model,
|
||||
workflow_run_node_execution_list_model,
|
||||
workflow_run_node_execution_model,
|
||||
workflow_run_pagination_model,
|
||||
)
|
||||
from controllers.console.snippets.payloads import (
|
||||
PublishWorkflowPayload,
|
||||
SnippetDraftNodeRunPayload,
|
||||
SnippetDraftRunPayload,
|
||||
SnippetDraftSyncPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
SnippetWorkflowListQuery,
|
||||
WorkflowRunQuery,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
from services.snippet_service import SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SnippetDraftSyncPayload,
|
||||
SnippetDraftNodeRunPayload,
|
||||
SnippetDraftRunPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
SnippetWorkflowListQuery,
|
||||
WorkflowRunQuery,
|
||||
PublishWorkflowPayload,
|
||||
)
|
||||
|
||||
|
||||
snippet_workflow_model = console_ns.clone("SnippetWorkflow", workflow_model, {
|
||||
"input_fields": fields.Raw(default=[]),
|
||||
})
|
||||
|
||||
|
||||
class SnippetNotFoundError(Exception):
|
||||
"""Snippet not found error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_snippet(view_func: Callable[P, R]):
|
||||
"""Decorator to fetch and validate snippet access."""
|
||||
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if not kwargs.get("snippet_id"):
|
||||
raise ValueError("missing snippet_id in path parameters")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet_id = str(kwargs.get("snippet_id"))
|
||||
del kwargs["snippet_id"]
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=snippet_id,
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
kwargs["snippet"] = snippet
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
|
||||
class SnippetDraftWorkflowApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_workflow")
|
||||
@console_ns.response(200, "Draft workflow retrieved successfully", snippet_workflow_model)
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
@marshal_with(snippet_workflow_model)
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get draft workflow for snippet."""
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
|
||||
if not workflow:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
db.session.expunge(workflow)
|
||||
workflow.conversation_variables = []
|
||||
workflow.input_fields = snippet.input_fields_list
|
||||
return workflow
|
||||
|
||||
@console_ns.doc("sync_snippet_draft_workflow")
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
|
||||
@console_ns.response(200, "Draft workflow synced successfully")
|
||||
@console_ns.response(400, "Hash mismatch")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""Sync draft workflow for snippet."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.sync_draft_workflow(
|
||||
snippet=snippet,
|
||||
graph=payload.graph,
|
||||
unique_hash=payload.hash,
|
||||
account=current_user,
|
||||
input_fields=payload.input_fields,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
|
||||
class SnippetDraftConfigApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_config")
|
||||
@console_ns.response(200, "Draft config retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get snippet draft workflow configuration limits."""
|
||||
return {
|
||||
"parallel_depth_limit": 3,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
|
||||
class SnippetPublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("get_snippet_published_workflow")
|
||||
@console_ns.response(200, "Published workflow retrieved successfully", snippet_workflow_model)
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
@marshal_with(snippet_workflow_model)
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get published workflow for snippet."""
|
||||
if not snippet.is_published:
|
||||
return None
|
||||
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_published_workflow(snippet=snippet)
|
||||
|
||||
if workflow:
|
||||
workflow.input_fields = snippet.input_fields_list
|
||||
|
||||
return workflow
|
||||
|
||||
@console_ns.doc("publish_snippet_workflow")
|
||||
@console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
|
||||
@console_ns.response(200, "Workflow published successfully")
|
||||
@console_ns.response(400, "No draft workflow found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""Publish snippet workflow."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
snippet_service = SnippetService()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
try:
|
||||
workflow = snippet_service.publish_workflow(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
account=current_user,
|
||||
)
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
session.commit()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
|
||||
class SnippetDefaultBlockConfigsApi(Resource):
|
||||
@console_ns.doc("get_snippet_default_block_configs")
|
||||
@console_ns.response(200, "Default block configs retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get default block configurations for snippet workflow."""
|
||||
snippet_service = SnippetService()
|
||||
return snippet_service.get_default_block_configs()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows")
|
||||
class SnippetPublishedAllWorkflowApi(Resource):
|
||||
@console_ns.expect(console_ns.models[SnippetWorkflowListQuery.__name__])
|
||||
@console_ns.doc("get_all_snippet_published_workflows")
|
||||
@console_ns.doc(description="Get all published workflows for a snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
@console_ns.response(200, "Published workflows retrieved successfully", workflow_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get all published workflow versions for snippet."""
|
||||
args = SnippetWorkflowListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
snippet_service = SnippetService()
|
||||
with Session(db.engine) as session:
|
||||
workflows, has_more = snippet_service.get_all_published_workflows(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
serialized_workflows = marshal(workflows, workflow_model)
|
||||
|
||||
return {
|
||||
"items": serialized_workflows,
|
||||
"page": args.page,
|
||||
"limit": args.limit,
|
||||
"has_more": has_more,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
|
||||
class SnippetWorkflowRunsApi(Resource):
|
||||
@console_ns.doc("list_snippet_workflow_runs")
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_pagination_model)
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""List workflow runs for snippet."""
|
||||
query = WorkflowRunQuery.model_validate(
|
||||
{
|
||||
"last_id": request.args.get("last_id"),
|
||||
"limit": request.args.get("limit", type=int, default=20),
|
||||
}
|
||||
)
|
||||
args = {
|
||||
"last_id": query.last_id,
|
||||
"limit": query.limit,
|
||||
}
|
||||
|
||||
snippet_service = SnippetService()
|
||||
result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
|
||||
class SnippetWorkflowRunDetailApi(Resource):
|
||||
@console_ns.doc("get_snippet_workflow_run_detail")
|
||||
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_detail_model)
|
||||
def get(self, snippet: CustomizedSnippet, run_id):
|
||||
"""Get workflow run detail for snippet."""
|
||||
run_id = str(run_id)
|
||||
|
||||
snippet_service = SnippetService()
|
||||
workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
|
||||
|
||||
if not workflow_run:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
class SnippetWorkflowRunNodeExecutionsApi(Resource):
|
||||
@console_ns.doc("list_snippet_workflow_run_node_executions")
|
||||
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_node_execution_list_model)
|
||||
def get(self, snippet: CustomizedSnippet, run_id):
|
||||
"""List node executions for a workflow run."""
|
||||
run_id = str(run_id)
|
||||
|
||||
snippet_service = SnippetService()
|
||||
node_executions = snippet_service.get_snippet_workflow_run_node_executions(
|
||||
snippet=snippet,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
return {"data": node_executions}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class SnippetDraftNodeRunApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_node")
|
||||
@console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Node run completed successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a single node in snippet draft workflow.
|
||||
|
||||
Executes a specific node with provided inputs for single-step debugging.
|
||||
Returns the node execution result including status, outputs, and timing.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
user_inputs = payload.inputs
|
||||
|
||||
# Get draft workflow for file parsing
|
||||
snippet_service = SnippetService()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise NotFound("Draft workflow not found")
|
||||
|
||||
files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
|
||||
|
||||
workflow_node_execution = SnippetGenerateService.run_draft_node(
|
||||
snippet=snippet,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=current_user,
|
||||
query=payload.query,
|
||||
files=files,
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
||||
class SnippetDraftNodeLastRunApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_node_last_run")
|
||||
@console_ns.doc(description="Get last run result for a node in snippet draft workflow")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(404, "Snippet, draft workflow, or node last run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
def get(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Get the last run result for a specific node in snippet draft workflow.
|
||||
|
||||
Returns the most recent execution record for the given node,
|
||||
including status, inputs, outputs, and timing information.
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise NotFound("Draft workflow not found")
|
||||
|
||||
node_exec = snippet_service.get_snippet_node_last_run(
|
||||
snippet=snippet,
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
)
|
||||
if node_exec is None:
|
||||
raise NotFound("Node last run not found")
|
||||
|
||||
return node_exec
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class SnippetDraftRunIterationNodeApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node for snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Iteration node run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow iteration node for snippet.
|
||||
|
||||
Iteration nodes execute their internal sub-graph multiple times over an input list.
|
||||
Returns an SSE event stream with iteration progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate_single_iteration(
|
||||
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class SnippetDraftRunLoopNodeApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node for snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Loop node run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow loop node for snippet.
|
||||
|
||||
Loop nodes execute their internal sub-graph repeatedly until a condition is met.
|
||||
Returns an SSE event stream with loop progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate_single_loop(
|
||||
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
|
||||
class SnippetDraftWorkflowRunApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_workflow")
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
|
||||
@console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""
|
||||
Run draft workflow for snippet.
|
||||
|
||||
Executes the snippet's draft workflow with the provided inputs
|
||||
and returns an SSE event stream with execution progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate(
|
||||
snippet=snippet,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
class SnippetWorkflowTaskStopApi(Resource):
|
||||
@console_ns.doc("stop_snippet_workflow_task")
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, task_id: str):
|
||||
"""
|
||||
Stop a running snippet workflow task.
|
||||
|
||||
Uses both the legacy stop flag mechanism and the graph engine
|
||||
command channel for backward compatibility.
|
||||
"""
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
@ -1,319 +0,0 @@
|
||||
"""
|
||||
Snippet draft workflow variable APIs.
|
||||
|
||||
Mirrors console app routes under /apps/.../workflows/draft/variables for snippet scope,
|
||||
using CustomizedSnippet.id as WorkflowDraftVariable.app_id (same invariant as snippet execution).
|
||||
|
||||
Snippet workflows do not expose system variables (`node_id == sys`) or conversation variables
|
||||
(`node_id == conversation`): paginated list queries exclude those rows; single-variable GET/PATCH/DELETE/reset
|
||||
reject them; `GET .../system-variables` and `GET .../conversation-variables` return empty lists for API parity.
|
||||
Other routes mirror `workflow_draft_variable` app APIs under `/snippets/...`.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal, marshal_with
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
WorkflowDraftVariableListQuery,
|
||||
WorkflowDraftVariableUpdatePayload,
|
||||
_ensure_variable_access,
|
||||
_file_access_controller,
|
||||
validate_node_id,
|
||||
workflow_draft_variable_list_model,
|
||||
workflow_draft_variable_list_without_value_model,
|
||||
workflow_draft_variable_model,
|
||||
)
|
||||
from controllers.console.snippets.snippet_workflow import get_snippet
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.login import current_user, login_required
|
||||
from models.snippet import CustomizedSnippet
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.snippet_service import SnippetService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: frozenset[str] = frozenset(
|
||||
{SYSTEM_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID}
|
||||
)
|
||||
|
||||
|
||||
def _ensure_snippet_draft_variable_row_allowed(
|
||||
*,
|
||||
variable: WorkflowDraftVariable,
|
||||
variable_id: str,
|
||||
) -> None:
|
||||
"""Snippet scope only supports canvas-node draft variables; treat sys/conversation rows as not found."""
|
||||
if variable.node_id in _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
|
||||
def _snippet_draft_var_prerequisite(f: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Setup, auth, snippet resolution, and tenant edit permission (same stack as snippet workflow APIs)."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables")
|
||||
class SnippetWorkflowVariableCollectionApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
|
||||
@console_ns.doc("get_snippet_workflow_variables")
|
||||
@console_ns.doc(description="List draft workflow variables without values (paginated, snippet scope)")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow variables retrieved successfully",
|
||||
workflow_draft_variable_list_without_value_model,
|
||||
)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
snippet_service = SnippetService()
|
||||
if snippet_service.get_draft_workflow(snippet=snippet) is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=session)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=snippet.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
user_id=current_user.id,
|
||||
exclude_node_ids=_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
||||
@console_ns.doc("delete_snippet_workflow_variables")
|
||||
@console_ns.doc(description="Delete all draft workflow variables for the current user (snippet scope)")
|
||||
@console_ns.response(204, "Workflow variables deleted successfully")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, snippet: CustomizedSnippet) -> Response:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
draft_var_srv.delete_user_workflow_variables(snippet.id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
class SnippetNodeVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_node_variables")
|
||||
@console_ns.doc(description="Get variables for a specific node (snippet draft workflow)")
|
||||
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList:
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=session)
|
||||
node_vars = draft_var_srv.list_node_variables(snippet.id, node_id, user_id=current_user.id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@console_ns.doc("delete_snippet_node_variables")
|
||||
@console_ns.doc(description="Delete all variables for a specific node (snippet draft workflow)")
|
||||
@console_ns.response(204, "Node variables deleted successfully")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, snippet: CustomizedSnippet, node_id: str) -> Response:
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(snippet.id, node_id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
class SnippetVariableApi(Resource):
|
||||
@console_ns.doc("get_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Get a specific draft workflow variable (snippet scope)")
|
||||
@console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
return variable
|
||||
|
||||
@console_ns.doc("update_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Update a draft workflow variable (snippet scope)")
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def patch(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
new_name = args_model.name
|
||||
raw_value = args_model.value
|
||||
if new_name is None and raw_value is None:
|
||||
return variable
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=snippet.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=snippet.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@console_ns.doc("delete_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Delete a draft workflow variable (snippet scope)")
|
||||
@console_ns.response(204, "Variable deleted successfully")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, snippet: CustomizedSnippet, variable_id: str) -> Response:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class SnippetVariableResetApi(Resource):
|
||||
@console_ns.doc("reset_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Reset a draft workflow variable to its default value (snippet scope)")
|
||||
@console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(204, "Variable reset (no content)")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def put(self, snippet: CustomizedSnippet, variable_id: str) -> Response | Any:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
snippet_service = SnippetService()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, snippet_id={snippet.id}",
|
||||
)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
if resetted is None:
|
||||
return Response("", 204)
|
||||
return marshal(resetted, workflow_draft_variable_model)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/conversation-variables")
|
||||
class SnippetConversationVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_conversation_variables")
|
||||
@console_ns.doc(
|
||||
description="Conversation variables are not used in snippet workflows; returns an empty list for API parity"
|
||||
)
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/system-variables")
|
||||
class SnippetSystemVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_system_variables")
|
||||
@console_ns.doc(
|
||||
description="System variables are not used in snippet workflows; returns an empty list for API parity"
|
||||
)
|
||||
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/environment-variables")
|
||||
class SnippetEnvironmentVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_environment_variables")
|
||||
@console_ns.doc(description="Get environment variables from snippet draft workflow graph")
|
||||
@console_ns.response(200, "Environment variables retrieved successfully")
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def get(self, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]:
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if workflow is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
env_vars_list: list[dict[str, Any]] = []
|
||||
for v in workflow.environment_variables:
|
||||
env_vars_list.append(
|
||||
{
|
||||
"id": v.id,
|
||||
"type": "env",
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.exposed_type().value,
|
||||
"value": v.value,
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"editable": True,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
@ -1 +0,0 @@
|
||||
|
||||
@ -1,108 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from flask import Request as FlaskRequest
|
||||
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository
|
||||
from services.account_service import AccountService
|
||||
from services.workflow_collaboration_service import WorkflowCollaborationService
|
||||
|
||||
repository = WorkflowCollaborationRepository()
|
||||
collaboration_service = WorkflowCollaborationService(repository, sio)
|
||||
|
||||
|
||||
def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]:
|
||||
return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event))
|
||||
|
||||
|
||||
@_sio_on("connect")
|
||||
def socket_connect(sid, environ, auth):
|
||||
"""
|
||||
WebSocket connect event, do authentication here.
|
||||
"""
|
||||
try:
|
||||
request_environ = FlaskRequest(environ)
|
||||
token = extract_access_token(request_environ)
|
||||
except Exception:
|
||||
logging.exception("Failed to extract token")
|
||||
token = None
|
||||
|
||||
if not token:
|
||||
logging.warning("Socket connect rejected: missing token (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
try:
|
||||
decoded = PassportService().verify(token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if not user:
|
||||
logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
if not user.has_edit_permission:
|
||||
logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
|
||||
collaboration_service.save_socket_identity(sid, user)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logging.exception("Socket authentication failed")
|
||||
return False
|
||||
|
||||
|
||||
@_sio_on("user_connect")
|
||||
def handle_user_connect(sid, data):
|
||||
"""
|
||||
Handle user connect event. Each session (tab) is treated as an independent collaborator.
|
||||
"""
|
||||
workflow_id = data.get("workflow_id")
|
||||
if not workflow_id:
|
||||
return {"msg": "workflow_id is required"}, 400
|
||||
|
||||
result = collaboration_service.authorize_and_join_workflow_room(workflow_id, sid)
|
||||
if not result:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
user_id, is_leader = result
|
||||
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
|
||||
|
||||
|
||||
@_sio_on("disconnect")
|
||||
def handle_disconnect(sid):
|
||||
"""
|
||||
Handle session disconnect event. Remove the specific session from online users.
|
||||
"""
|
||||
collaboration_service.disconnect_session(sid)
|
||||
|
||||
|
||||
@_sio_on("collaboration_event")
|
||||
def handle_collaboration_event(sid, data):
|
||||
"""
|
||||
Handle general collaboration events, include:
|
||||
1. mouse_move
|
||||
2. vars_and_features_update
|
||||
3. sync_request (ask leader to update graph)
|
||||
4. app_state_update
|
||||
5. mcp_server_update
|
||||
6. workflow_update
|
||||
7. comments_update
|
||||
8. node_panel_presence
|
||||
"""
|
||||
return collaboration_service.relay_collaboration_event(sid, data)
|
||||
|
||||
|
||||
@_sio_on("graph_event")
|
||||
def handle_graph_event(sid, data):
|
||||
"""
|
||||
Handle graph events - simple broadcast relay.
|
||||
"""
|
||||
return collaboration_service.relay_graph_event(sid, data)
|
||||
@ -39,7 +39,6 @@ from controllers.console.wraps import (
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import Account as AccountResponse
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@ -76,10 +75,6 @@ class AccountAvatarPayload(BaseModel):
|
||||
avatar: str
|
||||
|
||||
|
||||
class AccountAvatarQuery(BaseModel):
|
||||
avatar: str = Field(..., description="Avatar file ID")
|
||||
|
||||
|
||||
class AccountInterfaceLanguagePayload(BaseModel):
|
||||
interface_language: str
|
||||
|
||||
@ -165,7 +160,6 @@ def reg(cls: type[BaseModel]):
|
||||
reg(AccountInitPayload)
|
||||
reg(AccountNamePayload)
|
||||
reg(AccountAvatarPayload)
|
||||
reg(AccountAvatarQuery)
|
||||
reg(AccountInterfaceLanguagePayload)
|
||||
reg(AccountInterfaceThemePayload)
|
||||
reg(AccountTimezonePayload)
|
||||
@ -315,18 +309,6 @@ class AccountNameApi(Resource):
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AccountAvatarQuery.__name__])
|
||||
@console_ns.doc("get_account_avatar")
|
||||
@console_ns.doc(description="Get account avatar url")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(args.avatar)
|
||||
return {"avatar_url": avatar_url}
|
||||
|
||||
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -1,380 +0,0 @@
|
||||
import logging
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.snippets.payloads import (
|
||||
CreateSnippetPayload,
|
||||
IncludeSecretQuery,
|
||||
SnippetImportPayload,
|
||||
SnippetListQuery,
|
||||
UpdateSnippetPayload,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.snippet import SnippetType
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.snippet_dsl_service import SnippetDslService
|
||||
from services.snippet_service import SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SnippetListQuery,
|
||||
CreateSnippetPayload,
|
||||
UpdateSnippetPayload,
|
||||
SnippetImportPayload,
|
||||
IncludeSecretQuery,
|
||||
)
|
||||
|
||||
# Create namespace models for marshaling
|
||||
snippet_model = console_ns.model("Snippet", snippet_fields)
|
||||
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
|
||||
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets")
|
||||
class CustomizedSnippetsApi(Resource):
|
||||
@console_ns.doc("list_customized_snippets")
|
||||
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
|
||||
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""List customized snippets with pagination and search."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
query_params = request.args.to_dict()
|
||||
query = SnippetListQuery.model_validate(query_params)
|
||||
|
||||
snippets, total, has_more = SnippetService.get_snippets(
|
||||
tenant_id=current_tenant_id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
is_published=query.is_published,
|
||||
creators=query.creators,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": marshal(snippets, snippet_list_fields),
|
||||
"page": query.page,
|
||||
"limit": query.limit,
|
||||
"total": total,
|
||||
"has_more": has_more,
|
||||
}, 200
|
||||
|
||||
@console_ns.doc("create_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
|
||||
@console_ns.response(201, "Snippet created successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request or name already exists")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
"""Create a new customized snippet."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
snippet_type = SnippetType(payload.type)
|
||||
except ValueError:
|
||||
snippet_type = SnippetType.NODE
|
||||
|
||||
try:
|
||||
snippet = SnippetService.create_snippet(
|
||||
tenant_id=current_tenant_id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
snippet_type=snippet_type,
|
||||
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
|
||||
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
|
||||
account=current_user,
|
||||
)
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return marshal(snippet, snippet_fields), 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
|
||||
class CustomizedSnippetDetailApi(Resource):
|
||||
@console_ns.doc("get_customized_snippet")
|
||||
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, snippet_id: str):
|
||||
"""Get customized snippet details."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
return marshal(snippet, snippet_fields), 200
|
||||
|
||||
@console_ns.doc("update_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
|
||||
@console_ns.response(200, "Snippet updated successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request or name already exists")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, snippet_id: str):
|
||||
"""Update customized snippet."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
if "icon_info" in update_data and update_data["icon_info"] is not None:
|
||||
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
snippet = session.merge(snippet)
|
||||
snippet = SnippetService.update_snippet(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
account_id=current_user.id,
|
||||
data=update_data,
|
||||
)
|
||||
session.commit()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return marshal(snippet, snippet_fields), 200
|
||||
|
||||
@console_ns.doc("delete_customized_snippet")
|
||||
@console_ns.response(204, "Snippet deleted successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, snippet_id: str):
|
||||
"""Delete customized snippet."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
SnippetService.delete_snippet(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/export")
|
||||
class CustomizedSnippetExportApi(Resource):
|
||||
@console_ns.doc("export_customized_snippet")
|
||||
@console_ns.doc(description="Export snippet configuration as DSL")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID to export"})
|
||||
@console_ns.response(200, "Snippet exported successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, snippet_id: str):
|
||||
"""Export snippet as DSL."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
# Get include_secret parameter
|
||||
query = IncludeSecretQuery.model_validate(request.args.to_dict())
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = SnippetDslService(session)
|
||||
result = export_service.export_snippet_dsl(snippet=snippet, include_secret=query.include_secret == "true")
|
||||
|
||||
# Set filename with .snippet extension
|
||||
filename = f"{snippet.name}.snippet"
|
||||
encoded_filename = quote(filename)
|
||||
|
||||
response = Response(
|
||||
result,
|
||||
mimetype="application/x-yaml",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Type"] = "application/x-yaml"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/imports")
|
||||
class CustomizedSnippetImportApi(Resource):
|
||||
@console_ns.doc("import_customized_snippet")
|
||||
@console_ns.doc(description="Import snippet from DSL")
|
||||
@console_ns.expect(console_ns.models.get(SnippetImportPayload.__name__))
|
||||
@console_ns.response(200, "Snippet imported successfully")
|
||||
@console_ns.response(202, "Import pending confirmation")
|
||||
@console_ns.response(400, "Import failed")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
"""Import snippet from DSL."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = SnippetImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.import_snippet(
|
||||
account=current_user,
|
||||
import_mode=payload.mode,
|
||||
yaml_content=payload.yaml_content,
|
||||
yaml_url=payload.yaml_url,
|
||||
snippet_id=payload.snippet_id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/imports/<string:import_id>/confirm")
|
||||
class CustomizedSnippetImportConfirmApi(Resource):
|
||||
@console_ns.doc("confirm_snippet_import")
|
||||
@console_ns.doc(description="Confirm a pending snippet import")
|
||||
@console_ns.doc(params={"import_id": "Import ID to confirm"})
|
||||
@console_ns.response(200, "Import confirmed successfully")
|
||||
@console_ns.response(400, "Import failed")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, import_id: str):
|
||||
"""Confirm a pending snippet import."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.confirm_import(import_id=import_id, account=current_user)
|
||||
session.commit()
|
||||
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/check-dependencies")
|
||||
class CustomizedSnippetCheckDependenciesApi(Resource):
|
||||
@console_ns.doc("check_snippet_dependencies")
|
||||
@console_ns.doc(description="Check dependencies for a snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
@console_ns.response(200, "Dependencies checked successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, snippet_id: str):
|
||||
"""Check dependencies for a snippet."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.check_dependencies(snippet=snippet)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/use-count/increment")
|
||||
class CustomizedSnippetUseCountIncrementApi(Resource):
|
||||
@console_ns.doc("increment_snippet_use_count")
|
||||
@console_ns.doc(description="Increment snippet use count by 1")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
@console_ns.response(200, "Use count incremented successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, snippet_id: str):
|
||||
"""Increment snippet use count when it is inserted into a workflow."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
SnippetService.increment_use_count(session=session, snippet=snippet)
|
||||
session.commit()
|
||||
session.refresh(snippet)
|
||||
|
||||
return {"result": "success", "use_count": snippet.use_count}, 200
|
||||
@ -84,10 +84,10 @@ class ConversationVariableResponse(ResponseModel):
|
||||
def normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type())
|
||||
return str(exposed_type().value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return str(SegmentType(value).exposed_type())
|
||||
return str(SegmentType(value).exposed_type().value)
|
||||
except ValueError:
|
||||
return value
|
||||
try:
|
||||
|
||||
@ -6,6 +6,9 @@ from typing import Literal
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
@ -35,9 +38,6 @@ from extensions.ext_redis import redis_client
|
||||
from fields.base import ResponseModel
|
||||
from fields.end_user_fields import SimpleEndUser
|
||||
from fields.member_fields import SimpleAccount
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
@ -4,6 +4,20 @@ import uuid
|
||||
from decimal import Decimal
|
||||
from typing import Union, cast
|
||||
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
@ -29,20 +43,6 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
|
||||
|
||||
@ -300,9 +300,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
tool_instance = tool_instances.get(prompt_tool.name)
|
||||
if tool_instance:
|
||||
self.update_prompt_message_tool(tool_instance, prompt_tool)
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
from typing import cast
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class ModelConfigConverter:
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import re
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models.model import AppModelConfigDict
|
||||
|
||||
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
|
||||
@ -9,6 +9,12 @@ from datetime import datetime
|
||||
from threading import Thread
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.entities.pause_reason import HumanInputRequired
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@ -71,12 +77,6 @@ from core.repositories.human_input_repository import HumanInputFormRepositoryImp
|
||||
from core.workflow.file_reference import resolve_file_record_id
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from extensions.ext_database import db
|
||||
from graphon.entities.pause_reason import HumanInputRequired
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
@ -16,9 +19,6 @@ from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -2,6 +2,9 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
from graphon.enums import NodeType
|
||||
from graphon.file import File, FileUploadConfig
|
||||
from graphon.variables.input_entities import VariableEntityType
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.draft_variable_saver import (
|
||||
@ -13,9 +16,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from graphon.enums import NodeType
|
||||
from graphon.file import File, FileUploadConfig
|
||||
from graphon.variables.input_entities import VariableEntityType
|
||||
from libs.orjson import orjson_dumps
|
||||
from models import Account, EndUser
|
||||
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
|
||||
|
||||
@ -14,7 +14,7 @@ from graphon.runtime import GraphRuntimeState
|
||||
from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
@ -54,25 +54,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@staticmethod
|
||||
def _ensure_snippet_start_node_in_worker(*, session: Session, workflow: Workflow) -> Workflow:
|
||||
"""Re-apply snippet virtual Start injection after worker reloads workflow from DB."""
|
||||
if workflow.kind_or_standard != "snippet":
|
||||
return workflow
|
||||
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
|
||||
snippet = session.scalar(
|
||||
select(CustomizedSnippet).where(
|
||||
CustomizedSnippet.id == workflow.app_id,
|
||||
CustomizedSnippet.tenant_id == workflow.tenant_id,
|
||||
)
|
||||
)
|
||||
if snippet is None:
|
||||
return workflow
|
||||
return SnippetGenerateService.ensure_start_node_for_worker(workflow, snippet)
|
||||
|
||||
@staticmethod
|
||||
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
|
||||
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
|
||||
@ -576,8 +557,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
if workflow is None:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
workflow = self._ensure_snippet_start_node_in_worker(session=session, workflow=workflow)
|
||||
|
||||
# Determine system_user_id based on invocation source
|
||||
is_external_api_call = application_generate_entity.invoke_from in {
|
||||
InvokeFrom.WEB_APP,
|
||||
|
||||
@ -9,10 +9,11 @@ scope updates that matter to chat applications.
|
||||
|
||||
import logging
|
||||
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent
|
||||
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID
|
||||
from services.conversation_variable_updater import ConversationVariableUpdater
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -4,6 +4,13 @@ from collections.abc import Generator
|
||||
from threading import Thread
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@ -53,13 +60,6 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
|
||||
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from core.tools.signature import sign_tool_file
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.file import helpers as file_helpers
|
||||
|
||||
from core.tools.signature import sign_tool_file
|
||||
from models.model import MessageFile, UploadFile
|
||||
|
||||
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
|
||||
|
||||
@ -9,17 +9,17 @@ import urllib.parse
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
|
||||
from graphon.file.runtime import set_workflow_file_runtime
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.ssrf_proxy import graphon_ssrf_proxy
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from extensions.ext_storage import storage
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.file.protocols import WorkflowFileRuntimeProtocol
|
||||
from graphon.file.runtime import set_workflow_file_runtime
|
||||
from graphon.http.protocols import HttpResponseProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.file import File
|
||||
@ -44,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
return dify_config.MULTIMODAL_SEND_FORMAT
|
||||
|
||||
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
|
||||
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
return ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
|
||||
return storage.load(path, stream=stream)
|
||||
|
||||
@ -12,6 +12,10 @@ from contextvars import Token
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, final, override
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphNodeEventBase
|
||||
from graphon.nodes.base.node import Node
|
||||
from opentelemetry import context as context_api
|
||||
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
|
||||
|
||||
@ -24,10 +28,6 @@ from extensions.otel.parser import (
|
||||
ToolNodeOTelParser,
|
||||
)
|
||||
from extensions.otel.runtime import is_instrument_flag_enabled
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphNodeEventBase
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -350,7 +350,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
execution.total_tokens = runtime_state.total_tokens
|
||||
execution.total_steps = runtime_state.node_run_steps
|
||||
execution.outputs = execution.outputs or runtime_state.outputs
|
||||
execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count)
|
||||
execution.exceptions_count = runtime_state.exceptions_count
|
||||
|
||||
def _update_node_execution(
|
||||
self,
|
||||
|
||||
@ -352,11 +352,11 @@ class DatasourceManager:
|
||||
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
||||
|
||||
file_info = File(
|
||||
file_id=upload_file.id,
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
file_type=FileType.CUSTOM,
|
||||
type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
reference=build_file_reference(record_id=str(upload_file.id)),
|
||||
|
||||
@ -8,6 +8,16 @@ from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -24,16 +34,6 @@ from core.entities.provider_entities import (
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.enums import CredentialSourceType
|
||||
@ -318,28 +318,34 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""):
|
||||
def validate_provider_credentials(
|
||||
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
|
||||
):
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
|
||||
:param session: optional database session
|
||||
:return:
|
||||
"""
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if credential_id:
|
||||
with Session(db.engine) as session:
|
||||
def _validate(s: Session):
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if credential_id:
|
||||
try:
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.id == credential_id,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
# fix origin data
|
||||
if credential_record and credential_record.encrypted_config:
|
||||
if not credential_record.encrypted_config.startswith("{"):
|
||||
original_credentials = {"openai_api_key": credential_record.encrypted_config}
|
||||
@ -350,23 +356,31 @@ class ProviderConfiguration(BaseModel):
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
for key, value in credentials.items():
|
||||
# encrypt credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
return validated_credentials
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables and isinstance(value, str):
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
if session:
|
||||
return _validate(session)
|
||||
else:
|
||||
with Session(db.engine) as new_session:
|
||||
return _validate(new_session)
|
||||
|
||||
def _generate_provider_credential_name(self, session) -> str:
|
||||
"""
|
||||
@ -443,16 +457,14 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as pre_session:
|
||||
with Session(db.engine) as session:
|
||||
if credential_name:
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session):
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
else:
|
||||
credential_name = self._generate_provider_credential_name(pre_session)
|
||||
credential_name = self._generate_provider_credential_name(session)
|
||||
|
||||
credentials = self.validate_provider_credentials(credentials=credentials)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
|
||||
provider_record = self._get_provider_record(session)
|
||||
try:
|
||||
new_record = ProviderCredential(
|
||||
@ -465,6 +477,7 @@ class ProviderConfiguration(BaseModel):
|
||||
session.flush()
|
||||
|
||||
if not provider_record:
|
||||
# If provider record does not exist, create it
|
||||
provider_record = Provider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
@ -517,15 +530,15 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as pre_session:
|
||||
with Session(db.engine) as session:
|
||||
if credential_name and self._check_provider_credential_name_exists(
|
||||
credential_name=credential_name, session=pre_session, exclude_id=credential_id
|
||||
credential_name=credential_name, session=session, exclude_id=credential_id
|
||||
):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
|
||||
credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
credentials = self.validate_provider_credentials(
|
||||
credentials=credentials, credential_id=credential_id, session=session
|
||||
)
|
||||
provider_record = self._get_provider_record(session)
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
@ -533,10 +546,12 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
# Get the credential record to update
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
raise ValueError("Credential record not found.")
|
||||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
if credential_name:
|
||||
@ -864,6 +879,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str = "",
|
||||
session: Session | None = None,
|
||||
):
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
@ -874,14 +890,16 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
|
||||
:return:
|
||||
"""
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if credential_id:
|
||||
with Session(db.engine) as session:
|
||||
def _validate(s: Session):
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if credential_id:
|
||||
try:
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
@ -890,7 +908,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
original_credentials = (
|
||||
json.loads(credential_record.encrypted_config)
|
||||
if credential_record and credential_record.encrypted_config
|
||||
@ -899,23 +917,31 @@ class ProviderConfiguration(BaseModel):
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
for key, value in credentials.items():
|
||||
# decrypt credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
return validated_credentials
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables and isinstance(value, str):
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
if session:
|
||||
return _validate(session)
|
||||
else:
|
||||
with Session(db.engine) as new_session:
|
||||
return _validate(new_session)
|
||||
|
||||
def create_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
|
||||
@ -928,22 +954,20 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credentials: model credentials dict
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as pre_session:
|
||||
with Session(db.engine) as session:
|
||||
if credential_name:
|
||||
if self._check_custom_model_credential_name_exists(
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=pre_session
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=session
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
else:
|
||||
credential_name = self._generate_custom_model_credential_name(
|
||||
model=model, model_type=model_type, session=pre_session
|
||||
model=model, model_type=model_type, session=session
|
||||
)
|
||||
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# validate custom model config
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type, model=model, credentials=credentials, session=session
|
||||
)
|
||||
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
|
||||
|
||||
try:
|
||||
@ -958,6 +982,7 @@ class ProviderConfiguration(BaseModel):
|
||||
session.add(credential)
|
||||
session.flush()
|
||||
|
||||
# save provider model
|
||||
if not provider_model_record:
|
||||
provider_model_record = ProviderModel(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -999,24 +1024,23 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as pre_session:
|
||||
with Session(db.engine) as session:
|
||||
if credential_name and self._check_custom_model_credential_name_exists(
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
credential_name=credential_name,
|
||||
session=pre_session,
|
||||
session=session,
|
||||
exclude_id=credential_id,
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# validate custom model config
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_id=credential_id,
|
||||
session=session,
|
||||
)
|
||||
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
|
||||
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
@ -1031,6 +1055,7 @@ class ProviderConfiguration(BaseModel):
|
||||
raise ValueError("Credential record not found.")
|
||||
|
||||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
if credential_name:
|
||||
|
||||
@ -1,279 +0,0 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
CustomizedMetrics,
|
||||
EvaluationCategory,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
EvaluationMetric,
|
||||
NodeInfo,
|
||||
)
|
||||
from graphon.node_events.base import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseEvaluationInstance(ABC):
|
||||
"""Abstract base class for evaluation framework adapters."""
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_llm(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate LLM outputs using the configured framework."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_retrieval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate retrieval quality using the configured framework."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_agent(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate agent outputs using the configured framework."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
|
||||
"""Return the list of supported metric names for a given evaluation category."""
|
||||
...
|
||||
|
||||
def evaluate_with_customized_workflow(
|
||||
self,
|
||||
node_run_result_mapping_list: list[dict[str, NodeRunResult]],
|
||||
customized_metrics: CustomizedMetrics,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate using a published workflow as the evaluator.
|
||||
|
||||
The evaluator workflow's output variables are treated as metrics:
|
||||
each output variable name becomes a metric name, and its value
|
||||
becomes the score.
|
||||
|
||||
Args:
|
||||
node_run_result_mapping_list: One mapping per test-data item,
|
||||
where each mapping is ``{node_id: NodeRunResult}`` from the
|
||||
target execution.
|
||||
customized_metrics: Contains ``evaluation_workflow_id`` (the
|
||||
published evaluator workflow) and ``input_fields`` (value
|
||||
sources for the evaluator's input variables).
|
||||
tenant_id: Tenant scope.
|
||||
|
||||
Returns:
|
||||
A list of ``EvaluationItemResult`` with metrics extracted from
|
||||
the evaluator workflow's output variables.
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.evaluation.runners import get_service_account_for_app
|
||||
from models.engine import db
|
||||
from models.model import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
workflow_id = customized_metrics.evaluation_workflow_id
|
||||
if not workflow_id:
|
||||
raise ValueError("customized_metrics must contain 'evaluation_workflow_id' for customized evaluator")
|
||||
|
||||
# Load the evaluator workflow resources using a dedicated session
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
app = session.query(App).filter_by(id=workflow_id, tenant_id=tenant_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"Evaluation workflow app {workflow_id} not found in tenant {tenant_id}")
|
||||
service_account = get_service_account_for_app(session, workflow_id)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
published_workflow = workflow_service.get_published_workflow(app_model=app)
|
||||
if not published_workflow:
|
||||
raise ValueError(f"No published workflow found for evaluation app {workflow_id}")
|
||||
|
||||
eval_results: list[EvaluationItemResult] = []
|
||||
for idx, node_run_result_mapping in enumerate(node_run_result_mapping_list):
|
||||
try:
|
||||
workflow_inputs = self._build_workflow_inputs(
|
||||
customized_metrics.input_fields,
|
||||
node_run_result_mapping,
|
||||
)
|
||||
|
||||
generator = WorkflowAppGenerator()
|
||||
response: Mapping[str, Any] = generator.generate(
|
||||
app_model=app,
|
||||
workflow=published_workflow,
|
||||
user=service_account,
|
||||
args={"inputs": workflow_inputs},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
metrics = self._extract_workflow_metrics(response, workflow_id)
|
||||
eval_results.append(
|
||||
EvaluationItemResult(
|
||||
index=idx,
|
||||
metrics=metrics,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Customized evaluator failed for item %d with workflow %s",
|
||||
idx,
|
||||
workflow_id,
|
||||
)
|
||||
eval_results.append(EvaluationItemResult(index=idx))
|
||||
|
||||
return eval_results
|
||||
|
||||
@staticmethod
|
||||
def _build_workflow_inputs(
|
||||
input_fields: dict[str, Any],
|
||||
node_run_result_mapping: dict[str, NodeRunResult],
|
||||
) -> dict[str, Any]:
|
||||
"""Build customized workflow inputs by resolving value sources.
|
||||
|
||||
Each entry in ``input_fields`` maps a workflow input variable name
|
||||
to its value source, which can be:
|
||||
|
||||
- **Constant**: a plain string without ``{{#…#}}`` used as-is.
|
||||
- **Expression**: a string containing one or more
|
||||
``{{#node_id.output_key#}}`` selectors (same format as
|
||||
``VariableTemplateParser``) resolved from
|
||||
``node_run_result_mapping``.
|
||||
|
||||
"""
|
||||
from graphon.nodes.base.variable_template_parser import REGEX as VARIABLE_REGEX
|
||||
|
||||
workflow_inputs: dict[str, Any] = {}
|
||||
|
||||
for field_name, value_source in input_fields.items():
|
||||
if not isinstance(value_source, str):
|
||||
# Non-string values (numbers, bools, dicts) are used directly.
|
||||
workflow_inputs[field_name] = value_source
|
||||
continue
|
||||
|
||||
# Check if the entire value is a single expression.
|
||||
full_match = VARIABLE_REGEX.fullmatch(value_source)
|
||||
if full_match:
|
||||
workflow_inputs[field_name] = resolve_variable_selector(
|
||||
full_match.group(1),
|
||||
node_run_result_mapping,
|
||||
)
|
||||
elif VARIABLE_REGEX.search(value_source):
|
||||
# Mixed template: interpolate all expressions as strings.
|
||||
workflow_inputs[field_name] = VARIABLE_REGEX.sub(
|
||||
lambda m: str(resolve_variable_selector(m.group(1), node_run_result_mapping)),
|
||||
value_source,
|
||||
)
|
||||
else:
|
||||
# Plain constant — no expression markers.
|
||||
workflow_inputs[field_name] = value_source
|
||||
|
||||
return workflow_inputs
|
||||
|
||||
@staticmethod
|
||||
def _extract_workflow_metrics(
|
||||
response: Mapping[str, object],
|
||||
evaluation_workflow_id: str,
|
||||
) -> list[EvaluationMetric]:
|
||||
"""Extract evaluation metrics from workflow output variables.
|
||||
|
||||
Each metric's ``node_info`` is set with *evaluation_workflow_id* as
|
||||
the ``node_id``, so that judgment conditions can reference customized
|
||||
metrics via ``variable_selector: [evaluation_workflow_id, metric_name]``.
|
||||
"""
|
||||
metrics: list[EvaluationMetric] = []
|
||||
node_info = NodeInfo(node_id=evaluation_workflow_id, type="customized", title="customized")
|
||||
|
||||
data = response.get("data")
|
||||
if not isinstance(data, Mapping):
|
||||
logger.warning("Unexpected workflow response format: missing 'data' dict")
|
||||
return metrics
|
||||
|
||||
outputs = data.get("outputs")
|
||||
if not isinstance(outputs, dict):
|
||||
logger.warning("Unexpected workflow response format: 'outputs' is not a dict")
|
||||
return metrics
|
||||
|
||||
for key, raw_value in outputs.items():
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
metrics.append(EvaluationMetric(name=key, value=raw_value, node_info=node_info))
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def resolve_variable_selector(
|
||||
selector_raw: str,
|
||||
node_run_result_mapping: dict[str, NodeRunResult],
|
||||
) -> object:
|
||||
"""
|
||||
Resolve a ``#node_id.output_key#`` selector against node run results.
|
||||
"""
|
||||
#
|
||||
cleaned = selector_raw.strip("#")
|
||||
parts = cleaned.split(".")
|
||||
|
||||
if len(parts) < 2:
|
||||
logger.warning(
|
||||
"Selector '%s' must have at least node_id.output_key",
|
||||
selector_raw,
|
||||
)
|
||||
return ""
|
||||
|
||||
node_id = parts[0]
|
||||
output_path = parts[1:]
|
||||
|
||||
node_result = node_run_result_mapping.get(node_id)
|
||||
if not node_result or not node_result.outputs:
|
||||
logger.warning(
|
||||
"Selector '%s': node '%s' not found or has no outputs",
|
||||
selector_raw,
|
||||
node_id,
|
||||
)
|
||||
return ""
|
||||
|
||||
# Traverse the output path to support nested keys.
|
||||
current: object = node_result.outputs
|
||||
for key in output_path:
|
||||
if isinstance(current, Mapping):
|
||||
next_val = current.get(key)
|
||||
if next_val is None:
|
||||
logger.warning(
|
||||
"Selector '%s': key '%s' not found in node '%s' outputs",
|
||||
selector_raw,
|
||||
key,
|
||||
node_id,
|
||||
)
|
||||
return ""
|
||||
current = next_val
|
||||
else:
|
||||
logger.warning(
|
||||
"Selector '%s': cannot traverse into non-dict value at key '%s'",
|
||||
selector_raw,
|
||||
key,
|
||||
)
|
||||
return ""
|
||||
|
||||
return current if current is not None else ""
|
||||
@ -1,27 +0,0 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class EvaluationFrameworkEnum(StrEnum):
|
||||
RAGAS = "ragas"
|
||||
DEEPEVAL = "deepeval"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class BaseEvaluationConfig(BaseModel):
|
||||
"""Base configuration for evaluation frameworks."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RagasConfig(BaseEvaluationConfig):
|
||||
"""RAGAS-specific configuration."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DeepEvalConfig(BaseEvaluationConfig):
|
||||
"""DeepEval-specific configuration."""
|
||||
|
||||
pass
|
||||
@ -1,226 +0,0 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.evaluation.entities.judgment_entity import JudgmentConfig, JudgmentResult
|
||||
|
||||
|
||||
class EvaluationCategory(StrEnum):
|
||||
LLM = "llm"
|
||||
RETRIEVAL = "knowledge_retrieval"
|
||||
AGENT = "agent"
|
||||
WORKFLOW = "workflow"
|
||||
SNIPPET = "snippet"
|
||||
KNOWLEDGE_BASE = "knowledge_base"
|
||||
|
||||
|
||||
class EvaluationMetricName(StrEnum):
|
||||
"""Canonical metric names shared across all evaluation frameworks.
|
||||
|
||||
Each framework maps these names to its own internal implementation.
|
||||
A framework that does not support a given metric should log a warning
|
||||
and skip it rather than raising an error.
|
||||
|
||||
── LLM / general text-quality metrics ──────────────────────────────────
|
||||
FAITHFULNESS
|
||||
Measures whether every claim in the model's response is grounded in
|
||||
the provided retrieved context. A high score means the answer
|
||||
contains no hallucinated content — each statement can be traced back
|
||||
to a passage in the context.
|
||||
Required fields: user_input, response, retrieved_contexts.
|
||||
|
||||
ANSWER_RELEVANCY
|
||||
Measures how well the model's response addresses the user's question.
|
||||
A high score means the answer stays on-topic; a low score indicates
|
||||
irrelevant content or a failure to answer the actual question.
|
||||
Required fields: user_input, response.
|
||||
|
||||
ANSWER_CORRECTNESS
|
||||
Measures the factual accuracy and completeness of the model's answer
|
||||
relative to a ground-truth reference. It combines semantic similarity
|
||||
with key-fact coverage, so both meaning and content matter.
|
||||
Required fields: user_input, response, reference (expected_output).
|
||||
|
||||
SEMANTIC_SIMILARITY
|
||||
Measures the cosine similarity between the model's response and the
|
||||
reference answer in an embedding space. It evaluates whether the two
|
||||
texts convey the same meaning, independent of factual correctness.
|
||||
Required fields: response, reference (expected_output).
|
||||
|
||||
── Retrieval-quality metrics ────────────────────────────────────────────
|
||||
CONTEXT_PRECISION
|
||||
Measures the proportion of retrieved context chunks that are actually
|
||||
relevant to the question (precision). A high score means the retrieval
|
||||
pipeline returns little noise.
|
||||
Required fields: user_input, reference, retrieved_contexts.
|
||||
|
||||
CONTEXT_RECALL
|
||||
Measures the proportion of ground-truth information that is covered by
|
||||
the retrieved context chunks (recall). A high score means the retrieval
|
||||
pipeline does not miss important supporting evidence.
|
||||
Required fields: user_input, reference, retrieved_contexts.
|
||||
|
||||
CONTEXT_RELEVANCE
|
||||
Measures how relevant each individual retrieved chunk is to the query.
|
||||
Similar to CONTEXT_PRECISION but evaluated at the chunk level rather
|
||||
than against a reference answer.
|
||||
Required fields: user_input, retrieved_contexts.
|
||||
|
||||
── Agent-quality metrics ────────────────────────────────────────────────
|
||||
TOOL_CORRECTNESS
|
||||
Measures the correctness of the tool calls made by the agent during
|
||||
task execution — both the choice of tool and the arguments passed.
|
||||
A high score means the agent's tool-use strategy matches the expected
|
||||
behavior.
|
||||
Required fields: actual tool calls vs. expected tool calls.
|
||||
|
||||
TASK_COMPLETION
|
||||
Measures whether the agent ultimately achieves the user's stated goal.
|
||||
It evaluates the reasoning chain, intermediate steps, and final output
|
||||
holistically; a high score means the task was fully accomplished.
|
||||
Required fields: user_input, actual_output.
|
||||
"""
|
||||
|
||||
# LLM / general text-quality metrics
|
||||
FAITHFULNESS = "faithfulness"
|
||||
ANSWER_RELEVANCY = "answer_relevancy"
|
||||
ANSWER_CORRECTNESS = "answer_correctness"
|
||||
SEMANTIC_SIMILARITY = "semantic_similarity"
|
||||
|
||||
# Retrieval-quality metrics
|
||||
CONTEXT_PRECISION = "context_precision"
|
||||
CONTEXT_RECALL = "context_recall"
|
||||
CONTEXT_RELEVANCE = "context_relevance"
|
||||
|
||||
# Agent-quality metrics
|
||||
TOOL_CORRECTNESS = "tool_correctness"
|
||||
TASK_COMPLETION = "task_completion"
|
||||
|
||||
|
||||
# Per-category canonical metric lists used by get_supported_metrics().
|
||||
LLM_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
EvaluationMetricName.FAITHFULNESS, # Every claim is grounded in context; no hallucinations
|
||||
EvaluationMetricName.ANSWER_RELEVANCY, # Response stays on-topic and addresses the question
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS, # Factual accuracy and completeness vs. reference
|
||||
EvaluationMetricName.SEMANTIC_SIMILARITY, # Semantic closeness to the reference answer
|
||||
]
|
||||
|
||||
RETRIEVAL_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
EvaluationMetricName.CONTEXT_PRECISION, # Fraction of retrieved chunks that are relevant (precision)
|
||||
EvaluationMetricName.CONTEXT_RECALL, # Fraction of ground-truth info covered by retrieval (recall)
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE, # Per-chunk relevance to the query
|
||||
]
|
||||
|
||||
AGENT_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
EvaluationMetricName.TOOL_CORRECTNESS, # Correct tool selection and arguments
|
||||
EvaluationMetricName.TASK_COMPLETION, # Whether the agent fully achieves the user's goal
|
||||
]
|
||||
|
||||
WORKFLOW_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
EvaluationMetricName.FAITHFULNESS,
|
||||
EvaluationMetricName.ANSWER_RELEVANCY,
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS,
|
||||
]
|
||||
|
||||
METRIC_NODE_TYPE_MAPPING: dict[str, str] = {
|
||||
**{m.value: "llm" for m in LLM_METRIC_NAMES},
|
||||
**{m.value: "knowledge-retrieval" for m in RETRIEVAL_METRIC_NAMES},
|
||||
**{m.value: "agent" for m in AGENT_METRIC_NAMES},
|
||||
}
|
||||
|
||||
METRIC_VALUE_TYPE_MAPPING: dict[str, str] = {
|
||||
EvaluationMetricName.FAITHFULNESS: "number",
|
||||
EvaluationMetricName.ANSWER_RELEVANCY: "number",
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS: "number",
|
||||
EvaluationMetricName.SEMANTIC_SIMILARITY: "number",
|
||||
EvaluationMetricName.CONTEXT_PRECISION: "number",
|
||||
EvaluationMetricName.CONTEXT_RECALL: "number",
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE: "number",
|
||||
EvaluationMetricName.TOOL_CORRECTNESS: "number",
|
||||
EvaluationMetricName.TASK_COMPLETION: "number",
|
||||
}
|
||||
|
||||
|
||||
class NodeInfo(BaseModel):
|
||||
node_id: str
|
||||
type: str
|
||||
title: str
|
||||
|
||||
|
||||
class EvaluationMetric(BaseModel):
|
||||
name: str
|
||||
value: Any
|
||||
details: dict[str, Any] = Field(default_factory=dict)
|
||||
node_info: NodeInfo | None = None
|
||||
|
||||
|
||||
class EvaluationItemInput(BaseModel):
|
||||
index: int
|
||||
inputs: dict[str, Any]
|
||||
output: str
|
||||
expected_output: str | None = None
|
||||
context: list[str] | None = None
|
||||
|
||||
|
||||
class EvaluationDatasetInput(BaseModel):
|
||||
index: int
|
||||
inputs: dict[str, Any]
|
||||
expected_output: str | None = None
|
||||
|
||||
|
||||
class EvaluationItemResult(BaseModel):
|
||||
index: int
|
||||
actual_output: str | None = None
|
||||
metrics: list[EvaluationMetric] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
judgment: JudgmentResult = Field(default_factory=JudgmentResult)
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class DefaultMetric(BaseModel):
|
||||
metric: str
|
||||
value_type: str = ""
|
||||
node_info_list: list[NodeInfo]
|
||||
|
||||
|
||||
class CustomizedMetricOutputField(BaseModel):
|
||||
variable: str
|
||||
value_type: str
|
||||
|
||||
|
||||
class CustomizedMetrics(BaseModel):
|
||||
evaluation_workflow_id: str
|
||||
input_fields: dict[str, Any]
|
||||
output_fields: list[CustomizedMetricOutputField]
|
||||
|
||||
|
||||
class EvaluationConfigData(BaseModel):
|
||||
"""Structured data for saving evaluation configuration."""
|
||||
|
||||
evaluation_model: str = ""
|
||||
evaluation_model_provider: str = ""
|
||||
default_metrics: list[DefaultMetric] = Field(default_factory=list)
|
||||
customized_metrics: CustomizedMetrics | None = None
|
||||
judgment_config: JudgmentConfig | None = None
|
||||
|
||||
|
||||
class EvaluationRunRequest(EvaluationConfigData):
|
||||
"""Request body for starting an evaluation run."""
|
||||
|
||||
file_id: str
|
||||
|
||||
|
||||
class EvaluationRunData(BaseModel):
|
||||
"""Serializable data for Celery task."""
|
||||
|
||||
evaluation_run_id: str
|
||||
tenant_id: str
|
||||
target_type: str
|
||||
target_id: str
|
||||
evaluation_model_provider: str
|
||||
evaluation_model: str
|
||||
default_metrics: list[DefaultMetric] = Field(default_factory=list)
|
||||
customized_metrics: CustomizedMetrics | None = None
|
||||
judgment_config: JudgmentConfig | None = None
|
||||
input_list: list[EvaluationDatasetInput]
|
||||
@ -1,96 +0,0 @@
|
||||
"""Judgment condition entities for evaluation metric assessment.
|
||||
|
||||
Condition structure mirrors the workflow if-else ``Condition`` model from
|
||||
``graphon.utils.condition.entities``. The left-hand side uses
|
||||
``variable_selector`` — a two-element list ``[node_id, metric_name]`` — to
|
||||
uniquely identify an evaluation metric (different nodes may produce metrics
|
||||
with the same name).
|
||||
|
||||
Operators reuse ``SupportedComparisonOperator`` from the workflow engine so
|
||||
that type semantics stay consistent across the platform.
|
||||
|
||||
Typical usage::
|
||||
|
||||
judgment_config = JudgmentConfig(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
JudgmentCondition(
|
||||
variable_selector=["node_abc", "faithfulness"],
|
||||
comparison_operator=">",
|
||||
value="0.8",
|
||||
)
|
||||
],
|
||||
)
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphon.utils.condition.entities import SupportedComparisonOperator
|
||||
|
||||
|
||||
class JudgmentCondition(BaseModel):
|
||||
"""A single judgment condition that checks one metric value.
|
||||
|
||||
Mirrors ``graphon.utils.condition.entities.Condition`` with the left-hand
|
||||
side being a metric selector instead of a workflow variable selector.
|
||||
|
||||
Attributes:
|
||||
variable_selector: ``[node_id, metric_name]`` identifying the metric.
|
||||
comparison_operator: Reuses workflow's ``SupportedComparisonOperator``.
|
||||
value: The comparison target (right side). For unary operators such
|
||||
as ``empty`` or ``null`` this can be ``None``.
|
||||
"""
|
||||
|
||||
variable_selector: list[str]
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | bool | None = None
|
||||
|
||||
|
||||
class JudgmentConfig(BaseModel):
|
||||
"""A group of judgment conditions combined with a logical operator.
|
||||
|
||||
Attributes:
|
||||
logical_operator: How to combine condition results — "and" requires
|
||||
all conditions to pass, "or" requires at least one.
|
||||
conditions: The list of individual conditions to evaluate.
|
||||
"""
|
||||
|
||||
logical_operator: Literal["and", "or"] = "and"
|
||||
conditions: list[JudgmentCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class JudgmentConditionResult(BaseModel):
|
||||
"""Result of evaluating a single judgment condition.
|
||||
|
||||
Attributes:
|
||||
variable_selector: ``[node_id, metric_name]`` that was checked.
|
||||
comparison_operator: The operator that was applied.
|
||||
expected_value: The resolved comparison value.
|
||||
actual_value: The actual metric value that was evaluated.
|
||||
passed: Whether this individual condition passed.
|
||||
error: Error message if the condition evaluation failed.
|
||||
"""
|
||||
|
||||
variable_selector: list[str]
|
||||
comparison_operator: str
|
||||
expected_value: Any = None
|
||||
actual_value: Any = None
|
||||
passed: bool = False
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class JudgmentResult(BaseModel):
|
||||
"""Overall result of evaluating all judgment conditions for one item.
|
||||
|
||||
Attributes:
|
||||
passed: Whether the overall judgment passed (based on logical_operator).
|
||||
logical_operator: The logical operator used to combine conditions.
|
||||
condition_results: Detailed result for each individual condition.
|
||||
"""
|
||||
|
||||
passed: bool = False
|
||||
logical_operator: Literal["and", "or"] = "and"
|
||||
condition_results: list[JudgmentConditionResult] = Field(default_factory=list)
|
||||
@ -1,61 +0,0 @@
|
||||
import collections
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.config_entity import EvaluationFrameworkEnum
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvaluationFrameworkConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
"""Registry mapping framework enum -> {config_class, evaluator_class}."""
|
||||
|
||||
def __getitem__(self, framework: str) -> dict[str, Any]:
|
||||
match framework:
|
||||
case EvaluationFrameworkEnum.RAGAS:
|
||||
from core.evaluation.entities.config_entity import RagasConfig
|
||||
from core.evaluation.frameworks.ragas.ragas_evaluator import RagasEvaluator
|
||||
|
||||
return {
|
||||
"config_class": RagasConfig,
|
||||
"evaluator_class": RagasEvaluator,
|
||||
}
|
||||
case EvaluationFrameworkEnum.DEEPEVAL:
|
||||
raise NotImplementedError("DeepEval adapter is not yet implemented.")
|
||||
case _:
|
||||
raise ValueError(f"Unknown evaluation framework: {framework}")
|
||||
|
||||
|
||||
evaluation_framework_config_map = EvaluationFrameworkConfigMap()
|
||||
|
||||
|
||||
class EvaluationManager:
|
||||
"""Factory for evaluation instances based on global configuration."""
|
||||
|
||||
@staticmethod
|
||||
def get_evaluation_instance() -> BaseEvaluationInstance | None:
|
||||
"""Create and return an evaluation instance based on EVALUATION_FRAMEWORK env var."""
|
||||
framework = dify_config.EVALUATION_FRAMEWORK
|
||||
if not framework or framework == EvaluationFrameworkEnum.NONE:
|
||||
return None
|
||||
|
||||
try:
|
||||
config_map = evaluation_framework_config_map[framework]
|
||||
evaluator_class = config_map["evaluator_class"]
|
||||
config_class = config_map["config_class"]
|
||||
config = config_class()
|
||||
return evaluator_class(config)
|
||||
except Exception:
|
||||
logger.exception("Failed to create evaluation instance for framework: %s", framework)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_supported_metrics(category: EvaluationCategory) -> list[str]:
|
||||
"""Return supported metrics for the current framework and given category."""
|
||||
instance = EvaluationManager.get_evaluation_instance()
|
||||
if instance is None:
|
||||
return []
|
||||
return instance.get_supported_metrics(category)
|
||||
@ -1 +0,0 @@
|
||||
|
||||
@ -1,299 +0,0 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.config_entity import DeepEvalConfig
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
AGENT_METRIC_NAMES,
|
||||
LLM_METRIC_NAMES,
|
||||
RETRIEVAL_METRIC_NAMES,
|
||||
WORKFLOW_METRIC_NAMES,
|
||||
EvaluationCategory,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
EvaluationMetric,
|
||||
EvaluationMetricName,
|
||||
)
|
||||
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps canonical EvaluationMetricName to the corresponding deepeval metric class name.
|
||||
# deepeval metric field requirements (LLMTestCase fields):
|
||||
# - faithfulness: input, actual_output, retrieval_context
|
||||
# - answer_relevancy: input, actual_output
|
||||
# - context_precision: input, actual_output, expected_output, retrieval_context
|
||||
# - context_recall: input, actual_output, expected_output, retrieval_context
|
||||
# - context_relevance: input, actual_output, retrieval_context
|
||||
# - tool_correctness: input, actual_output, expected_tools
|
||||
# - task_completion: input, actual_output
|
||||
# Metrics not listed here are unsupported by deepeval and will be skipped.
|
||||
_DEEPEVAL_METRIC_MAP: dict[EvaluationMetricName, str] = {
|
||||
EvaluationMetricName.FAITHFULNESS: "FaithfulnessMetric",
|
||||
EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancyMetric",
|
||||
EvaluationMetricName.CONTEXT_PRECISION: "ContextualPrecisionMetric",
|
||||
EvaluationMetricName.CONTEXT_RECALL: "ContextualRecallMetric",
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE: "ContextualRelevancyMetric",
|
||||
EvaluationMetricName.TOOL_CORRECTNESS: "ToolCorrectnessMetric",
|
||||
EvaluationMetricName.TASK_COMPLETION: "TaskCompletionMetric",
|
||||
}
|
||||
|
||||
|
||||
class DeepEvalEvaluator(BaseEvaluationInstance):
|
||||
"""DeepEval framework adapter for evaluation."""
|
||||
|
||||
def __init__(self, config: DeepEvalConfig):
|
||||
self.config = config
|
||||
|
||||
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
|
||||
match category:
|
||||
case EvaluationCategory.LLM:
|
||||
candidates = LLM_METRIC_NAMES
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
candidates = RETRIEVAL_METRIC_NAMES
|
||||
case EvaluationCategory.AGENT:
|
||||
candidates = AGENT_METRIC_NAMES
|
||||
case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET:
|
||||
candidates = WORKFLOW_METRIC_NAMES
|
||||
case _:
|
||||
return []
|
||||
return [m for m in candidates if m in _DEEPEVAL_METRIC_MAP]
|
||||
|
||||
def evaluate_llm(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
|
||||
|
||||
def evaluate_retrieval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
|
||||
|
||||
def evaluate_agent(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
|
||||
|
||||
def evaluate_workflow(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
|
||||
|
||||
def _evaluate(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
category: EvaluationCategory,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Core evaluation logic using DeepEval."""
|
||||
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
|
||||
requested_metrics = metric_names or self.get_supported_metrics(category)
|
||||
|
||||
try:
|
||||
return self._evaluate_with_deepeval(items, requested_metrics, category)
|
||||
except ImportError:
|
||||
logger.warning("DeepEval not installed, falling back to simple evaluation")
|
||||
return self._evaluate_simple(items, requested_metrics, model_wrapper)
|
||||
|
||||
def _evaluate_with_deepeval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
requested_metrics: list[str],
|
||||
category: EvaluationCategory,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate using DeepEval library.
|
||||
|
||||
Builds LLMTestCase differently per category:
|
||||
- LLM/Workflow: input=prompt, actual_output=output, retrieval_context=context
|
||||
- Retrieval: input=query, actual_output=output, expected_output, retrieval_context=context
|
||||
- Agent: input=query, actual_output=output
|
||||
"""
|
||||
metric_pairs = _build_deepeval_metrics(requested_metrics)
|
||||
if not metric_pairs:
|
||||
logger.warning("No valid DeepEval metrics found for: %s", requested_metrics)
|
||||
return [EvaluationItemResult(index=item.index) for item in items]
|
||||
|
||||
results: list[EvaluationItemResult] = []
|
||||
for item in items:
|
||||
test_case = self._build_test_case(item, category)
|
||||
metrics: list[EvaluationMetric] = []
|
||||
for canonical_name, metric in metric_pairs:
|
||||
try:
|
||||
metric.measure(test_case)
|
||||
if metric.score is not None:
|
||||
metrics.append(EvaluationMetric(name=canonical_name, value=float(metric.score)))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to compute metric %s for item %d",
|
||||
canonical_name,
|
||||
item.index,
|
||||
)
|
||||
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _build_test_case(item: EvaluationItemInput, category: EvaluationCategory) -> Any:
|
||||
"""Build a deepeval LLMTestCase with the correct fields per category."""
|
||||
from deepeval.test_case import LLMTestCase
|
||||
|
||||
user_input = _format_input(item.inputs, category)
|
||||
|
||||
match category:
|
||||
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
|
||||
# faithfulness needs: input, actual_output, retrieval_context
|
||||
# answer_relevancy needs: input, actual_output
|
||||
return LLMTestCase(
|
||||
input=user_input,
|
||||
actual_output=item.output,
|
||||
expected_output=item.expected_output or None,
|
||||
retrieval_context=item.context or None,
|
||||
)
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
# contextual_precision/recall needs: input, actual_output, expected_output, retrieval_context
|
||||
return LLMTestCase(
|
||||
input=user_input,
|
||||
actual_output=item.output or "",
|
||||
expected_output=item.expected_output or "",
|
||||
retrieval_context=item.context or [],
|
||||
)
|
||||
case _:
|
||||
return LLMTestCase(
|
||||
input=user_input,
|
||||
actual_output=item.output,
|
||||
)
|
||||
|
||||
def _evaluate_simple(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
requested_metrics: list[str],
|
||||
model_wrapper: DifyModelWrapper,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Simple LLM-as-judge fallback when DeepEval is not available."""
|
||||
results: list[EvaluationItemResult] = []
|
||||
for item in items:
|
||||
metrics: list[EvaluationMetric] = []
|
||||
for m_name in requested_metrics:
|
||||
try:
|
||||
score = self._judge_with_llm(model_wrapper, m_name, item)
|
||||
metrics.append(EvaluationMetric(name=m_name, value=score))
|
||||
except Exception:
|
||||
logger.exception("Failed to compute metric %s for item %d", m_name, item.index)
|
||||
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
|
||||
return results
|
||||
|
||||
def _judge_with_llm(
|
||||
self,
|
||||
model_wrapper: DifyModelWrapper,
|
||||
metric_name: str,
|
||||
item: EvaluationItemInput,
|
||||
) -> float:
|
||||
"""Use the LLM to judge a single metric for a single item."""
|
||||
prompt = self._build_judge_prompt(metric_name, item)
|
||||
response = model_wrapper.invoke(prompt)
|
||||
return self._parse_score(response)
|
||||
|
||||
@staticmethod
|
||||
def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str:
|
||||
"""Build a scoring prompt for the LLM judge."""
|
||||
parts = [
|
||||
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
|
||||
f"\nInput: {item.inputs}",
|
||||
f"\nOutput: {item.output}",
|
||||
]
|
||||
if item.expected_output:
|
||||
parts.append(f"\nExpected Output: {item.expected_output}")
|
||||
if item.context:
|
||||
parts.append(f"\nContext: {'; '.join(item.context)}")
|
||||
parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.")
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _parse_score(response: str) -> float:
|
||||
"""Parse a float score from LLM response."""
|
||||
import re
|
||||
|
||||
cleaned = response.strip()
|
||||
try:
|
||||
score = float(cleaned)
|
||||
return max(0.0, min(1.0, score))
|
||||
except ValueError:
|
||||
match = re.search(r"(\d+\.?\d*)", cleaned)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
return max(0.0, min(1.0, score))
|
||||
return 0.0
|
||||
|
||||
|
||||
def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str:
|
||||
"""Extract the user-facing input string from the inputs dict."""
|
||||
match category:
|
||||
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
|
||||
return str(inputs.get("prompt", ""))
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
return str(inputs.get("query", ""))
|
||||
case _:
|
||||
return str(next(iter(inputs.values()), "")) if inputs else ""
|
||||
|
||||
|
||||
def _build_deepeval_metrics(requested_metrics: list[str]) -> list[tuple[str, Any]]:
|
||||
"""Build DeepEval metric instances from canonical metric names.
|
||||
|
||||
Returns a list of (canonical_name, metric_instance) pairs so that callers
|
||||
can record the canonical name rather than the framework-internal class name.
|
||||
"""
|
||||
try:
|
||||
from deepeval.metrics import (
|
||||
AnswerRelevancyMetric,
|
||||
ContextualPrecisionMetric,
|
||||
ContextualRecallMetric,
|
||||
ContextualRelevancyMetric,
|
||||
FaithfulnessMetric,
|
||||
TaskCompletionMetric,
|
||||
ToolCorrectnessMetric,
|
||||
)
|
||||
|
||||
# Maps canonical name → deepeval metric class
|
||||
deepeval_class_map: dict[str, Any] = {
|
||||
EvaluationMetricName.FAITHFULNESS: FaithfulnessMetric,
|
||||
EvaluationMetricName.ANSWER_RELEVANCY: AnswerRelevancyMetric,
|
||||
EvaluationMetricName.CONTEXT_PRECISION: ContextualPrecisionMetric,
|
||||
EvaluationMetricName.CONTEXT_RECALL: ContextualRecallMetric,
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE: ContextualRelevancyMetric,
|
||||
EvaluationMetricName.TOOL_CORRECTNESS: ToolCorrectnessMetric,
|
||||
EvaluationMetricName.TASK_COMPLETION: TaskCompletionMetric,
|
||||
}
|
||||
|
||||
pairs: list[tuple[str, Any]] = []
|
||||
for name in requested_metrics:
|
||||
metric_class = deepeval_class_map.get(name)
|
||||
if metric_class:
|
||||
pairs.append((name, metric_class(threshold=0.5)))
|
||||
else:
|
||||
logger.warning("Metric '%s' is not supported by DeepEval, skipping", name)
|
||||
return pairs
|
||||
except ImportError:
|
||||
logger.warning("DeepEval metrics not available")
|
||||
return []
|
||||
@ -1,312 +0,0 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.config_entity import RagasConfig
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
AGENT_METRIC_NAMES,
|
||||
LLM_METRIC_NAMES,
|
||||
RETRIEVAL_METRIC_NAMES,
|
||||
WORKFLOW_METRIC_NAMES,
|
||||
EvaluationCategory,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
EvaluationMetric,
|
||||
EvaluationMetricName,
|
||||
)
|
||||
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps canonical EvaluationMetricName to the corresponding ragas metric class.
|
||||
# Metrics not listed here are unsupported by ragas and will be skipped.
|
||||
_RAGAS_METRIC_MAP: dict[EvaluationMetricName, str] = {
|
||||
EvaluationMetricName.FAITHFULNESS: "Faithfulness",
|
||||
EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancy",
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS: "AnswerCorrectness",
|
||||
EvaluationMetricName.SEMANTIC_SIMILARITY: "SemanticSimilarity",
|
||||
EvaluationMetricName.CONTEXT_PRECISION: "ContextPrecision",
|
||||
EvaluationMetricName.CONTEXT_RECALL: "ContextRecall",
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE: "ContextRelevance",
|
||||
EvaluationMetricName.TOOL_CORRECTNESS: "ToolCallAccuracy",
|
||||
}
|
||||
|
||||
|
||||
class RagasEvaluator(BaseEvaluationInstance):
|
||||
"""RAGAS framework adapter for evaluation."""
|
||||
|
||||
def __init__(self, config: RagasConfig):
|
||||
self.config = config
|
||||
|
||||
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
|
||||
match category:
|
||||
case EvaluationCategory.LLM:
|
||||
candidates = LLM_METRIC_NAMES
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
candidates = RETRIEVAL_METRIC_NAMES
|
||||
case EvaluationCategory.AGENT:
|
||||
candidates = AGENT_METRIC_NAMES
|
||||
case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET:
|
||||
candidates = WORKFLOW_METRIC_NAMES
|
||||
case _:
|
||||
return []
|
||||
return [m for m in candidates if m in _RAGAS_METRIC_MAP]
|
||||
|
||||
def evaluate_llm(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
|
||||
|
||||
def evaluate_retrieval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
|
||||
|
||||
def evaluate_agent(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
|
||||
|
||||
def evaluate_workflow(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
|
||||
|
||||
def _evaluate(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
category: EvaluationCategory,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Core evaluation logic using RAGAS."""
|
||||
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
|
||||
requested_metrics = metric_names or self.get_supported_metrics(category)
|
||||
|
||||
try:
|
||||
return self._evaluate_with_ragas(items, requested_metrics, model_wrapper, category)
|
||||
except ImportError:
|
||||
logger.warning("RAGAS not installed, falling back to simple evaluation")
|
||||
return self._evaluate_simple(items, requested_metrics, model_wrapper)
|
||||
|
||||
def _evaluate_with_ragas(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
requested_metrics: list[str],
|
||||
model_wrapper: DifyModelWrapper,
|
||||
category: EvaluationCategory,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate using RAGAS library.
|
||||
|
||||
Builds SingleTurnSample differently per category to match ragas requirements:
|
||||
- LLM/Workflow: user_input=prompt, response=output, reference=expected_output
|
||||
- Retrieval: user_input=query, reference=expected_output, retrieved_contexts=context
|
||||
- Agent: Not supported via EvaluationDataset (requires message-based API)
|
||||
"""
|
||||
from ragas import evaluate as ragas_evaluate
|
||||
from ragas.dataset_schema import EvaluationDataset
|
||||
|
||||
samples: list[Any] = []
|
||||
for item in items:
|
||||
sample = self._build_sample(item, category)
|
||||
samples.append(sample)
|
||||
|
||||
dataset = EvaluationDataset(samples=samples)
|
||||
|
||||
ragas_metrics = self._build_ragas_metrics(requested_metrics)
|
||||
if not ragas_metrics:
|
||||
logger.warning("No valid RAGAS metrics found for: %s", requested_metrics)
|
||||
return [EvaluationItemResult(index=item.index) for item in items]
|
||||
|
||||
try:
|
||||
result = ragas_evaluate(
|
||||
dataset=dataset,
|
||||
metrics=ragas_metrics,
|
||||
)
|
||||
|
||||
results: list[EvaluationItemResult] = []
|
||||
result_df = result.to_pandas()
|
||||
for i, item in enumerate(items):
|
||||
metrics: list[EvaluationMetric] = []
|
||||
for m_name in requested_metrics:
|
||||
if m_name in result_df.columns:
|
||||
score = result_df.iloc[i][m_name]
|
||||
if score is not None and not (isinstance(score, float) and score != score):
|
||||
metrics.append(EvaluationMetric(name=m_name, value=float(score)))
|
||||
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
|
||||
return results
|
||||
except Exception:
|
||||
logger.exception("RAGAS evaluation failed, falling back to simple evaluation")
|
||||
return self._evaluate_simple(items, requested_metrics, model_wrapper)
|
||||
|
||||
@staticmethod
|
||||
def _build_sample(item: EvaluationItemInput, category: EvaluationCategory) -> Any:
|
||||
"""Build a ragas SingleTurnSample with the correct fields per category.
|
||||
|
||||
ragas metric field requirements:
|
||||
- faithfulness: user_input, response, retrieved_contexts
|
||||
- answer_relevancy: user_input, response
|
||||
- answer_correctness: user_input, response, reference
|
||||
- semantic_similarity: user_input, response, reference
|
||||
- context_precision: user_input, reference, retrieved_contexts
|
||||
- context_recall: user_input, reference, retrieved_contexts
|
||||
- context_relevance: user_input, retrieved_contexts
|
||||
"""
|
||||
from ragas.dataset_schema import SingleTurnSample
|
||||
|
||||
user_input = _format_input(item.inputs, category)
|
||||
|
||||
match category:
|
||||
case EvaluationCategory.LLM:
|
||||
# response = actual LLM output, reference = expected output
|
||||
return SingleTurnSample(
|
||||
user_input=user_input,
|
||||
response=item.output,
|
||||
reference=item.expected_output or "",
|
||||
retrieved_contexts=item.context or [],
|
||||
)
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
# context_precision/recall only need reference + retrieved_contexts
|
||||
return SingleTurnSample(
|
||||
user_input=user_input,
|
||||
reference=item.expected_output or "",
|
||||
retrieved_contexts=item.context or [],
|
||||
)
|
||||
case _:
|
||||
return SingleTurnSample(
|
||||
user_input=user_input,
|
||||
response=item.output,
|
||||
)
|
||||
|
||||
def _evaluate_simple(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
requested_metrics: list[str],
|
||||
model_wrapper: DifyModelWrapper,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Simple LLM-as-judge fallback when RAGAS is not available."""
|
||||
results: list[EvaluationItemResult] = []
|
||||
for item in items:
|
||||
metrics: list[EvaluationMetric] = []
|
||||
for m_name in requested_metrics:
|
||||
try:
|
||||
score = self._judge_with_llm(model_wrapper, m_name, item)
|
||||
metrics.append(EvaluationMetric(name=m_name, value=score))
|
||||
except Exception:
|
||||
logger.exception("Failed to compute metric %s for item %d", m_name, item.index)
|
||||
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
|
||||
return results
|
||||
|
||||
def _judge_with_llm(
|
||||
self,
|
||||
model_wrapper: DifyModelWrapper,
|
||||
metric_name: str,
|
||||
item: EvaluationItemInput,
|
||||
) -> float:
|
||||
"""Use the LLM to judge a single metric for a single item."""
|
||||
prompt = self._build_judge_prompt(metric_name, item)
|
||||
response = model_wrapper.invoke(prompt)
|
||||
return self._parse_score(response)
|
||||
|
||||
@staticmethod
|
||||
def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str:
|
||||
"""Build a scoring prompt for the LLM judge."""
|
||||
parts = [
|
||||
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
|
||||
f"\nInput: {item.inputs}",
|
||||
f"\nOutput: {item.output}",
|
||||
]
|
||||
if item.expected_output:
|
||||
parts.append(f"\nExpected Output: {item.expected_output}")
|
||||
if item.context:
|
||||
parts.append(f"\nContext: {'; '.join(item.context)}")
|
||||
parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.")
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _parse_score(response: str) -> float:
|
||||
"""Parse a float score from LLM response."""
|
||||
import re
|
||||
|
||||
cleaned = response.strip()
|
||||
try:
|
||||
score = float(cleaned)
|
||||
return max(0.0, min(1.0, score))
|
||||
except ValueError:
|
||||
match = re.search(r"(\d+\.?\d*)", cleaned)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
return max(0.0, min(1.0, score))
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _build_ragas_metrics(requested_metrics: list[str]) -> list[Any]:
|
||||
"""Build RAGAS metric instances from canonical metric names."""
|
||||
try:
|
||||
from ragas.metrics.collections import (
|
||||
AnswerCorrectness,
|
||||
AnswerRelevancy,
|
||||
ContextPrecision,
|
||||
ContextRecall,
|
||||
ContextRelevance,
|
||||
Faithfulness,
|
||||
SemanticSimilarity,
|
||||
ToolCallAccuracy,
|
||||
)
|
||||
|
||||
# Maps canonical name → ragas metric class
|
||||
ragas_class_map: dict[str, Any] = {
|
||||
EvaluationMetricName.FAITHFULNESS: Faithfulness,
|
||||
EvaluationMetricName.ANSWER_RELEVANCY: AnswerRelevancy,
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS: AnswerCorrectness,
|
||||
EvaluationMetricName.SEMANTIC_SIMILARITY: SemanticSimilarity,
|
||||
EvaluationMetricName.CONTEXT_PRECISION: ContextPrecision,
|
||||
EvaluationMetricName.CONTEXT_RECALL: ContextRecall,
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE: ContextRelevance,
|
||||
EvaluationMetricName.TOOL_CORRECTNESS: ToolCallAccuracy,
|
||||
}
|
||||
|
||||
metrics = []
|
||||
for name in requested_metrics:
|
||||
metric_class = ragas_class_map.get(name)
|
||||
if metric_class:
|
||||
metrics.append(metric_class())
|
||||
else:
|
||||
logger.warning("Metric '%s' is not supported by RAGAS, skipping", name)
|
||||
return metrics
|
||||
except ImportError:
|
||||
logger.warning("RAGAS metrics not available")
|
||||
return []
|
||||
|
||||
|
||||
def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str:
|
||||
"""Extract the user-facing input string from the inputs dict."""
|
||||
match category:
|
||||
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
|
||||
return str(inputs.get("prompt", ""))
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
return str(inputs.get("query", ""))
|
||||
case _:
|
||||
return str(next(iter(inputs.values()), "")) if inputs else ""
|
||||
@ -1,48 +0,0 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyModelWrapper:
|
||||
"""Wraps Dify's model invocation interface for use by RAGAS as an LLM judge.
|
||||
|
||||
RAGAS requires an LLM to compute certain metrics (faithfulness, answer_relevancy, etc.).
|
||||
This wrapper bridges Dify's ModelInstance to a callable that RAGAS can use.
|
||||
"""
|
||||
|
||||
def __init__(self, model_provider: str, model_name: str, tenant_id: str):
|
||||
self.model_provider = model_provider
|
||||
self.model_name = model_name
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def _get_model_instance(self) -> Any:
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=self.model_provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=self.model_name,
|
||||
)
|
||||
return model_instance
|
||||
|
||||
def invoke(self, prompt: str) -> str:
|
||||
"""Invoke the model with a text prompt and return the text response."""
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
model_instance = self._get_model_instance()
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content="You are an evaluation judge. Answer precisely and concisely."),
|
||||
UserPromptMessage(content=prompt),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 2048},
|
||||
stream=False,
|
||||
)
|
||||
return result.message.content
|
||||
@ -1,160 +0,0 @@
|
||||
"""Judgment condition processor for evaluation metrics.
|
||||
|
||||
Evaluates pass/fail judgment conditions against evaluation metric values.
|
||||
Each condition uses ``variable_selector`` (``[node_id, metric_name]``) to
|
||||
look up the metric value, then delegates the actual comparison to the
|
||||
workflow condition engine (``graphon.utils.condition.processor``).
|
||||
|
||||
The processor is intentionally decoupled from evaluation frameworks and
|
||||
runners. It operates on plain ``dict`` mappings and can be invoked
|
||||
anywhere that already has per-item metric results.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.evaluation.entities.judgment_entity import (
|
||||
JudgmentCondition,
|
||||
JudgmentConditionResult,
|
||||
JudgmentConfig,
|
||||
JudgmentResult,
|
||||
)
|
||||
from graphon.utils.condition.entities import SupportedComparisonOperator
|
||||
from graphon.utils.condition.processor import _evaluate_condition # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_UNARY_OPERATORS = frozenset({"null", "not null", "empty", "not empty"})
|
||||
|
||||
|
||||
class JudgmentProcessor:
|
||||
@staticmethod
|
||||
def evaluate(
|
||||
metric_values: dict[tuple[str, str], Any],
|
||||
config: JudgmentConfig,
|
||||
) -> JudgmentResult:
|
||||
"""Evaluate all judgment conditions against the given metric values.
|
||||
|
||||
Args:
|
||||
metric_values: Mapping of ``(node_id, metric_name)`` → metric
|
||||
value (e.g. ``{("node_abc", "faithfulness"): 0.85}``).
|
||||
config: The judgment configuration with logical_operator and
|
||||
conditions.
|
||||
|
||||
Returns:
|
||||
JudgmentResult with overall pass/fail and per-condition details.
|
||||
"""
|
||||
if not config.conditions:
|
||||
return JudgmentResult(
|
||||
passed=True,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=[],
|
||||
)
|
||||
|
||||
condition_results: list[JudgmentConditionResult] = []
|
||||
|
||||
for condition in config.conditions:
|
||||
result = JudgmentProcessor._evaluate_single_condition(metric_values, condition)
|
||||
condition_results.append(result)
|
||||
|
||||
if config.logical_operator == "and" and not result.passed:
|
||||
return JudgmentResult(
|
||||
passed=False,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=condition_results,
|
||||
)
|
||||
if config.logical_operator == "or" and result.passed:
|
||||
return JudgmentResult(
|
||||
passed=True,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=condition_results,
|
||||
)
|
||||
|
||||
if config.logical_operator == "and":
|
||||
final_passed = all(r.passed for r in condition_results)
|
||||
else:
|
||||
final_passed = any(r.passed for r in condition_results)
|
||||
|
||||
return JudgmentResult(
|
||||
passed=final_passed,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=condition_results,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _evaluate_single_condition(
|
||||
metric_values: dict[tuple[str, str], Any],
|
||||
condition: JudgmentCondition,
|
||||
) -> JudgmentConditionResult:
|
||||
"""Evaluate a single judgment condition.
|
||||
|
||||
Steps:
|
||||
1. Extract ``(node_id, metric_name)`` from ``variable_selector``.
|
||||
2. Look up the metric value from ``metric_values``.
|
||||
3. Delegate comparison to the workflow condition engine.
|
||||
"""
|
||||
selector = condition.variable_selector
|
||||
if len(selector) < 2:
|
||||
return JudgmentConditionResult(
|
||||
variable_selector=selector,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
expected_value=condition.value,
|
||||
actual_value=None,
|
||||
passed=False,
|
||||
error=f"variable_selector must have at least 2 elements, got {len(selector)}",
|
||||
)
|
||||
|
||||
node_id, metric_name = selector[0], selector[1]
|
||||
actual_value = metric_values.get((node_id, metric_name))
|
||||
|
||||
if actual_value is None and condition.comparison_operator not in _UNARY_OPERATORS:
|
||||
return JudgmentConditionResult(
|
||||
variable_selector=selector,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
expected_value=condition.value,
|
||||
actual_value=None,
|
||||
passed=False,
|
||||
error=f"Metric '{metric_name}' on node '{node_id}' not found in evaluation results",
|
||||
)
|
||||
|
||||
try:
|
||||
expected = condition.value
|
||||
# Numeric operators need the actual value coerced to int/float
|
||||
# so that the workflow engine's numeric assertions work correctly.
|
||||
coerced_actual: object = actual_value
|
||||
if (
|
||||
condition.comparison_operator in {"=", "≠", ">", "<", "≥", "≤"}
|
||||
and actual_value is not None
|
||||
and not isinstance(actual_value, (int, float, bool))
|
||||
):
|
||||
coerced_actual = float(actual_value)
|
||||
|
||||
passed = _evaluate_condition(
|
||||
operator=cast(SupportedComparisonOperator, condition.comparison_operator),
|
||||
value=coerced_actual,
|
||||
expected=cast(str | Sequence[str] | bool | Sequence[bool] | None, expected),
|
||||
)
|
||||
|
||||
return JudgmentConditionResult(
|
||||
variable_selector=selector,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
expected_value=expected,
|
||||
actual_value=actual_value,
|
||||
passed=passed,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Judgment condition evaluation failed for [%s, %s]: %s",
|
||||
node_id,
|
||||
metric_name,
|
||||
str(e),
|
||||
)
|
||||
return JudgmentConditionResult(
|
||||
variable_selector=selector,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
expected_value=condition.value,
|
||||
actual_value=actual_value,
|
||||
passed=False,
|
||||
error=str(e),
|
||||
)
|
||||
@ -1,52 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import Account, App, CustomizedSnippet, TenantAccountJoin
|
||||
|
||||
|
||||
def get_service_account_for_app(session: Session, app_id: str) -> Account:
|
||||
"""Get the creator account for an app with tenant context set up.
|
||||
|
||||
This follows the same pattern as BaseTraceInstance.get_service_account_with_tenant().
|
||||
"""
|
||||
app = session.scalar(select(App).where(App.id == app_id))
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator")
|
||||
|
||||
account = session.scalar(select(Account).where(Account.id == app.created_by))
|
||||
if not account:
|
||||
raise ValueError(f"Creator account not found for app {app_id}")
|
||||
|
||||
current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {account.id}")
|
||||
|
||||
account.set_tenant_id(current_tenant.tenant_id)
|
||||
return account
|
||||
|
||||
|
||||
def get_service_account_for_snippet(session: Session, snippet_id: str) -> Account:
|
||||
"""Get the creator account for a snippet with tenant context set up.
|
||||
|
||||
Mirrors :func:`get_service_account_for_app` but queries CustomizedSnippet.
|
||||
"""
|
||||
snippet = session.scalar(select(CustomizedSnippet).where(CustomizedSnippet.id == snippet_id))
|
||||
if not snippet:
|
||||
raise ValueError(f"Snippet with id {snippet_id} not found")
|
||||
|
||||
if not snippet.created_by:
|
||||
raise ValueError(f"Snippet with id {snippet_id} has no creator")
|
||||
|
||||
account = session.scalar(select(Account).where(Account.id == snippet.created_by))
|
||||
if not account:
|
||||
raise ValueError(f"Creator account not found for snippet {snippet_id}")
|
||||
|
||||
current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {account.id}")
|
||||
|
||||
account.set_tenant_id(current_tenant.tenant_id)
|
||||
return account
|
||||
@ -1,62 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for agent evaluation: collects tool calls and final output."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
super().__init__(evaluation_instance)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute agent evaluation metrics."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
merged_items = self._merge_results_into_items(node_run_result_list)
|
||||
return self.evaluation_instance.evaluate_agent(
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
|
||||
"""Create EvaluationItemInput list from NodeRunResult for agent evaluation."""
|
||||
merged = []
|
||||
for i, item in enumerate(items):
|
||||
output = _extract_agent_output(item.outputs)
|
||||
merged.append(
|
||||
EvaluationItemInput(
|
||||
index=i,
|
||||
inputs=dict(item.inputs),
|
||||
output=output,
|
||||
)
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def _extract_agent_output(outputs: Mapping[str, Any]) -> str:
|
||||
"""Extract the primary output text from agent NodeRunResult.outputs."""
|
||||
if "answer" in outputs:
|
||||
return str(outputs["answer"])
|
||||
if "text" in outputs:
|
||||
return str(outputs["text"])
|
||||
values = list(outputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
@ -1,51 +0,0 @@
|
||||
"""Base evaluation runner.
|
||||
|
||||
Provides the abstract interface for metric computation. Each concrete runner
|
||||
(LLM, Retrieval, Agent, Workflow, Snippet) implements ``evaluate_metrics``
|
||||
to compute scores for a specific node type.
|
||||
|
||||
Orchestration (merging results from multiple runners, applying judgment, and
|
||||
persisting to the database) is handled by the evaluation task, not the runner.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseEvaluationRunner(ABC):
|
||||
"""Abstract base class for evaluation runners.
|
||||
|
||||
Runners are stateless metric calculators: they receive node execution
|
||||
results and a metric specification, then return scored results. They
|
||||
do **not** touch the database or apply judgment logic.
|
||||
"""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
self.evaluation_instance = evaluation_instance
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute evaluation metrics on the collected results.
|
||||
|
||||
The returned ``EvaluationItemResult.index`` values are positional
|
||||
(0-based) and correspond to the order of *node_run_result_list*.
|
||||
The caller is responsible for mapping them back to the original
|
||||
dataset indices.
|
||||
"""
|
||||
...
|
||||
@ -1,83 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for LLM evaluation: extracts prompts/outputs then evaluates."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
super().__init__(evaluation_instance)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Use the evaluation instance to compute LLM metrics."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
merged_items = self._merge_results_into_items(node_run_result_list)
|
||||
return self.evaluation_instance.evaluate_llm(
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_results_into_items(
|
||||
items: list[NodeRunResult],
|
||||
) -> list[EvaluationItemInput]:
|
||||
"""Create new items from NodeRunResult for ragas evaluation.
|
||||
|
||||
Extracts prompts from process_data and concatenates them into a single
|
||||
string with role prefixes (e.g. "system: ...\nuser: ...\nassistant: ...").
|
||||
The last assistant message in outputs is used as the actual output.
|
||||
"""
|
||||
merged = []
|
||||
for i, item in enumerate(items):
|
||||
prompt = _format_prompts(item.process_data.get("prompts", []))
|
||||
output = _extract_llm_output(item.outputs)
|
||||
merged.append(
|
||||
EvaluationItemInput(
|
||||
index=i,
|
||||
inputs={"prompt": prompt},
|
||||
output=output,
|
||||
)
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def _format_prompts(prompts: list[dict[str, Any]]) -> str:
|
||||
"""Concatenate a list of prompt messages into a single string for evaluation.
|
||||
|
||||
Each message is formatted as "role: text" and joined with newlines.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for msg in prompts:
|
||||
role = msg.get("role", "unknown")
|
||||
text = msg.get("text", "")
|
||||
parts.append(f"{role}: {text}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _extract_llm_output(outputs: Mapping[str, Any]) -> str:
|
||||
"""Extract the LLM output text from NodeRunResult.outputs."""
|
||||
if "text" in outputs:
|
||||
return str(outputs["text"])
|
||||
if "answer" in outputs:
|
||||
return str(outputs["answer"])
|
||||
values = list(outputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
@ -1,61 +0,0 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrievalEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for retrieval evaluation: performs knowledge base retrieval, then evaluates."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
super().__init__(evaluation_instance)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute retrieval evaluation metrics."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
|
||||
merged_items = []
|
||||
for i, node_result in enumerate(node_run_result_list):
|
||||
outputs = node_result.outputs
|
||||
query = self._extract_query(dict(node_result.inputs))
|
||||
result_list = outputs.get("result", [])
|
||||
contexts = [item.get("content", "") for item in result_list if item.get("content")]
|
||||
output = "\n---\n".join(contexts)
|
||||
|
||||
merged_items.append(
|
||||
EvaluationItemInput(
|
||||
index=i,
|
||||
inputs={"query": query},
|
||||
output=output,
|
||||
context=contexts,
|
||||
)
|
||||
)
|
||||
|
||||
return self.evaluation_instance.evaluate_retrieval(
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_query(inputs: dict[str, Any]) -> str:
|
||||
for key in ("query", "question", "input", "text"):
|
||||
if key in inputs:
|
||||
return str(inputs[key])
|
||||
values = list(inputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
@ -1,68 +0,0 @@
|
||||
"""Runner for Snippet evaluation.
|
||||
|
||||
Snippets are essentially workflows, so we reuse ``evaluate_workflow`` from
|
||||
the evaluation instance for metric computation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SnippetEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for snippet evaluation: evaluates a published Snippet workflow."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
super().__init__(evaluation_instance)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute evaluation metrics for snippet outputs."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
merged_items = self._merge_results_into_items(node_run_result_list)
|
||||
return self.evaluation_instance.evaluate_workflow(
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
|
||||
"""Create EvaluationItemInput list from NodeRunResult for snippet evaluation."""
|
||||
merged = []
|
||||
for i, item in enumerate(items):
|
||||
output = _extract_snippet_output(item.outputs)
|
||||
merged.append(
|
||||
EvaluationItemInput(
|
||||
index=i,
|
||||
inputs=dict(item.inputs),
|
||||
output=output,
|
||||
)
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def _extract_snippet_output(outputs: Mapping[str, Any]) -> str:
|
||||
"""Extract the primary output text from snippet NodeRunResult.outputs."""
|
||||
if "answer" in outputs:
|
||||
return str(outputs["answer"])
|
||||
if "text" in outputs:
|
||||
return str(outputs["text"])
|
||||
values = list(outputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
@ -1,62 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for workflow evaluation: executes workflow App in non-streaming mode."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
super().__init__(evaluation_instance)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute workflow evaluation metrics (end-to-end)."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
merged_items = self._merge_results_into_items(node_run_result_list)
|
||||
return self.evaluation_instance.evaluate_workflow(
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
|
||||
"""Create EvaluationItemInput list from NodeRunResult for workflow evaluation."""
|
||||
merged = []
|
||||
for i, item in enumerate(items):
|
||||
output = _extract_workflow_output(item.outputs)
|
||||
merged.append(
|
||||
EvaluationItemInput(
|
||||
index=i,
|
||||
inputs=dict(item.inputs),
|
||||
output=output,
|
||||
)
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def _extract_workflow_output(outputs: Mapping[str, Any]) -> str:
|
||||
"""Extract the primary output text from workflow NodeRunResult.outputs."""
|
||||
if "answer" in outputs:
|
||||
return str(outputs["answer"])
|
||||
if "text" in outputs:
|
||||
return str(outputs["text"])
|
||||
values = list(outputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
@ -102,7 +102,7 @@ class TemplateTransformer(ABC):
|
||||
|
||||
@classmethod
|
||||
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
|
||||
inputs_json_str = dumps_with_segments(inputs).encode()
|
||||
inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
|
||||
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
|
||||
return input_base64_encoded
|
||||
|
||||
|
||||
@ -2,13 +2,14 @@ import logging
|
||||
import secrets
|
||||
from typing import cast
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from models.provider import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -12,7 +12,6 @@ from pydantic import TypeAdapter, ValidationError
|
||||
from configs import dify_config
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from core.tools.errors import ToolSSRFError
|
||||
from graphon.http.response import HttpResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -268,47 +267,4 @@ class SSRFProxy:
|
||||
return patch(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def _to_graphon_http_response(response: httpx.Response) -> HttpResponse:
|
||||
"""Convert an ``httpx`` response into Graphon's transport-agnostic wrapper."""
|
||||
return HttpResponse(
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
content=response.content,
|
||||
url=str(response.url) if response.url else None,
|
||||
reason_phrase=response.reason_phrase,
|
||||
fallback_text=response.text,
|
||||
)
|
||||
|
||||
|
||||
class GraphonSSRFProxy:
|
||||
"""Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``."""
|
||||
|
||||
@property
|
||||
def max_retries_exceeded_error(self) -> type[Exception]:
|
||||
return max_retries_exceeded_error
|
||||
|
||||
@property
|
||||
def request_error(self) -> type[Exception]:
|
||||
return request_error
|
||||
|
||||
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
|
||||
ssrf_proxy = SSRFProxy()
|
||||
graphon_ssrf_proxy = GraphonSSRFProxy()
|
||||
|
||||
@ -9,6 +9,7 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
@ -34,7 +35,6 @@ from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
|
||||
@ -5,6 +5,11 @@ from collections.abc import Sequence
|
||||
from typing import Any, Protocol, TypedDict, cast
|
||||
|
||||
import json_repair
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
@ -30,11 +35,6 @@ from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from models import App, Message, WorkflowNodeExecutionModel
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
@ -1,6 +1,20 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
|
||||
from typing import IO, Any, Literal, Optional, Union, cast, overload
|
||||
|
||||
from graphon.model_runtime.callbacks.base_callback import Callback
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities import PluginCredentialType
|
||||
@ -11,24 +25,9 @@ from core.errors.error import ProviderTokenNotInitError
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.provider_manager import ProviderManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.callbacks.base_callback import Callback
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
|
||||
from models.provider import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class ModelInstance:
|
||||
@ -170,7 +169,7 @@ class ModelInstance:
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
self.model_type_instance.invoke,
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
@ -195,7 +194,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
return self._round_robin_invoke(
|
||||
self.model_type_instance.get_num_tokens,
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
@ -215,7 +214,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return self._round_robin_invoke(
|
||||
self.model_type_instance.invoke,
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
@ -237,7 +236,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return self._round_robin_invoke(
|
||||
self.model_type_instance.invoke,
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
@ -254,7 +253,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return self._round_robin_invoke(
|
||||
self.model_type_instance.get_num_tokens,
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
@ -279,7 +278,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return self._round_robin_invoke(
|
||||
self.model_type_instance.invoke,
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
@ -307,7 +306,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return self._round_robin_invoke(
|
||||
self.model_type_instance.invoke_multimodal_rerank,
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
@ -326,7 +325,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
return self._round_robin_invoke(
|
||||
self.model_type_instance.invoke,
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
@ -342,7 +341,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
return self._round_robin_invoke(
|
||||
self.model_type_instance.invoke,
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
@ -359,14 +358,14 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
return self._round_robin_invoke(
|
||||
self.model_type_instance.invoke,
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""
|
||||
Round-robin invoke
|
||||
:param function: function to invoke
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class OpenAIModeration(Moderation):
|
||||
|
||||
@ -6,20 +6,7 @@ from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from opentelemetry.trace import SpanKind
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_aliyun.config import AliyunConfig
|
||||
from dify_trace_aliyun.data_exporter.traceclient import (
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
TraceClient,
|
||||
build_endpoint,
|
||||
convert_datetime_to_nanoseconds,
|
||||
@ -27,8 +14,8 @@ from dify_trace_aliyun.data_exporter.traceclient import (
|
||||
convert_to_trace_id,
|
||||
generate_span_id,
|
||||
)
|
||||
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
DIFY_APP_ID,
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
@ -47,7 +34,7 @@ from dify_trace_aliyun.entities.semconv import (
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from dify_trace_aliyun.utils import (
|
||||
from core.ops.aliyun_trace.utils import (
|
||||
create_common_span_attributes,
|
||||
create_links_from_trace_id,
|
||||
create_status_from_error,
|
||||
@ -59,6 +46,19 @@ from dify_trace_aliyun.utils import (
|
||||
get_workflow_node_status,
|
||||
serialize_json_data,
|
||||
)
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
@ -26,8 +26,8 @@ from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
|
||||
from configs import dify_config
|
||||
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData
|
||||
from dify_trace_aliyun.entities.semconv import ACS_ARMS_SERVICE_FEATURE
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE
|
||||
|
||||
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
|
||||
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000
|
||||
@ -6,8 +6,7 @@ from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from opentelemetry.trace import Link, Status, StatusCode
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
@ -16,6 +15,7 @@ from dify_trace_aliyun.entities.semconv import (
|
||||
OUTPUT_VALUE,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser
|
||||
|
||||
@ -48,7 +48,7 @@ def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
|
||||
|
||||
|
||||
def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
|
||||
from dify_trace_aliyun.data_exporter.traceclient import create_link
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import create_link
|
||||
|
||||
links = []
|
||||
if trace_id:
|
||||
@ -26,6 +26,7 @@ from opentelemetry.util.types import AttributeValue
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
@ -39,7 +40,6 @@ from core.ops.entities.trace_entity import (
|
||||
)
|
||||
from core.ops.utils import JSON_DICT_ADAPTER
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
@ -1,8 +1,8 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
|
||||
from core.ops.utils import validate_project_name, validate_url
|
||||
from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path
|
||||
|
||||
|
||||
class TracingProviderEnum(StrEnum):
|
||||
@ -52,5 +52,220 @@ class BaseTracingConfig(BaseModel):
|
||||
return validate_project_name(v, default_name)
|
||||
|
||||
|
||||
class ArizeConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Arize tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
space_id: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://otlp.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
|
||||
|
||||
|
||||
class PhoenixConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Phoenix tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://app.phoenix.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://app.phoenix.arize.com")
|
||||
|
||||
|
||||
class LangfuseConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langfuse tracing config.
|
||||
"""
|
||||
|
||||
public_key: str
|
||||
secret_key: str
|
||||
host: str = "https://api.langfuse.com"
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://api.langfuse.com")
|
||||
|
||||
|
||||
class LangSmithConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langsmith tracing config.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
project: str
|
||||
endpoint: str = "https://api.smith.langchain.com"
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# LangSmith only allows HTTPS
|
||||
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
|
||||
|
||||
|
||||
class OpikConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Opik tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
workspace: str | None = None
|
||||
url: str = "https://www.comet.com/opik/api/"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "Default Project")
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
|
||||
|
||||
|
||||
class WeaveConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Weave tracing config.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
entity: str | None = None
|
||||
project: str
|
||||
endpoint: str = "https://trace.wandb.ai"
|
||||
host: str | None = None
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# Weave only allows HTTPS for endpoint
|
||||
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
if v is not None and v.strip() != "":
|
||||
return validate_url(v, v, allowed_schemes=("https", "http"))
|
||||
return v
|
||||
|
||||
|
||||
class AliyunConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Aliyun tracing config.
|
||||
"""
|
||||
|
||||
app_name: str = "dify_app"
|
||||
license_key: str
|
||||
endpoint: str
|
||||
|
||||
@field_validator("app_name")
|
||||
@classmethod
|
||||
def app_name_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "dify_app")
|
||||
|
||||
@field_validator("license_key")
|
||||
@classmethod
|
||||
def license_key_validator(cls, v, info: ValidationInfo):
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("License key cannot be empty")
|
||||
return v
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# aliyun uses two URL formats, which may include a URL path
|
||||
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
|
||||
class TencentConfig(BaseTracingConfig):
|
||||
"""
|
||||
Tencent APM tracing config
|
||||
"""
|
||||
|
||||
token: str
|
||||
endpoint: str
|
||||
service_name: str
|
||||
|
||||
@field_validator("token")
|
||||
@classmethod
|
||||
def token_validator(cls, v, info: ValidationInfo):
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Token cannot be empty")
|
||||
return v
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
|
||||
|
||||
@field_validator("service_name")
|
||||
@classmethod
|
||||
def service_name_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "dify_app")
|
||||
|
||||
|
||||
class MLflowConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for MLflow tracing config.
|
||||
"""
|
||||
|
||||
tracking_uri: str = "http://localhost:5000"
|
||||
experiment_id: str = "0" # Default experiment id in MLflow is 0
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
|
||||
@field_validator("tracking_uri")
|
||||
@classmethod
|
||||
def tracking_uri_validator(cls, v, info: ValidationInfo):
|
||||
if isinstance(v, str) and v.startswith("databricks"):
|
||||
raise ValueError(
|
||||
"Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
|
||||
)
|
||||
return validate_url_with_path(v, "http://localhost:5000")
|
||||
|
||||
@field_validator("experiment_id")
|
||||
@classmethod
|
||||
def experiment_id_validator(cls, v, info: ValidationInfo):
|
||||
return validate_integer_id(v)
|
||||
|
||||
|
||||
class DatabricksConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Databricks (Databricks-managed MLflow) tracing config.
|
||||
"""
|
||||
|
||||
experiment_id: str
|
||||
host: str
|
||||
client_id: str | None = None
|
||||
client_secret: str | None = None
|
||||
personal_access_token: str | None = None
|
||||
|
||||
@field_validator("experiment_id")
|
||||
@classmethod
|
||||
def experiment_id_validator(cls, v, info: ValidationInfo):
|
||||
return validate_integer_id(v)
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||
|
||||
@ -3,6 +3,7 @@ import os
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from langfuse import Langfuse
|
||||
from langfuse.api import (
|
||||
CreateGenerationBody,
|
||||
@ -16,6 +17,7 @@ from langfuse.api.commons.types.usage import Usage
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
@ -27,10 +29,7 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
from dify_trace_langfuse.entities.langfuse_trace_entity import (
|
||||
from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
GenerationUsage,
|
||||
LangfuseGeneration,
|
||||
LangfuseSpan,
|
||||
@ -38,8 +37,9 @@ from dify_trace_langfuse.entities.langfuse_trace_entity import (
|
||||
LevelEnum,
|
||||
UnitEnum,
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import MessageStatus
|
||||
|
||||
@ -10,6 +10,7 @@ from langsmith.schemas import RunBase
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
@ -21,14 +22,13 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_langsmith.config import LangSmithConfig
|
||||
from dify_trace_langsmith.entities.langsmith_trace_entity import (
|
||||
from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
LangSmithRunModel,
|
||||
LangSmithRunType,
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
@ -12,6 +12,7 @@ from mlflow.tracing.provider import detach_span_from_context, set_span_in_contex
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
@ -24,7 +25,6 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import JSON_DICT_ADAPTER
|
||||
from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user