mirror of
https://github.com/langgenius/dify.git
synced 2026-02-08 12:35:27 +08:00
Compare commits
116 Commits
fix/20326-
...
1.4.2
| Author | SHA1 | Date | |
|---|---|---|---|
| acb2488fc8 | |||
| d6d8cca053 | |||
| f601093ccc | |||
| 0f3d4d0b6e | |||
| 60777bc610 | |||
| 21a50e22d2 | |||
| fc6e2d14a5 | |||
| c439e82038 | |||
| a97ff587d2 | |||
| 91144207e0 | |||
| 0720bc7408 | |||
| ab62a9662c | |||
| d6a8af03b4 | |||
| 65c7c01d90 | |||
| e6e76852d5 | |||
| 930c4cb609 | |||
| 0c8447fd0e | |||
| 37c3283450 | |||
| 723b69cf8d | |||
| 85859b6723 | |||
| c1a13fa553 | |||
| 4f0c9fdf2b | |||
| 4271602cfc | |||
| 4f14d7c0ca | |||
| 38554c5f3e | |||
| 138ad6e8b3 | |||
| f76f70f0b6 | |||
| 7094680e23 | |||
| 59dc7c880e | |||
| 3fb9b41fe5 | |||
| 0ccf8cb23e | |||
| 837f769960 | |||
| 3367d4258d | |||
| d608be6e7f | |||
| de9c7f2ea4 | |||
| 1fbbbb735d | |||
| 9915a70d7f | |||
| 822298f69d | |||
| ad2f25875e | |||
| ad8e79c440 | |||
| f2dcfc976d | |||
| 5ccfb1f4ba | |||
| 92614765ff | |||
| 4f066454d0 | |||
| 7ae5819c67 | |||
| 2b0f3edcef | |||
| 244687c9a7 | |||
| d22c351221 | |||
| 006496f24e | |||
| 01d500db14 | |||
| 4ac3600f81 | |||
| 6aba223383 | |||
| f1c19cda74 | |||
| 275e86a26c | |||
| 077d627953 | |||
| ca0b268ae5 | |||
| 25be7c1ad5 | |||
| 888cd86afd | |||
| 157d916154 | |||
| e40e9db39a | |||
| 36f1b4b222 | |||
| 257bf13fef | |||
| 957f5b212e | |||
| 72fdafc180 | |||
| db83bfc53a | |||
| 744159a079 | |||
| d6b30efe2c | |||
| 3f7aa38d77 | |||
| a145c2a8fe | |||
| c29cb503be | |||
| 8025ad0661 | |||
| b4b59148dc | |||
| 23c9f1b444 | |||
| b33f8b47ca | |||
| c26e1929d6 | |||
| e01d975b80 | |||
| 92528360f9 | |||
| 1d9c90089c | |||
| e303417e04 | |||
| c8d9f8e2e4 | |||
| 51f64797cd | |||
| 1ea4459d9f | |||
| 55371e5abf | |||
| fb12a3033d | |||
| a6ea15e63c | |||
| 5a991295e0 | |||
| 9b47f9f786 | |||
| f65c2fcb1d | |||
| 156bb8238d | |||
| db488bef51 | |||
| d72d02b970 | |||
| dd2725be68 | |||
| 8e2d342de6 | |||
| 91eeb2ab76 | |||
| f2e0d161a1 | |||
| 2ebf4e767b | |||
| f7fb10635f | |||
| 32e779eef3 | |||
| 482e50aae9 | |||
| cd0a05f114 | |||
| d4408e0f54 | |||
| eee88a8012 | |||
| 0368e1769a | |||
| 2d4f8f1377 | |||
| 8ef91222ea | |||
| 808aa4467c | |||
| b2ab401279 | |||
| 9bbd646f40 | |||
| 57ece83c30 | |||
| c3c67d9608 | |||
| f59fb94dae | |||
| 00199c41bb | |||
| 400ae664bb | |||
| b39ca7ee31 | |||
| 4250501058 | |||
| eaaf551497 |
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
npm add -g pnpm@10.8.0
|
||||
npm add -g pnpm@10.11.1
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
|
||||
2
.github/actions/setup-uv/action.yml
vendored
2
.github/actions/setup-uv/action.yml
vendored
@ -8,7 +8,7 @@ inputs:
|
||||
uv-version:
|
||||
description: UV version to set up
|
||||
required: true
|
||||
default: '0.6.14'
|
||||
default: '~=0.7.11'
|
||||
uv-lockfile:
|
||||
description: Path to the UV lockfile to restore cache from
|
||||
required: true
|
||||
|
||||
20
.github/pull_request_template.md
vendored
20
.github/pull_request_template.md
vendored
@ -1,25 +1,23 @@
|
||||
# Summary
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> 1. Make sure you have read our [contribution guidelines](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)
|
||||
> 2. Ensure there is an associated issue and you have been assigned to it
|
||||
> 3. Use the correct syntax to link this PR: `Fixes #<issue number>`.
|
||||
|
||||
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
|
||||
## Summary
|
||||
|
||||
> [!Tip]
|
||||
> Close issue syntax: `Fixes #<issue number>` or `Resolves #<issue number>`, see [documentation](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword) for more details.
|
||||
<!-- Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. -->
|
||||
|
||||
|
||||
# Screenshots
|
||||
## Screenshots
|
||||
|
||||
| Before | After |
|
||||
|--------|-------|
|
||||
| ... | ... |
|
||||
|
||||
# Checklist
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Please review the checklist below before submitting your pull request.
|
||||
## Checklist
|
||||
|
||||
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
|
||||
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [x] I've updated the documentation accordingly.
|
||||
- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@ -192,12 +192,12 @@ sdks/python-client/dist
|
||||
sdks/python-client/dify_client.egg-info
|
||||
|
||||
.vscode/*
|
||||
!.vscode/launch.json
|
||||
!.vscode/launch.json.template
|
||||
!.vscode/README.md
|
||||
pyrightconfig.json
|
||||
api/.vscode
|
||||
|
||||
.idea/
|
||||
.vscode
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
@ -207,3 +207,6 @@ plugins.jsonl
|
||||
|
||||
# mise
|
||||
mise.toml
|
||||
|
||||
# Next.js build output
|
||||
.next/
|
||||
|
||||
14
.vscode/README.md
vendored
Normal file
14
.vscode/README.md
vendored
Normal file
@ -0,0 +1,14 @@
|
||||
# Debugging with VS Code
|
||||
|
||||
This `launch.json.template` file provides various debug configurations for the Dify project within VS Code / Cursor. To use these configurations, you should copy the contents of this file into a new file named `launch.json` in the same `.vscode` directory.
|
||||
|
||||
## How to Use
|
||||
|
||||
1. **Create `launch.json`**: If you don't have one, create a file named `launch.json` inside the `.vscode` directory.
|
||||
2. **Copy Content**: Copy the entire content from `launch.json.template` into your newly created `launch.json` file.
|
||||
3. **Select Debug Configuration**: Go to the Run and Debug view in VS Code / Cursor (Ctrl+Shift+D or Cmd+Shift+D).
|
||||
4. **Start Debugging**: Select the desired configuration from the dropdown menu and click the green play button.
|
||||
|
||||
## Tips
|
||||
|
||||
- If you need to debug with Edge browser instead of Chrome, modify the `serverReadyAction` configuration in the "Next.js: debug full stack" section, change `"debugWithChrome"` to `"debugWithEdge"` to use Microsoft Edge for debugging.
|
||||
68
.vscode/launch.json.template
vendored
Normal file
68
.vscode/launch.json.template
vendored
Normal file
@ -0,0 +1,68 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Flask API",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_ENV": "development",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
"--host=0.0.0.0",
|
||||
"--port=5001",
|
||||
"--no-debugger",
|
||||
"--no-reload"
|
||||
],
|
||||
"jinja": true,
|
||||
"justMyCode": true,
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
"python": "${workspaceFolder}/api/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "Python: Celery Worker (Solo)",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"env": {
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"app.celery",
|
||||
"worker",
|
||||
"-P",
|
||||
"solo",
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
"python": "${workspaceFolder}/api/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "Next.js: debug full stack",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/web/node_modules/next/dist/bin/next",
|
||||
"runtimeArgs": ["--inspect"],
|
||||
"skipFiles": ["<node_internals>/**"],
|
||||
"serverReadyAction": {
|
||||
"action": "debugWithChrome",
|
||||
"killOnServerStop": true,
|
||||
"pattern": "- Local:.+(https?://.+)",
|
||||
"uriFormat": "%s",
|
||||
"webRoot": "${workspaceFolder}/web"
|
||||
},
|
||||
"cwd": "${workspaceFolder}/web"
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -491,3 +491,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000
|
||||
|
||||
# Prevent Clickjacking
|
||||
ALLOW_EMBED=false
|
||||
|
||||
# Dataset queue monitor configuration
|
||||
QUEUE_MONITOR_THRESHOLD=200
|
||||
# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai
|
||||
QUEUE_MONITOR_ALERT_EMAILS=
|
||||
# Monitor interval in minutes, default is 30 minutes
|
||||
QUEUE_MONITOR_INTERVAL=30
|
||||
|
||||
@ -43,6 +43,7 @@ select = [
|
||||
"S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval`
|
||||
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
|
||||
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
||||
"S311", # suspicious-non-cryptographic-random-usage
|
||||
]
|
||||
|
||||
ignore = [
|
||||
|
||||
@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
|
||||
WORKDIR /app/api
|
||||
|
||||
# Install uv
|
||||
ENV UV_VERSION=0.6.14
|
||||
ENV UV_VERSION=0.7.11
|
||||
|
||||
RUN pip install --no-cache-dir uv==${UV_VERSION}
|
||||
|
||||
|
||||
@ -846,6 +846,9 @@ def clear_orphaned_file_records(force: bool):
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "outputs"},
|
||||
{"type": "text", "table": "conversations", "column": "introduction"},
|
||||
{"type": "text", "table": "conversations", "column": "system_instruction"},
|
||||
{"type": "text", "table": "accounts", "column": "avatar"},
|
||||
{"type": "text", "table": "apps", "column": "icon"},
|
||||
{"type": "text", "table": "sites", "column": "icon"},
|
||||
{"type": "json", "table": "messages", "column": "inputs"},
|
||||
{"type": "json", "table": "messages", "column": "message"},
|
||||
]
|
||||
|
||||
@ -2,7 +2,7 @@ import os
|
||||
from typing import Any, Literal, Optional
|
||||
from urllib.parse import parse_qsl, quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from .cache.redis_config import RedisConfig
|
||||
@ -256,6 +256,25 @@ class InternalTestConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class DatasetQueueMonitorConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Dataset Queue Monitor
|
||||
"""
|
||||
|
||||
QUEUE_MONITOR_THRESHOLD: Optional[NonNegativeInt] = Field(
|
||||
description="Threshold for dataset queue monitor",
|
||||
default=200,
|
||||
)
|
||||
QUEUE_MONITOR_ALERT_EMAILS: Optional[str] = Field(
|
||||
description="Emails for dataset queue monitor alert, separated by commas",
|
||||
default=None,
|
||||
)
|
||||
QUEUE_MONITOR_INTERVAL: Optional[NonNegativeFloat] = Field(
|
||||
description="Interval for dataset queue monitor in minutes",
|
||||
default=30,
|
||||
)
|
||||
|
||||
|
||||
class MiddlewareConfig(
|
||||
# place the configs in alphabet order
|
||||
CeleryConfig,
|
||||
@ -303,5 +322,6 @@ class MiddlewareConfig(
|
||||
BaiduVectorDBConfig,
|
||||
OpenGaussConfig,
|
||||
TableStoreConfig,
|
||||
DatasetQueueMonitorConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="1.4.1",
|
||||
default="1.4.2",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -60,8 +60,7 @@ class NacosHttpClient:
|
||||
sign_str = tenant + "+"
|
||||
if group:
|
||||
sign_str = sign_str + group + "+"
|
||||
if sign_str:
|
||||
sign_str += ts
|
||||
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
|
||||
return sign_str
|
||||
|
||||
def get_access_token(self, force_refresh=False):
|
||||
|
||||
@ -208,7 +208,7 @@ class AnnotationBatchImportApi(Resource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename.endswith(".csv"):
|
||||
if not file.filename or not file.filename.endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||
|
||||
|
||||
@ -6,12 +6,12 @@ from sqlalchemy.orm import Session
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRunStatus
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ class WorkflowAppLogApi(Resource):
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
|
||||
@ -119,9 +119,6 @@ class ForgotPasswordResetApi(Resource):
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
if reset_data.get("phase", "") != "reset":
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
if reset_data.get("phase", "") != "reset":
|
||||
raise InvalidTokenError()
|
||||
|
||||
|
||||
@ -374,7 +374,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename.endswith(".csv"):
|
||||
if not file.filename or not file.filename.endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
try:
|
||||
|
||||
@ -59,7 +59,14 @@ class InstalledAppsListApi(Resource):
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
user_id = current_user.id
|
||||
res = []
|
||||
app_ids = [installed_app["app"].id for installed_app in installed_app_list]
|
||||
webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids)
|
||||
for installed_app in installed_app_list:
|
||||
webapp_setting = webapp_settings.get(installed_app["app"].id)
|
||||
if not webapp_setting:
|
||||
continue
|
||||
if webapp_setting.access_mode == "sso_verified":
|
||||
continue
|
||||
app_code = AppService.get_app_code_by_id(str(installed_app["app"].id))
|
||||
if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id,
|
||||
|
||||
@ -44,6 +44,17 @@ def only_edition_cloud(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_enterprise(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
abort(404)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_self_hosted(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
|
||||
@ -29,7 +29,7 @@ from core.plugin.entities.request import (
|
||||
RequestRequestUploadFile,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from libs.helper import compact_generate_response
|
||||
from libs.helper import length_prefixed_response
|
||||
from models.account import Account, Tenant
|
||||
from models.model import EndUser
|
||||
|
||||
@ -44,7 +44,7 @@ class PluginInvokeLLMApi(Resource):
|
||||
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
class PluginInvokeTextEmbeddingApi(Resource):
|
||||
@ -101,7 +101,7 @@ class PluginInvokeTTSApi(Resource):
|
||||
)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
class PluginInvokeSpeech2TextApi(Resource):
|
||||
@ -162,7 +162,7 @@ class PluginInvokeToolApi(Resource):
|
||||
),
|
||||
)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
class PluginInvokeParameterExtractorNodeApi(Resource):
|
||||
@ -228,7 +228,7 @@ class PluginInvokeAppApi(Resource):
|
||||
files=payload.files,
|
||||
)
|
||||
|
||||
return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
|
||||
|
||||
class PluginInvokeEncryptApi(Resource):
|
||||
|
||||
@ -2,12 +2,14 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restful import reqparse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant
|
||||
from models.model import EndUser
|
||||
from services.account_service import AccountService
|
||||
@ -30,6 +32,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
|
||||
)
|
||||
session.add(user_model)
|
||||
session.commit()
|
||||
session.refresh(user_model)
|
||||
else:
|
||||
user_model = AccountService.load_user(user_id)
|
||||
if not user_model:
|
||||
@ -80,7 +83,12 @@ def get_user_tenant(view: Optional[Callable] = None):
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
kwargs["tenant_model"] = tenant_model
|
||||
kwargs["user_model"] = get_user(tenant_id, user_id)
|
||||
|
||||
user = get_user(tenant_id, user_id)
|
||||
kwargs["user_model"] = user
|
||||
|
||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
|
||||
@ -9,13 +9,13 @@ from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
)
|
||||
from libs.login import current_user
|
||||
from models.model import App, EndUser
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@validate_app_token
|
||||
def post(self, app_model: App, end_user: EndUser, action):
|
||||
def post(self, app_model: App, action):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
@ -32,7 +32,7 @@ class AnnotationReplyActionApi(Resource):
|
||||
|
||||
class AnnotationReplyActionStatusApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, end_user: EndUser, job_id, action):
|
||||
def get(self, app_model: App, job_id, action):
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
@ -50,7 +50,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
|
||||
class AnnotationListApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
def get(self, app_model: App):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
@ -67,7 +67,7 @@ class AnnotationListApi(Resource):
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
def post(self, app_model: App):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
@ -79,7 +79,7 @@ class AnnotationListApi(Resource):
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@validate_app_token
|
||||
@marshal_with(annotation_fields)
|
||||
def put(self, app_model: App, end_user: EndUser, annotation_id):
|
||||
def put(self, app_model: App, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -92,7 +92,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
return annotation
|
||||
|
||||
@validate_app_token
|
||||
def delete(self, app_model: App, end_user: EndUser, annotation_id):
|
||||
def delete(self, app_model: App, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
@ -24,12 +24,13 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun, WorkflowRunStatus
|
||||
from models.workflow import WorkflowRun
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
@ -138,7 +139,7 @@ class WorkflowAppLogApi(Resource):
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
|
||||
@ -1,19 +1,21 @@
|
||||
from flask import request
|
||||
from flask_restful import marshal, reqparse
|
||||
from flask_restful import marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from controllers.service_api.wraps import DatasetApiResource, validate_dataset_token
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import tag_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
@ -320,5 +322,135 @@ class DatasetApi(DatasetApiResource):
|
||||
raise DatasetInUseError()
|
||||
|
||||
|
||||
class DatasetTagsApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
@marshal_with(tag_fields)
|
||||
def get(self, _, dataset_id):
|
||||
"""Get all knowledge type tags."""
|
||||
tags = TagService.get_tags("knowledge", current_user.current_tenant_id)
|
||||
|
||||
return tags, 200
|
||||
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
"""Add a knowledge type tag."""
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=DatasetTagsApi._validate_tag_name,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.save_tags(args)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
||||
return response, 200
|
||||
|
||||
@validate_dataset_token
|
||||
def patch(self, _, dataset_id):
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=DatasetTagsApi._validate_tag_name,
|
||||
)
|
||||
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.update_tags(args, args.get("tag_id"))
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
return response, 200
|
||||
|
||||
@validate_dataset_token
|
||||
def delete(self, _, dataset_id):
|
||||
"""Delete a knowledge type tag."""
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
|
||||
args = parser.parse_args()
|
||||
TagService.delete_tag(args.get("tag_id"))
|
||||
|
||||
return 204
|
||||
|
||||
@staticmethod
|
||||
def _validate_tag_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 50:
|
||||
raise ValueError("Name must be between 1 to 50 characters.")
|
||||
return name
|
||||
|
||||
|
||||
class DatasetTagBindingApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
|
||||
)
|
||||
parser.add_argument(
|
||||
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.save_tag_binding(args)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
|
||||
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
|
||||
response = {"data": tags_list, "total": len(tags)}
|
||||
return response, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, "/datasets")
|
||||
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
|
||||
api.add_resource(DatasetTagsApi, "/datasets/tags")
|
||||
api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding")
|
||||
api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding")
|
||||
api.add_resource(DatasetTagsBindingStatusApi, "/datasets/<uuid:dataset_id>/tags")
|
||||
|
||||
@ -175,8 +175,11 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
if not dataset.indexing_technique and not args.get("indexing_technique"):
|
||||
|
||||
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
|
||||
if not indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
args["indexing_technique"] = indexing_technique
|
||||
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
@ -206,12 +209,16 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
|
||||
if not knowledge_config.original_document_id and not dataset_process_rule and not knowledge_config.process_rule:
|
||||
raise ValueError("process_rule is required.")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
dataset_process_rule=dataset_process_rule,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
|
||||
@ -208,6 +208,28 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
class ChildChunkApi(DatasetApiResource):
|
||||
"""Resource for child chunks."""
|
||||
|
||||
@ -15,4 +15,17 @@ api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
|
||||
from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow
|
||||
from . import (
|
||||
app,
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
feature,
|
||||
forgot_password,
|
||||
login,
|
||||
message,
|
||||
passport,
|
||||
saved_message,
|
||||
site,
|
||||
workflow,
|
||||
)
|
||||
|
||||
@ -10,6 +10,8 @@ from libs.passport import PassportService
|
||||
from models.model import App, AppMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
class AppParameterApi(WebApiResource):
|
||||
@ -46,10 +48,22 @@ class AppMeta(WebApiResource):
|
||||
class AppAccessMode(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("appId", type=str, required=True, location="args")
|
||||
parser.add_argument("appId", type=str, required=False, location="args")
|
||||
parser.add_argument("appCode", type=str, required=False, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_id = args["appId"]
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.webapp_auth.enabled:
|
||||
return {"accessMode": "public"}
|
||||
|
||||
app_id = args.get("appId")
|
||||
if args.get("appCode"):
|
||||
app_code = args["appCode"]
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
|
||||
if not app_id:
|
||||
raise ValueError("appId or appCode must be provided")
|
||||
|
||||
res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
|
||||
|
||||
return {"accessMode": res.access_mode}
|
||||
@ -75,6 +89,10 @@ class AppWebAuthPermission(Resource):
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.webapp_auth.enabled:
|
||||
return {"result": True}
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("appId", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
@ -82,7 +100,9 @@ class AppWebAuthPermission(Resource):
|
||||
app_id = args["appId"]
|
||||
app_code = AppService.get_app_code_by_id(app_id)
|
||||
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
|
||||
res = True
|
||||
if WebAppAuthService.is_app_require_permission_check(app_id=app_id):
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
|
||||
return {"result": res}
|
||||
|
||||
|
||||
|
||||
147
api/controllers/web/forgot_password.py
Normal file
147
api/controllers/web/forgot_password.py
Normal file
@ -0,0 +1,147 @@
|
||||
import base64
|
||||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
EmailPasswordResetLimitError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
||||
from controllers.web import api
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
token = None
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
token_data = AccountService.get_reset_password_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=args["code"], additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate passwords match
|
||||
if args["new_password"] != args["password_confirm"]:
|
||||
raise PasswordMismatchError()
|
||||
|
||||
# Validate token and get reset data
|
||||
reset_data = AccountService.get_reset_password_data(args["token"])
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
if reset_data.get("phase", "") != "reset":
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Revoke token to prevent reuse
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
|
||||
# Generate secure salt and hash password
|
||||
salt = secrets.token_bytes(16)
|
||||
password_hashed = hash_password(args["new_password"], salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
def _update_existing_account(self, account, password_hashed, salt, session):
|
||||
# Update existing account credentials
|
||||
account.password = base64.b64encode(password_hashed).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
session.commit()
|
||||
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
||||
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
||||
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
|
||||
@ -1,12 +1,11 @@
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from jwt import InvalidTokenError # type: ignore
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
import services
|
||||
from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError
|
||||
from controllers.console.error import AccountBannedError, AccountNotFound
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.console.wraps import only_edition_enterprise, setup_required
|
||||
from controllers.web import api
|
||||
from libs.helper import email
|
||||
from libs.password import valid_password
|
||||
from services.account_service import AccountService
|
||||
@ -16,6 +15,8 @@ from services.webapp_auth_service import WebAppAuthService
|
||||
class LoginApi(Resource):
|
||||
"""Resource for web app email/password login."""
|
||||
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = reqparse.RequestParser()
|
||||
@ -23,10 +24,6 @@ class LoginApi(Resource):
|
||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
if app_code is None:
|
||||
raise BadRequest("X-App-Code header is missing.")
|
||||
|
||||
try:
|
||||
account = WebAppAuthService.authenticate(args["email"], args["password"])
|
||||
except services.errors.account.AccountLoginError:
|
||||
@ -36,12 +33,8 @@ class LoginApi(Resource):
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
raise AccountNotFound()
|
||||
|
||||
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
|
||||
|
||||
end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code)
|
||||
|
||||
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
|
||||
return {"result": "success", "token": token}
|
||||
token = WebAppAuthService.login(account=account)
|
||||
return {"result": "success", "data": {"access_token": token}}
|
||||
|
||||
|
||||
# class LogoutApi(Resource):
|
||||
@ -56,6 +49,7 @@ class LoginApi(Resource):
|
||||
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
@ -78,6 +72,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
@ -86,9 +81,6 @@ class EmailCodeLoginApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
if app_code is None:
|
||||
raise BadRequest("X-App-Code header is missing.")
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
|
||||
if token_data is None:
|
||||
@ -105,16 +97,12 @@ class EmailCodeLoginApi(Resource):
|
||||
if not account:
|
||||
raise AccountNotFound()
|
||||
|
||||
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
|
||||
|
||||
end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code)
|
||||
|
||||
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
|
||||
token = WebAppAuthService.login(account=account)
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "token": token}
|
||||
return {"result": "success", "data": {"access_token": token}}
|
||||
|
||||
|
||||
# api.add_resource(LoginApi, "/login")
|
||||
api.add_resource(LoginApi, "/login")
|
||||
# api.add_resource(LogoutApi, "/logout")
|
||||
# api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||
# api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import api
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from extensions.ext_database import db
|
||||
@ -11,6 +13,7 @@ from libs.passport import PassportService
|
||||
from models.model import App, EndUser, Site
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
|
||||
|
||||
|
||||
class PassportResource(Resource):
|
||||
@ -20,10 +23,19 @@ class PassportResource(Resource):
|
||||
system_features = FeatureService.get_system_features()
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
user_id = request.args.get("user_id")
|
||||
web_app_access_token = request.args.get("web_app_access_token")
|
||||
|
||||
if app_code is None:
|
||||
raise Unauthorized("X-App-Code header is missing.")
|
||||
|
||||
# exchange token for enterprise logined web user
|
||||
enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token)
|
||||
if enterprise_user_decoded:
|
||||
# a web user has already logged in, exchange a token for this app without redirecting to the login page
|
||||
return exchange_token_for_existing_web_user(
|
||||
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
|
||||
)
|
||||
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
|
||||
if not app_settings or not app_settings.access_mode == "public":
|
||||
@ -84,6 +96,128 @@ class PassportResource(Resource):
|
||||
api.add_resource(PassportResource, "/passport")
|
||||
|
||||
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None):
|
||||
"""
|
||||
Decode the enterprise user session from the Authorization header.
|
||||
"""
|
||||
if not jwt_token:
|
||||
return None
|
||||
|
||||
decoded = PassportService().verify(jwt_token)
|
||||
source = decoded.get("token_source")
|
||||
if not source or source != "webapp_login_token":
|
||||
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
|
||||
return decoded
|
||||
|
||||
|
||||
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict):
|
||||
"""
|
||||
Exchange a token for an existing web user session.
|
||||
"""
|
||||
user_id = enterprise_user_decoded.get("user_id")
|
||||
end_user_id = enterprise_user_decoded.get("end_user_id")
|
||||
session_id = enterprise_user_decoded.get("session_id")
|
||||
user_auth_type = enterprise_user_decoded.get("auth_type")
|
||||
if not user_auth_type:
|
||||
raise Unauthorized("Missing auth_type in the token.")
|
||||
|
||||
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
|
||||
if not site:
|
||||
raise NotFound()
|
||||
|
||||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
|
||||
|
||||
if app_auth_type == WebAppAuthType.PUBLIC:
|
||||
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
|
||||
elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
|
||||
raise WebAppAuthRequiredError("Please login as external user.")
|
||||
elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
|
||||
raise WebAppAuthRequiredError("Please login as internal user.")
|
||||
|
||||
end_user = None
|
||||
if end_user_id:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if session_id:
|
||||
end_user = (
|
||||
db.session.query(EndUser)
|
||||
.filter(
|
||||
EndUser.session_id == session_id,
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not end_user:
|
||||
if not session_id:
|
||||
raise NotFound("Missing session_id for existing web user.")
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="browser",
|
||||
is_anonymous=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
|
||||
exp = int(exp_dt.timestamp())
|
||||
payload = {
|
||||
"iss": site.id,
|
||||
"sub": "Web API Passport",
|
||||
"app_id": site.app_id,
|
||||
"app_code": site.code,
|
||||
"user_id": user_id,
|
||||
"end_user_id": end_user.id,
|
||||
"auth_type": user_auth_type,
|
||||
"granted_at": int(datetime.now(UTC).timestamp()),
|
||||
"token_source": "webapp",
|
||||
"exp": exp,
|
||||
}
|
||||
token: str = PassportService().issue(payload)
|
||||
return {
|
||||
"access_token": token,
|
||||
}
|
||||
|
||||
|
||||
def _exchange_for_public_app_token(app_model, site, token_decoded):
|
||||
user_id = token_decoded.get("user_id")
|
||||
end_user = None
|
||||
if user_id:
|
||||
end_user = (
|
||||
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
|
||||
)
|
||||
|
||||
if not end_user:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="browser",
|
||||
is_anonymous=True,
|
||||
session_id=generate_session_id(),
|
||||
)
|
||||
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
payload = {
|
||||
"iss": site.app_id,
|
||||
"sub": "Web API Passport",
|
||||
"app_id": site.app_id,
|
||||
"app_code": site.code,
|
||||
"end_user_id": end_user.id,
|
||||
}
|
||||
|
||||
tk = PassportService().issue(payload)
|
||||
|
||||
return {
|
||||
"access_token": tk,
|
||||
}
|
||||
|
||||
|
||||
def generate_session_id():
|
||||
"""
|
||||
Generate a unique session ID.
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from datetime import UTC, datetime
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
@ -8,8 +9,9 @@ from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequire
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from models.model import App, EndUser, Site
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
def validate_jwt_token(view=None):
|
||||
@ -45,7 +47,8 @@ def decode_jwt_token():
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
decoded = PassportService().verify(tk)
|
||||
app_code = decoded.get("app_code")
|
||||
app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first()
|
||||
app_id = decoded.get("app_id")
|
||||
app_model = db.session.query(App).filter(App.id == app_id).first()
|
||||
site = db.session.query(Site).filter(Site.code == app_code).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
@ -53,23 +56,30 @@ def decode_jwt_token():
|
||||
raise BadRequest("Site URL is no longer valid.")
|
||||
if app_model.enable_site is False:
|
||||
raise BadRequest("Site is disabled.")
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
# for enterprise webapp auth
|
||||
app_web_auth_enabled = False
|
||||
webapp_settings = None
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_web_auth_enabled = (
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public"
|
||||
)
|
||||
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
|
||||
if not webapp_settings:
|
||||
raise NotFound("Web app settings not found.")
|
||||
app_web_auth_enabled = webapp_settings.access_mode != "public"
|
||||
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||
_validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||
_validate_user_accessibility(
|
||||
decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled, webapp_settings
|
||||
)
|
||||
|
||||
return app_model, end_user
|
||||
except Unauthorized as e:
|
||||
if system_features.webapp_auth.enabled:
|
||||
if not app_code:
|
||||
raise Unauthorized("Please re-login to access the web app.")
|
||||
app_web_auth_enabled = (
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public"
|
||||
)
|
||||
@ -95,15 +105,41 @@ def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_au
|
||||
raise Unauthorized("webapp token expired.")
|
||||
|
||||
|
||||
def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
|
||||
def _validate_user_accessibility(
|
||||
decoded,
|
||||
app_code,
|
||||
app_web_auth_enabled: bool,
|
||||
system_webapp_auth_enabled: bool,
|
||||
webapp_settings: WebAppSettings | None,
|
||||
):
|
||||
if system_webapp_auth_enabled and app_web_auth_enabled:
|
||||
# Check if the user is allowed to access the web app
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
raise WebAppAuthRequiredError()
|
||||
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
|
||||
raise WebAppAuthAccessDeniedError()
|
||||
if not webapp_settings:
|
||||
raise WebAppAuthRequiredError("Web app settings not found.")
|
||||
|
||||
if WebAppAuthService.is_app_require_permission_check(access_mode=webapp_settings.access_mode):
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
|
||||
raise WebAppAuthAccessDeniedError()
|
||||
|
||||
auth_type = decoded.get("auth_type")
|
||||
granted_at = decoded.get("granted_at")
|
||||
if not auth_type:
|
||||
raise WebAppAuthAccessDeniedError("Missing auth_type in the token.")
|
||||
if not granted_at:
|
||||
raise WebAppAuthAccessDeniedError("Missing granted_at in the token.")
|
||||
# check if sso has been updated
|
||||
if auth_type == "external":
|
||||
last_update_time = EnterpriseService.get_app_sso_settings_last_update_time()
|
||||
if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time:
|
||||
raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.")
|
||||
elif auth_type == "internal":
|
||||
last_update_time = EnterpriseService.get_workspace_sso_settings_last_update_time()
|
||||
if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time:
|
||||
raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.")
|
||||
|
||||
|
||||
class WebApiResource(Resource):
|
||||
|
||||
@ -63,7 +63,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
@ -82,7 +82,7 @@ class AgentEntity(BaseModel):
|
||||
strategy: Strategy
|
||||
prompt: Optional[AgentPromptEntity] = None
|
||||
tools: Optional[list[AgentToolEntity]] = None
|
||||
max_iteration: int = 5
|
||||
max_iteration: int = 10
|
||||
|
||||
|
||||
class AgentInvokeMessage(ToolInvokeMessage):
|
||||
|
||||
@ -48,7 +48,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
assert app_config.agent
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
|
||||
@ -75,7 +75,7 @@ class AgentConfigManager:
|
||||
strategy=strategy,
|
||||
prompt=agent_prompt_entity,
|
||||
tools=agent_tools,
|
||||
max_iteration=agent_dict.get("max_iteration", 5),
|
||||
max_iteration=agent_dict.get("max_iteration", 10),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@ -70,7 +70,7 @@ class ModelConfigConverter:
|
||||
if not model_mode:
|
||||
model_mode = LLMMode.CHAT.value
|
||||
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
||||
model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value
|
||||
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
@ -27,8 +27,8 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
@ -140,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
|
||||
SystemVariableKey.APP_ID: app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id,
|
||||
}
|
||||
|
||||
# init variable pool
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
@ -57,26 +56,23 @@ from core.app.entities.task_entities import (
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models import Conversation, EndUser, Message, MessageFile
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -126,8 +122,14 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
workflow_info=CycleManagerWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
@ -137,7 +139,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._message_cycle_manager = MessageCycleManage(
|
||||
self._message_cycle_manager = MessageCycleManager(
|
||||
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||
)
|
||||
|
||||
@ -158,7 +160,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
:return:
|
||||
"""
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||
)
|
||||
|
||||
@ -302,15 +304,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
)
|
||||
self._workflow_run_id = workflow_execution.id
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
message = self._get_message(session=session)
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {self._message_id}")
|
||||
message.workflow_run_id = workflow_execution.id
|
||||
message.workflow_run_id = workflow_execution.id_
|
||||
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
@ -550,7 +549,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
status=WorkflowExecutionStatus.FAILED,
|
||||
error_message=event.error,
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
@ -576,7 +575,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
status=WorkflowExecutionStatus.STOPPED,
|
||||
error_message=event.get_stop_reason(),
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
@ -604,22 +603,18 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._message_cycle_manager._handle_retriever_resources(event)
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
session.commit()
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
self._message_cycle_manager._handle_annotation_reply(event)
|
||||
self._message_cycle_manager.handle_annotation_reply(event)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
session.commit()
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
@ -636,12 +631,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_cycle_manager._message_to_stream_response(
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
# published by moderation
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=event.text, reason=event.reason
|
||||
)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
@ -653,7 +648,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=output_moderation_answer,
|
||||
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
||||
)
|
||||
@ -682,9 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
message = self._get_message(session=session)
|
||||
message.answer = self._task_state.answer
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=message.id,
|
||||
@ -712,9 +705,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
message.answer_price_unit = usage.completion_price_unit
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.metadata["usage"] = jsonable_encoder(usage)
|
||||
self._task_state.metadata.usage = usage
|
||||
else:
|
||||
self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage())
|
||||
self._task_state.metadata.usage = LLMUsage.empty_usage()
|
||||
message_was_created.send(
|
||||
message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
@ -725,18 +718,16 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
Message end to stream response.
|
||||
:return:
|
||||
"""
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata.copy()
|
||||
extras = self._task_state.metadata.model_dump()
|
||||
|
||||
if "annotation_reply" in extras["metadata"]:
|
||||
del extras["metadata"]["annotation_reply"]
|
||||
if self._task_state.metadata.annotation_reply:
|
||||
del extras["annotation_reply"]
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message_id,
|
||||
files=self._recorded_files,
|
||||
metadata=extras.get("metadata", {}),
|
||||
metadata=extras,
|
||||
)
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
|
||||
@ -44,15 +44,14 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_execution_entities import NodeExecution
|
||||
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
@ -73,11 +72,10 @@ class WorkflowResponseConverter:
|
||||
) -> WorkflowStartStreamResponse:
|
||||
return WorkflowStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution.id,
|
||||
workflow_run_id=workflow_execution.id_,
|
||||
data=WorkflowStartStreamResponse.Data(
|
||||
id=workflow_execution.id,
|
||||
id=workflow_execution.id_,
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
sequence_number=workflow_execution.sequence_number,
|
||||
inputs=workflow_execution.inputs,
|
||||
created_at=int(workflow_execution.started_at.timestamp()),
|
||||
),
|
||||
@ -91,7 +89,7 @@ class WorkflowResponseConverter:
|
||||
workflow_execution: WorkflowExecution,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
created_by = None
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
|
||||
assert workflow_run is not None
|
||||
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
|
||||
stmt = select(Account).where(Account.id == workflow_run.created_by)
|
||||
@ -122,11 +120,10 @@ class WorkflowResponseConverter:
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution.id,
|
||||
workflow_run_id=workflow_execution.id_,
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=workflow_execution.id,
|
||||
id=workflow_execution.id_,
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
sequence_number=workflow_execution.sequence_number,
|
||||
status=workflow_execution.status,
|
||||
outputs=workflow_execution.outputs,
|
||||
error=workflow_execution.error_message,
|
||||
@ -146,16 +143,16 @@ class WorkflowResponseConverter:
|
||||
*,
|
||||
event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: NodeExecution,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeStartStreamResponse]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
return None
|
||||
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
@ -196,18 +193,18 @@ class WorkflowResponseConverter:
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: NodeExecution,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
return None
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
@ -239,18 +236,18 @@ class WorkflowResponseConverter:
|
||||
*,
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: NodeExecution,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
return None
|
||||
|
||||
return NodeRetryStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
data=NodeRetryStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
|
||||
@ -25,8 +25,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
@ -132,7 +132,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
)
|
||||
|
||||
contexts.plugin_tool_providers.set({})
|
||||
@ -279,7 +279,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
@ -355,7 +355,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
@ -95,7 +95,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.APP_ID: app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
|
||||
}
|
||||
|
||||
variable_pool = VariablePool(
|
||||
|
||||
@ -50,16 +50,15 @@ from core.app.entities.task_entities import (
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
@ -69,7 +68,6 @@ from models.workflow import (
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -114,8 +112,14 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id,
|
||||
},
|
||||
workflow_info=CycleManagerWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
@ -125,9 +129,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
)
|
||||
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._workflow_run_id = ""
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
@ -266,17 +268,13 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
)
|
||||
self._workflow_run_id = workflow_execution.id
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
|
||||
yield start_resp
|
||||
elif isinstance(
|
||||
@ -511,9 +509,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED
|
||||
status=WorkflowExecutionStatus.FAILED
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else WorkflowRunStatus.STOPPED,
|
||||
else WorkflowExecutionStatus.STOPPED,
|
||||
error_message=event.error
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else event.get_stop_reason(),
|
||||
@ -542,7 +540,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(
|
||||
delta_text, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
@ -557,7 +554,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
|
||||
assert workflow_run is not None
|
||||
invoke_from = self._application_generate_entity.invoke_from
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
|
||||
@ -29,8 +29,8 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
AgentLogEvent,
|
||||
GraphEngineEvent,
|
||||
@ -295,7 +295,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
inputs: Mapping[str, Any] | None = {}
|
||||
process_data: Mapping[str, Any] | None = {}
|
||||
outputs: Mapping[str, Any] | None = {}
|
||||
execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {}
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
|
||||
@ -76,6 +76,8 @@ class AppGenerateEntity(BaseModel):
|
||||
App Generate Entity.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
task_id: str
|
||||
|
||||
# app config
|
||||
@ -99,9 +101,6 @@ class AppGenerateEntity(BaseModel):
|
||||
# tracing instance
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
@ -205,7 +204,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
workflow_run_id: str
|
||||
workflow_execution_id: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
@ -6,7 +6,9 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
@ -282,7 +284,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
|
||||
retriever_resources: list[dict]
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata]
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
@ -412,7 +414,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
"""single iteration duration map"""
|
||||
@ -446,7 +448,7 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
retry_index: int # retry index
|
||||
@ -480,7 +482,7 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
@ -513,7 +515,7 @@ class QueueNodeInLoopFailedEvent(AppQueueEvent):
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
@ -546,7 +548,7 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
@ -579,7 +581,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
@ -2,12 +2,29 @@ from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class AnnotationReply(BaseModel):
|
||||
id: str
|
||||
account: AnnotationReplyAccount
|
||||
|
||||
|
||||
class TaskStateMetadata(BaseModel):
|
||||
annotation_reply: AnnotationReply | None = None
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list)
|
||||
usage: LLMUsage | None = None
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
@ -15,7 +32,7 @@ class TaskState(BaseModel):
|
||||
TaskState entity
|
||||
"""
|
||||
|
||||
metadata: dict = {}
|
||||
metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata)
|
||||
|
||||
|
||||
class EasyUITaskState(TaskState):
|
||||
@ -189,7 +206,6 @@ class WorkflowStartStreamResponse(StreamResponse):
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
sequence_number: int
|
||||
inputs: Mapping[str, Any]
|
||||
created_at: int
|
||||
|
||||
@ -210,7 +226,6 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
sequence_number: int
|
||||
status: str
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
@ -307,7 +322,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
@ -376,7 +391,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
@ -43,7 +42,7 @@ from core.app.entities.task_entities import (
|
||||
StreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@ -51,7 +50,6 @@ from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
@ -63,7 +61,7 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage):
|
||||
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
"""
|
||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
@ -104,6 +102,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
)
|
||||
)
|
||||
|
||||
self._message_cycle_manager = MessageCycleManager(
|
||||
application_generate_entity=application_generate_entity,
|
||||
task_state=self._task_state,
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread: Optional[Thread] = None
|
||||
|
||||
def process(
|
||||
@ -115,7 +118,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
]:
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
|
||||
)
|
||||
|
||||
@ -136,9 +139,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
|
||||
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
extras["metadata"] = self._task_state.metadata.model_dump()
|
||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
||||
if self._conversation_mode == AppMode.COMPLETION.value:
|
||||
response = CompletionAppBlockingResponse(
|
||||
@ -277,7 +280,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=output_moderation_answer
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Save message
|
||||
@ -286,9 +291,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
message_end_resp = self._message_end_to_stream_response()
|
||||
yield message_end_resp
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._handle_retriever_resources(event)
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
annotation = self._handle_annotation_reply(event)
|
||||
annotation = self._message_cycle_manager.handle_annotation_reply(event)
|
||||
if annotation:
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
@ -296,7 +301,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if agent_thought_response is not None:
|
||||
yield agent_thought_response
|
||||
elif isinstance(event, QueueMessageFileEvent):
|
||||
response = self._message_file_to_stream_response(event)
|
||||
response = self._message_cycle_manager.message_file_to_stream_response(event)
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
|
||||
@ -318,7 +323,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
yield self._message_to_stream_response(
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
)
|
||||
@ -328,7 +333,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
message_id=self._message_id,
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
@ -372,9 +377,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
@ -423,16 +426,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
Message end to stream response.
|
||||
:return:
|
||||
"""
|
||||
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)
|
||||
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
self._task_state.metadata.usage = self._task_state.llm_result.usage
|
||||
metadata_dict = self._task_state.metadata.model_dump()
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message_id,
|
||||
metadata=extras.get("metadata", {}),
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
@ -455,8 +454,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
db.session.refresh(agent_thought)
|
||||
db.session.close()
|
||||
|
||||
if agent_thought:
|
||||
return AgentThoughtStreamResponse(
|
||||
|
||||
@ -17,6 +17,8 @@ from core.app.entities.queue_entities import (
|
||||
QueueRetrieverResourcesEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AnnotationReply,
|
||||
AnnotationReplyAccount,
|
||||
EasyUITaskState,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
@ -30,7 +32,7 @@ from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
class MessageCycleManage:
|
||||
class MessageCycleManager:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -45,7 +47,7 @@ class MessageCycleManage:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._task_state = task_state
|
||||
|
||||
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||
"""
|
||||
Generate conversation name.
|
||||
:param conversation_id: conversation id
|
||||
@ -102,7 +104,7 @@ class MessageCycleManage:
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
|
||||
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Handle annotation reply.
|
||||
:param event: event
|
||||
@ -111,25 +113,28 @@ class MessageCycleManage:
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata["annotation_reply"] = {
|
||||
"id": annotation.id,
|
||||
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
|
||||
}
|
||||
self._task_state.metadata.annotation_reply = AnnotationReply(
|
||||
id=annotation.id,
|
||||
account=AnnotationReplyAccount(
|
||||
id=annotation.account_id,
|
||||
name=account.name if account else "Dify user",
|
||||
),
|
||||
)
|
||||
|
||||
return annotation
|
||||
|
||||
return None
|
||||
|
||||
def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
|
||||
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
|
||||
"""
|
||||
Handle retriever resources.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata["retriever_resources"] = event.retriever_resources
|
||||
self._task_state.metadata.retriever_resources = event.retriever_resources
|
||||
|
||||
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
"""
|
||||
Message file to stream response.
|
||||
:param event: event
|
||||
@ -166,7 +171,7 @@ class MessageCycleManage:
|
||||
|
||||
return None
|
||||
|
||||
def _message_to_stream_response(
|
||||
def message_to_stream_response(
|
||||
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
@ -182,7 +187,7 @@ class MessageCycleManage:
|
||||
from_variable_selector=from_variable_selector,
|
||||
)
|
||||
|
||||
def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
"""
|
||||
Message replace to stream response.
|
||||
:param answer: answer
|
||||
@ -1,8 +1,10 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
@ -85,7 +87,8 @@ class DatasetIndexToolCallbackHandler:
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def return_retriever_resource_info(self, resource: list):
|
||||
# TODO(-LAN-): Improve type check
|
||||
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
self._queue_manager.publish(
|
||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||
|
||||
@ -55,6 +55,25 @@ class ProviderModelWithStatusEntity(ProviderModel):
|
||||
status: ModelStatus
|
||||
load_balancing_enabled: bool = False
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
"""
|
||||
Check model status and raise ValueError if not active.
|
||||
|
||||
:raises ValueError: When model status is not active, with a descriptive message
|
||||
"""
|
||||
if self.status == ModelStatus.ACTIVE:
|
||||
return
|
||||
|
||||
error_messages = {
|
||||
ModelStatus.NO_CONFIGURE: "Model is not configured",
|
||||
ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded",
|
||||
ModelStatus.NO_PERMISSION: "No permission to use this model",
|
||||
ModelStatus.DISABLED: "Model is disabled",
|
||||
}
|
||||
|
||||
if self.status in error_messages:
|
||||
raise ValueError(error_messages[self.status])
|
||||
|
||||
|
||||
class ModelWithProviderEntity(ProviderModelWithStatusEntity):
|
||||
"""
|
||||
|
||||
@ -41,45 +41,53 @@ class Extensible:
|
||||
extensions = []
|
||||
position_map: dict[str, int] = {}
|
||||
|
||||
# get the path of the current class
|
||||
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
|
||||
current_dir_path = os.path.dirname(current_path)
|
||||
# Get the package name from the module path
|
||||
package_name = ".".join(cls.__module__.split(".")[:-1])
|
||||
|
||||
# traverse subdirectories
|
||||
for subdir_name in os.listdir(current_dir_path):
|
||||
if subdir_name.startswith("__"):
|
||||
continue
|
||||
try:
|
||||
# Get package directory path
|
||||
package_spec = importlib.util.find_spec(package_name)
|
||||
if not package_spec or not package_spec.origin:
|
||||
raise ImportError(f"Could not find package {package_name}")
|
||||
|
||||
subdir_path = os.path.join(current_dir_path, subdir_name)
|
||||
extension_name = subdir_name
|
||||
if os.path.isdir(subdir_path):
|
||||
package_dir = os.path.dirname(package_spec.origin)
|
||||
|
||||
# Traverse subdirectories
|
||||
for subdir_name in os.listdir(package_dir):
|
||||
if subdir_name.startswith("__"):
|
||||
continue
|
||||
|
||||
subdir_path = os.path.join(package_dir, subdir_name)
|
||||
if not os.path.isdir(subdir_path):
|
||||
continue
|
||||
|
||||
extension_name = subdir_name
|
||||
file_names = os.listdir(subdir_path)
|
||||
|
||||
# is builtin extension, builtin extension
|
||||
# in the front-end page and business logic, there are special treatments.
|
||||
# Check for extension module file
|
||||
if (extension_name + ".py") not in file_names:
|
||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
# Check for builtin flag and position
|
||||
builtin = False
|
||||
# default position is 0 can not be None for sort_to_dict_by_position_map
|
||||
position = 0
|
||||
if "__builtin__" in file_names:
|
||||
builtin = True
|
||||
|
||||
builtin_file_path = os.path.join(subdir_path, "__builtin__")
|
||||
if os.path.exists(builtin_file_path):
|
||||
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
|
||||
position_map[extension_name] = position
|
||||
|
||||
if (extension_name + ".py") not in file_names:
|
||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
|
||||
py_path = os.path.join(subdir_path, extension_name + ".py")
|
||||
spec = importlib.util.spec_from_file_location(extension_name, py_path)
|
||||
# Import the extension module
|
||||
module_name = f"{package_name}.{extension_name}.{extension_name}"
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if not spec or not spec.loader:
|
||||
raise Exception(f"Failed to load module {extension_name} from {py_path}")
|
||||
raise ImportError(f"Failed to load module {module_name}")
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
# Find extension class
|
||||
extension_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
|
||||
@ -87,21 +95,21 @@ class Extensible:
|
||||
break
|
||||
|
||||
if not extension_class:
|
||||
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
|
||||
logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.")
|
||||
continue
|
||||
|
||||
# Load schema if not builtin
|
||||
json_data: dict[str, Any] = {}
|
||||
if not builtin:
|
||||
if "schema.json" not in file_names:
|
||||
json_path = os.path.join(subdir_path, "schema.json")
|
||||
if not os.path.exists(json_path):
|
||||
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
json_path = os.path.join(subdir_path, "schema.json")
|
||||
json_data = {}
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
json_data = json.load(f)
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
# Create extension
|
||||
extensions.append(
|
||||
ModuleExtension(
|
||||
extension_class=extension_class,
|
||||
@ -113,6 +121,11 @@ class Extensible:
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Error scanning extensions")
|
||||
raise
|
||||
|
||||
# Sort extensions by position
|
||||
sorted_extensions = sort_to_dict_by_position_map(
|
||||
position_map=position_map, data=extensions, name_func=lambda x: x.name
|
||||
)
|
||||
|
||||
@ -15,6 +15,7 @@ from core.helper.code_executor.python3.python3_transformer import Python3Templat
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT))
|
||||
|
||||
|
||||
class CodeExecutionError(Exception):
|
||||
@ -64,7 +65,7 @@ class CodeExecutor:
|
||||
:param code: code
|
||||
:return:
|
||||
"""
|
||||
url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run"
|
||||
url = code_execution_endpoint_url / "v1" / "sandbox" / "run"
|
||||
|
||||
headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY}
|
||||
|
||||
|
||||
@ -7,29 +7,28 @@ from configs import dify_config
|
||||
from core.helper.download import download_with_size_limit
|
||||
from core.plugin.entities.marketplace import MarketplacePluginDeclaration
|
||||
|
||||
marketplace_api_url = URL(str(dify_config.MARKETPLACE_API_URL))
|
||||
|
||||
def get_plugin_pkg_url(plugin_unique_identifier: str):
|
||||
return (URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/download").with_query(
|
||||
unique_identifier=plugin_unique_identifier
|
||||
)
|
||||
|
||||
def get_plugin_pkg_url(plugin_unique_identifier: str) -> str:
|
||||
return str((marketplace_api_url / "api/v1/plugins/download").with_query(unique_identifier=plugin_unique_identifier))
|
||||
|
||||
|
||||
def download_plugin_pkg(plugin_unique_identifier: str):
|
||||
url = str(get_plugin_pkg_url(plugin_unique_identifier))
|
||||
return download_with_size_limit(url, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
return download_with_size_limit(get_plugin_pkg_url(plugin_unique_identifier), dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]:
|
||||
if len(plugin_ids) == 0:
|
||||
return []
|
||||
|
||||
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/batch")
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
|
||||
|
||||
def record_install_plugin_event(plugin_unique_identifier: str):
|
||||
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/stats/plugins/install_count")
|
||||
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
|
||||
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response.raise_for_status()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
import random
|
||||
import secrets
|
||||
from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
@ -38,7 +38,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
|
||||
if len(text_chunks) == 0:
|
||||
return True
|
||||
|
||||
text_chunk = random.choice(text_chunks)
|
||||
text_chunk = secrets.choice(text_chunks)
|
||||
|
||||
try:
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
|
||||
@ -1,61 +1,20 @@
|
||||
# Written by YORKI MINAKO🤡, Edited by Xiaoyi
|
||||
CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is.
|
||||
Notice: the language type user uses could be diverse, which can be English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.
|
||||
ENSURE your output is in the SAME language as the user's input!
|
||||
Your output is restricted only to: (Input language) Intention + Subject(short as possible)
|
||||
Your output MUST be a valid JSON.
|
||||
# Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh
|
||||
CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the user’s input into two parts: “Intention” and “Subject”.
|
||||
|
||||
Tip: When the user's question is directed at you (the language model), you can add an emoji to make it more fun.
|
||||
1. Detect Input Language
|
||||
Automatically identify the language of the user’s input (e.g. English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.).
|
||||
|
||||
2. Generate Title
|
||||
- Combine Intention + Subject into a single, as-short-as-possible phrase.
|
||||
- The title must be natural, friendly, and in the same language as the input.
|
||||
- If the input is a direct question to the model, you may add an emoji at the end.
|
||||
|
||||
example 1:
|
||||
User Input: hi, yesterday i had some burgers.
|
||||
3. Output Format
|
||||
Return **only** a valid JSON object with these exact keys and no additional text:
|
||||
{
|
||||
"Language Type": "The user's input is pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "sharing yesterday's food"
|
||||
}
|
||||
|
||||
example 2:
|
||||
User Input: hello
|
||||
{
|
||||
"Language Type": "The user's input is pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "Greeting myself☺️"
|
||||
}
|
||||
|
||||
|
||||
example 3:
|
||||
User Input: why mmap file: oom
|
||||
{
|
||||
"Language Type": "The user's input is written in pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "Asking about the reason for mmap file: oom"
|
||||
}
|
||||
|
||||
|
||||
example 4:
|
||||
User Input: www.convinceme.yesterday-you-ate-seafood.tv讲了什么?
|
||||
{
|
||||
"Language Type": "The user's input English-Chinese mixed",
|
||||
"Your Reasoning": "The English-part is an URL, the main intention is still written in Chinese, so the language of my output must be using Chinese.",
|
||||
"Your Output": "询问网站www.convinceme.yesterday-you-ate-seafood.tv"
|
||||
}
|
||||
|
||||
example 5:
|
||||
User Input: why小红的年龄is老than小明?
|
||||
{
|
||||
"Language Type": "The user's input is English-Chinese mixed",
|
||||
"Your Reasoning": "The English parts are filler words, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.",
|
||||
"Your Output": "询问小红和小明的年龄"
|
||||
}
|
||||
|
||||
example 6:
|
||||
User Input: yo, 你今天咋样?
|
||||
{
|
||||
"Language Type": "The user's input is English-Chinese mixed",
|
||||
"Your Reasoning": "The English-part is a subjective particle, the main intention is written in Chinese, so the language of my output must be using Chinese.",
|
||||
"Your Output": "查询今日我的状态☺️"
|
||||
"Language Type": "<Detected language>",
|
||||
"Your Reasoning": "<Brief explanation in that language>",
|
||||
"Your Output": "<Intention + Subject>"
|
||||
}
|
||||
|
||||
User Input:
|
||||
|
||||
@ -17,19 +17,6 @@ class LLMMode(StrEnum):
|
||||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "LLMMode":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class LLMUsage(ModelUsage):
|
||||
"""
|
||||
|
||||
@ -160,6 +160,10 @@ class ProviderModel(BaseModel):
|
||||
deprecated: bool = False
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def support_structure_output(self) -> bool:
|
||||
return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features
|
||||
|
||||
|
||||
class ParameterRule(BaseModel):
|
||||
"""
|
||||
|
||||
@ -129,17 +129,18 @@ def jsonable_encoder(
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
if dataclasses.is_dataclass(obj):
|
||||
# FIXME: mypy error, try to fix it instead of using type: ignore
|
||||
obj_dict = dataclasses.asdict(obj) # type: ignore
|
||||
return jsonable_encoder(
|
||||
obj_dict,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
# Ensure obj is a dataclass instance, not a dataclass type
|
||||
if not isinstance(obj, type):
|
||||
obj_dict = dataclasses.asdict(obj)
|
||||
return jsonable_encoder(
|
||||
obj_dict,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
if isinstance(obj, Enum):
|
||||
return obj.value
|
||||
if isinstance(obj, PurePath):
|
||||
|
||||
@ -1,7 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.entities.trace_entity import BaseTraceInfo
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, TenantAccountJoin
|
||||
|
||||
|
||||
class BaseTraceInstance(ABC):
|
||||
@ -24,3 +28,38 @@ class BaseTraceInstance(ABC):
|
||||
Subclasses must implement specific tracing logic for activities.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_service_account_with_tenant(self, app_id: str) -> Account:
|
||||
"""
|
||||
Get service account for an app and set up its tenant.
|
||||
|
||||
Args:
|
||||
app_id: The ID of the app
|
||||
|
||||
Returns:
|
||||
Account: The service account with tenant set up
|
||||
|
||||
Raises:
|
||||
ValueError: If app, creator account or tenant cannot be found
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
|
||||
current_tenant = (
|
||||
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
|
||||
)
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||
service_account.set_tenant_id(current_tenant.tenant_id)
|
||||
|
||||
return service_account
|
||||
|
||||
@ -98,6 +98,7 @@ class WeaveConfig(BaseTracingConfig):
|
||||
entity: str | None = None
|
||||
project: str
|
||||
endpoint: str = "https://trace.wandb.ai"
|
||||
host: str | None = None
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
@ -109,6 +110,14 @@ class WeaveConfig(BaseTracingConfig):
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def validate_host(cls, v, info: ValidationInfo):
|
||||
if v is not None and v != "":
|
||||
if not v.startswith(("https://", "http://")):
|
||||
raise ValueError("host must start with https:// or http://")
|
||||
return v
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||
|
||||
@ -3,7 +3,7 @@ from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
||||
|
||||
|
||||
class BaseTraceInfo(BaseModel):
|
||||
@ -24,10 +24,13 @@ class BaseTraceInfo(BaseModel):
|
||||
return v
|
||||
return ""
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat(),
|
||||
}
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_serializer("start_time", "end_time")
|
||||
def serialize_datetime(self, dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
class WorkflowTraceInfo(BaseTraceInfo):
|
||||
|
||||
@ -4,7 +4,7 @@ from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from langfuse import Langfuse # type: ignore
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
@ -31,7 +31,7 @@ from core.ops.utils import filter_none_values
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -114,22 +114,11 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Optional, cast
|
||||
|
||||
from langsmith import Client
|
||||
from langsmith.schemas import RunBase
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
@ -28,10 +28,10 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -139,22 +139,11 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
@ -185,7 +174,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
metadata = {str(key): value for key, value in execution_metadata.items()}
|
||||
metadata.update(
|
||||
{
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Optional, cast
|
||||
|
||||
from opik import Opik, Trace
|
||||
from opik.id_helpers import uuid4_to_uuid7
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
@ -22,10 +22,10 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -154,22 +154,11 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
@ -246,7 +235,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
||||
|
||||
if not total_tokens:
|
||||
total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||
total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
|
||||
span_data = {
|
||||
"trace_id": opik_trace_id,
|
||||
|
||||
@ -30,7 +30,7 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import get_message_data
|
||||
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
@ -81,7 +81,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
|
||||
return {
|
||||
"config_class": WeaveConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "entity", "endpoint"],
|
||||
"other_keys": ["project", "entity", "endpoint", "host"],
|
||||
"trace_instance": WeaveDataTrace,
|
||||
}
|
||||
|
||||
@ -386,7 +386,7 @@ class TraceTask:
|
||||
):
|
||||
self.trace_type = trace_type
|
||||
self.message_id = message_id
|
||||
self.workflow_run_id = workflow_execution.id if workflow_execution else None
|
||||
self.workflow_run_id = workflow_execution.id_ if workflow_execution else None
|
||||
self.conversation_id = conversation_id
|
||||
self.user_id = user_id
|
||||
self.timer = timer
|
||||
@ -487,6 +487,7 @@ class TraceTask:
|
||||
"file_list": file_list,
|
||||
"triggered_from": workflow_run.triggered_from,
|
||||
"user_id": user_id,
|
||||
"app_id": workflow_run.app_id,
|
||||
}
|
||||
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Any, Optional, cast
|
||||
|
||||
import wandb
|
||||
import weave
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
@ -23,10 +23,10 @@ from core.ops.entities.trace_entity import (
|
||||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -40,9 +40,14 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
self.weave_api_key = weave_config.api_key
|
||||
self.project_name = weave_config.project
|
||||
self.entity = weave_config.entity
|
||||
self.host = weave_config.host
|
||||
|
||||
# Login with API key first, including host if provided
|
||||
if self.host:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
|
||||
else:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
|
||||
# Login with API key first
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
if not login_status:
|
||||
logger.error("Failed to login to Weights & Biases with the provided API key")
|
||||
raise ValueError("Weave login failed")
|
||||
@ -133,22 +138,11 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
@ -179,7 +173,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
attributes = {str(k): v for k, v in execution_metadata.items()}
|
||||
attributes.update(
|
||||
{
|
||||
@ -397,7 +391,11 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
if self.host:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
|
||||
else:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
|
||||
if not login_status:
|
||||
raise ValueError("Weave login failed")
|
||||
else:
|
||||
|
||||
@ -11,14 +11,12 @@ class BaseBackwardsInvocation:
|
||||
try:
|
||||
for chunk in response:
|
||||
if isinstance(chunk, BaseModel | dict):
|
||||
yield BaseBackwardsInvocationResponse(data=chunk).model_dump_json().encode() + b"\n\n"
|
||||
elif isinstance(chunk, str):
|
||||
yield f"event: {chunk}\n\n".encode()
|
||||
yield BaseBackwardsInvocationResponse(data=chunk).model_dump_json().encode()
|
||||
except Exception as e:
|
||||
error_message = BaseBackwardsInvocationResponse(error=str(e)).model_dump_json()
|
||||
yield f"{error_message}\n\n".encode()
|
||||
yield error_message.encode()
|
||||
else:
|
||||
yield BaseBackwardsInvocationResponse(data=response).model_dump_json().encode() + b"\n\n"
|
||||
yield BaseBackwardsInvocationResponse(data=response).model_dump_json().encode()
|
||||
|
||||
|
||||
T = TypeVar("T", bound=dict | Mapping | str | bool | int | BaseModel)
|
||||
|
||||
@ -21,7 +21,7 @@ from core.plugin.entities.request import (
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
@ -55,20 +55,21 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
def handle() -> Generator[LLMResultChunk, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
LLMNode.deduct_llm_quota(
|
||||
llm_utils.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=response.model,
|
||||
prompt_messages=response.prompt_messages,
|
||||
prompt_messages=[],
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import TypeVar
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from requests.exceptions import HTTPError
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
@ -30,8 +31,7 @@ from core.plugin.impl.exc import (
|
||||
PluginUniqueIdentifierError,
|
||||
)
|
||||
|
||||
plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_DAEMON_URL
|
||||
plugin_daemon_inner_api_key = dify_config.PLUGIN_DAEMON_KEY
|
||||
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
@ -52,9 +52,9 @@ class BasePluginClient:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API.
|
||||
"""
|
||||
url = URL(str(plugin_daemon_inner_api_baseurl)) / path
|
||||
url = plugin_daemon_inner_api_baseurl / path
|
||||
headers = headers or {}
|
||||
headers["X-Api-Key"] = plugin_daemon_inner_api_key
|
||||
headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
headers["Accept-Encoding"] = "gzip, deflate, br"
|
||||
|
||||
if headers.get("Content-Type") == "application/json" and isinstance(data, dict):
|
||||
@ -136,12 +136,31 @@ class BasePluginClient:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
json_response = response.json()
|
||||
if transformer:
|
||||
json_response = transformer(json_response)
|
||||
try:
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}"
|
||||
logging.exception(msg)
|
||||
raise e
|
||||
except Exception as e:
|
||||
msg = f"Failed to request plugin daemon, url: {path}"
|
||||
logging.exception(msg)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
try:
|
||||
json_response = response.json()
|
||||
if transformer:
|
||||
json_response = transformer(json_response)
|
||||
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
|
||||
except Exception:
|
||||
msg = (
|
||||
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}],"
|
||||
f" url: {path}"
|
||||
)
|
||||
logging.exception(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
|
||||
if rep.code != 0:
|
||||
try:
|
||||
error = PluginDaemonError(**json.loads(rep.message))
|
||||
|
||||
@ -3,7 +3,9 @@ from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
@ -393,19 +395,13 @@ class ProviderManager:
|
||||
|
||||
@staticmethod
|
||||
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
|
||||
"""
|
||||
Get all provider records of the workspace.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()
|
||||
|
||||
provider_name_to_provider_records_dict = defaultdict(list)
|
||||
for provider in providers:
|
||||
# TODO: Use provider name with prefix after the data migration
|
||||
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
|
||||
providers = session.scalars(stmt)
|
||||
for provider in providers:
|
||||
# Use provider name with prefix after the data migration
|
||||
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
|
||||
return provider_name_to_provider_records_dict
|
||||
|
||||
@staticmethod
|
||||
@ -416,17 +412,12 @@ class ProviderManager:
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
# Get all provider model records of the workspace
|
||||
provider_models = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
|
||||
.all()
|
||||
)
|
||||
|
||||
provider_name_to_provider_model_records_dict = defaultdict(list)
|
||||
for provider_model in provider_models:
|
||||
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
|
||||
provider_models = session.scalars(stmt)
|
||||
for provider_model in provider_models:
|
||||
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
|
||||
return provider_name_to_provider_model_records_dict
|
||||
|
||||
@staticmethod
|
||||
@ -437,17 +428,14 @@ class ProviderManager:
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
preferred_provider_types = (
|
||||
db.session.query(TenantPreferredModelProvider)
|
||||
.filter(TenantPreferredModelProvider.tenant_id == tenant_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
provider_name_to_preferred_provider_type_records_dict = {
|
||||
preferred_provider_type.provider_name: preferred_provider_type
|
||||
for preferred_provider_type in preferred_provider_types
|
||||
}
|
||||
|
||||
provider_name_to_preferred_provider_type_records_dict = {}
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
|
||||
preferred_provider_types = session.scalars(stmt)
|
||||
provider_name_to_preferred_provider_type_records_dict = {
|
||||
preferred_provider_type.provider_name: preferred_provider_type
|
||||
for preferred_provider_type in preferred_provider_types
|
||||
}
|
||||
return provider_name_to_preferred_provider_type_records_dict
|
||||
|
||||
@staticmethod
|
||||
@ -458,18 +446,14 @@ class ProviderManager:
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
provider_model_settings = (
|
||||
db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
provider_name_to_provider_model_settings_dict = defaultdict(list)
|
||||
for provider_model_setting in provider_model_settings:
|
||||
(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
|
||||
provider_model_settings = session.scalars(stmt)
|
||||
for provider_model_setting in provider_model_settings:
|
||||
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
|
||||
provider_model_setting
|
||||
)
|
||||
)
|
||||
|
||||
return provider_name_to_provider_model_settings_dict
|
||||
|
||||
@staticmethod
|
||||
@ -492,15 +476,14 @@ class ProviderManager:
|
||||
if not model_load_balancing_enabled:
|
||||
return {}
|
||||
|
||||
provider_load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
provider_name_to_provider_load_balancing_model_configs_dict[
|
||||
provider_load_balancing_config.provider_name
|
||||
].append(provider_load_balancing_config)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
|
||||
provider_load_balancing_configs = session.scalars(stmt)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
provider_name_to_provider_load_balancing_model_configs_dict[
|
||||
provider_load_balancing_config.provider_name
|
||||
].append(provider_load_balancing_config)
|
||||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
@ -626,10 +609,9 @@ class ProviderManager:
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
# fix origin data
|
||||
if (
|
||||
custom_provider_record.encrypted_config
|
||||
and not custom_provider_record.encrypted_config.startswith("{")
|
||||
):
|
||||
if custom_provider_record.encrypted_config is None:
|
||||
raise ValueError("No credentials found")
|
||||
if not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
@ -733,7 +715,7 @@ class ProviderManager:
|
||||
return SystemConfiguration(enabled=False)
|
||||
|
||||
# Convert provider_records to dict
|
||||
quota_type_to_provider_records_dict = {}
|
||||
quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
||||
continue
|
||||
@ -758,6 +740,11 @@ class ProviderManager:
|
||||
else:
|
||||
provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type]
|
||||
|
||||
if provider_record.quota_used is None:
|
||||
raise ValueError("quota_used is None")
|
||||
if provider_record.quota_limit is None:
|
||||
raise ValueError("quota_limit is None")
|
||||
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
@ -791,10 +778,9 @@ class ProviderManager:
|
||||
cached_provider_credentials = provider_credentials_cache.get()
|
||||
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials = {}
|
||||
provider_credentials: dict[str, Any] = {}
|
||||
if provider_records and provider_records[0].encrypted_config:
|
||||
provider_credentials = json.loads(provider_records[0].encrypted_config)
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
|
||||
@ -720,7 +720,7 @@ STOPWORDS = {
|
||||
"〉",
|
||||
"〈",
|
||||
"…",
|
||||
" ",
|
||||
" ",
|
||||
"0",
|
||||
"1",
|
||||
"2",
|
||||
@ -731,16 +731,6 @@ STOPWORDS = {
|
||||
"7",
|
||||
"8",
|
||||
"9",
|
||||
"0",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7",
|
||||
"8",
|
||||
"9",
|
||||
"二",
|
||||
"三",
|
||||
"四",
|
||||
|
||||
@ -85,7 +85,6 @@ class BaiduVector(BaseVector):
|
||||
end = min(start + batch_size, total_count)
|
||||
rows = []
|
||||
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
|
||||
# FIXME do you need this assert?
|
||||
for i in range(start, end, 1):
|
||||
row = Row(
|
||||
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
|
||||
|
||||
@ -142,7 +142,7 @@ class ElasticSearchVector(BaseVector):
|
||||
if score > score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
@ -97,6 +97,10 @@ class MilvusVector(BaseVector):
|
||||
|
||||
try:
|
||||
milvus_version = self._client.get_server_version()
|
||||
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
|
||||
if "Zilliz Cloud" in milvus_version:
|
||||
return True
|
||||
# For standard Milvus installations, check version number
|
||||
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
|
||||
|
||||
@ -184,7 +184,16 @@ class OpenSearchVector(BaseVector):
|
||||
}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
query["query"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
query["query"] = {
|
||||
"script_score": {
|
||||
"query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID.value: document_ids_filter}}]}},
|
||||
"script": {
|
||||
"source": "knn_score",
|
||||
"lang": "knn",
|
||||
"params": {"field": Field.VECTOR.value, "query_value": query_vector, "space_type": "l2"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
@ -209,10 +218,10 @@ class OpenSearchVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
|
||||
full_text_query = {"query": {"bool": {"must": [{"match": {Field.CONTENT_KEY.value: query}}]}}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter}
|
||||
full_text_query["query"]["bool"]["filter"] = [{"terms": {"metadata.document_id": document_ids_filter}}]
|
||||
|
||||
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
|
||||
|
||||
@ -255,7 +264,8 @@ class OpenSearchVector(BaseVector):
|
||||
Field.METADATA_KEY.value: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type
|
||||
"document_id": {"type": "keyword"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -261,7 +261,7 @@ class OracleVector(BaseVector):
|
||||
words = pseg.cut(query)
|
||||
current_entity = ""
|
||||
for word, pos in words:
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
|
||||
current_entity += word
|
||||
else:
|
||||
if current_entity:
|
||||
@ -303,7 +303,6 @@ class OracleVector(BaseVector):
|
||||
return docs
|
||||
else:
|
||||
return [Document(page_content="", metadata={})]
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_connection() as conn:
|
||||
|
||||
@ -245,4 +245,4 @@ class TidbService:
|
||||
return cluster_infos
|
||||
else:
|
||||
response.raise_for_status()
|
||||
return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception
|
||||
return []
|
||||
|
||||
@ -139,4 +139,4 @@ class CacheEmbedding(Embeddings):
|
||||
logging.exception(f"Failed to add embedding to redis for the text '{text[:10]}...({len(text)} chars)'")
|
||||
raise ex
|
||||
|
||||
return embedding_results
|
||||
return embedding_results # type: ignore
|
||||
|
||||
23
api/core/rag/entities/citation_metadata.py
Normal file
23
api/core/rag/entities/citation_metadata.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RetrievalSourceMetadata(BaseModel):
|
||||
position: Optional[int] = None
|
||||
dataset_id: Optional[str] = None
|
||||
dataset_name: Optional[str] = None
|
||||
document_id: Optional[str] = None
|
||||
document_name: Optional[str] = None
|
||||
data_source_type: Optional[str] = None
|
||||
segment_id: Optional[str] = None
|
||||
retriever_from: Optional[str] = None
|
||||
score: Optional[float] = None
|
||||
hit_count: Optional[int] = None
|
||||
word_count: Optional[int] = None
|
||||
segment_position: Optional[int] = None
|
||||
index_node_hash: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
page: Optional[int] = None
|
||||
doc_metadata: Optional[dict[str, Any]] = None
|
||||
title: Optional[str] = None
|
||||
@ -27,6 +27,8 @@ class WebsiteInfo(BaseModel):
|
||||
website import info.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
provider: str
|
||||
job_id: str
|
||||
url: str
|
||||
@ -34,12 +36,6 @@ class WebsiteInfo(BaseModel):
|
||||
tenant_id: str
|
||||
only_main_content: bool = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class ExtractSetting(BaseModel):
|
||||
"""
|
||||
|
||||
@ -104,7 +104,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
|
||||
# check file type
|
||||
if not file.filename.endswith(".csv"):
|
||||
if not file.filename or not file.filename.endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
try:
|
||||
|
||||
@ -45,13 +45,12 @@ class BaseDocumentTransformer(ABC):
|
||||
.. code-block:: python
|
||||
|
||||
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
embeddings: Embeddings
|
||||
similarity_fn: Callable = cosine_similarity
|
||||
similarity_threshold: float = 0.95
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
|
||||
@ -35,6 +35,7 @@ from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
@ -198,21 +199,21 @@ class DatasetRetrieval:
|
||||
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
document_context_list = []
|
||||
retrieval_resource_list = []
|
||||
document_context_list: list[DocumentContext] = []
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
|
||||
source = {
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": invoke_from.to_source(),
|
||||
"score": item.metadata.get("score"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=item.metadata.get("dataset_id"),
|
||||
dataset_name=item.metadata.get("dataset_name"),
|
||||
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
document_name=item.metadata.get("title"),
|
||||
data_source_type="external",
|
||||
retriever_from=invoke_from.to_source(),
|
||||
score=item.metadata.get("score"),
|
||||
content=item.page_content,
|
||||
)
|
||||
retrieval_resource_list.append(source)
|
||||
# deal with dify documents
|
||||
if dify_documents:
|
||||
@ -248,32 +249,32 @@ class DatasetRetrieval:
|
||||
.first()
|
||||
)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": invoke_from.to_source(),
|
||||
"score": record.score or 0.0,
|
||||
"doc_metadata": document.doc_metadata,
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
score=record.score or 0.0,
|
||||
doc_metadata=document.doc_metadata,
|
||||
)
|
||||
|
||||
if invoke_from.to_source() == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
source.hit_count = segment.hit_count
|
||||
source.word_count = segment.word_count
|
||||
source.segment_position = segment.position
|
||||
source.index_node_hash = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
source.content = segment.content
|
||||
retrieval_resource_list.append(source)
|
||||
if hit_callback and retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True)
|
||||
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||
item["position"] = position
|
||||
item.position = position
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
@ -936,6 +937,9 @@ class DatasetRetrieval:
|
||||
return metadata_filter_document_ids, metadata_condition
|
||||
|
||||
def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
|
||||
if not inputs:
|
||||
return text
|
||||
|
||||
def replacer(match):
|
||||
key = match.group(1)
|
||||
return str(inputs.get(key, f"{{{{{key}}}}}"))
|
||||
|
||||
@ -9,7 +9,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction
|
||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
|
||||
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
||||
|
||||
@ -165,7 +165,7 @@ class ReactMultiDatasetRouter:
|
||||
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
# deduct quota
|
||||
LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
|
||||
@ -10,12 +10,12 @@ from sqlalchemy import select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_execution_entities import (
|
||||
from core.workflow.entities.workflow_execution import (
|
||||
WorkflowExecution,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
@ -104,10 +104,9 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
status = WorkflowExecutionStatus(db_model.status)
|
||||
|
||||
return WorkflowExecution(
|
||||
id=db_model.id,
|
||||
id_=db_model.id,
|
||||
workflow_id=db_model.workflow_id,
|
||||
sequence_number=db_model.sequence_number,
|
||||
type=WorkflowType(db_model.type),
|
||||
workflow_type=WorkflowType(db_model.type),
|
||||
workflow_version=db_model.version,
|
||||
graph=graph,
|
||||
inputs=inputs,
|
||||
@ -140,14 +139,29 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
raise ValueError("created_by_role is required in repository constructor")
|
||||
|
||||
db_model = WorkflowRun()
|
||||
db_model.id = domain_model.id
|
||||
db_model.id = domain_model.id_
|
||||
db_model.tenant_id = self._tenant_id
|
||||
if self._app_id is not None:
|
||||
db_model.app_id = self._app_id
|
||||
db_model.workflow_id = domain_model.workflow_id
|
||||
db_model.triggered_from = self._triggered_from
|
||||
db_model.sequence_number = domain_model.sequence_number
|
||||
db_model.type = domain_model.type
|
||||
|
||||
# Check if this is a new record
|
||||
with self._session_factory() as session:
|
||||
existing = session.scalar(select(WorkflowRun).where(WorkflowRun.id == domain_model.id_))
|
||||
if not existing:
|
||||
# For new records, get the next sequence number
|
||||
stmt = select(WorkflowRun.sequence_number).where(
|
||||
WorkflowRun.app_id == self._app_id,
|
||||
WorkflowRun.tenant_id == self._tenant_id,
|
||||
)
|
||||
max_sequence = session.scalar(stmt.order_by(WorkflowRun.sequence_number.desc()))
|
||||
db_model.sequence_number = (max_sequence or 0) + 1
|
||||
else:
|
||||
# For updates, keep the existing sequence number
|
||||
db_model.sequence_number = existing.sequence_number
|
||||
|
||||
db_model.type = domain_model.workflow_type
|
||||
db_model.version = domain_model.workflow_version
|
||||
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
|
||||
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
|
||||
|
||||
@ -12,19 +12,18 @@ from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.node_execution_entities import (
|
||||
NodeExecution,
|
||||
NodeExecutionStatus,
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
)
|
||||
|
||||
@ -87,9 +86,9 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
|
||||
# Initialize in-memory cache for node executions
|
||||
# Key: node_execution_id, Value: WorkflowNodeExecution (DB model)
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecutionModel] = {}
|
||||
|
||||
def _to_domain_model(self, db_model: WorkflowNodeExecution) -> NodeExecution:
|
||||
def _to_domain_model(self, db_model: WorkflowNodeExecutionModel) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Convert a database model to a domain model.
|
||||
|
||||
@ -103,16 +102,16 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
inputs = db_model.inputs_dict
|
||||
process_data = db_model.process_data_dict
|
||||
outputs = db_model.outputs_dict
|
||||
metadata = {NodeRunMetadataKey(k): v for k, v in db_model.execution_metadata_dict.items()}
|
||||
metadata = {WorkflowNodeExecutionMetadataKey(k): v for k, v in db_model.execution_metadata_dict.items()}
|
||||
|
||||
# Convert status to domain enum
|
||||
status = NodeExecutionStatus(db_model.status)
|
||||
status = WorkflowNodeExecutionStatus(db_model.status)
|
||||
|
||||
return NodeExecution(
|
||||
return WorkflowNodeExecution(
|
||||
id=db_model.id,
|
||||
node_execution_id=db_model.node_execution_id,
|
||||
workflow_id=db_model.workflow_id,
|
||||
workflow_run_id=db_model.workflow_run_id,
|
||||
workflow_execution_id=db_model.workflow_run_id,
|
||||
index=db_model.index,
|
||||
predecessor_node_id=db_model.predecessor_node_id,
|
||||
node_id=db_model.node_id,
|
||||
@ -129,7 +128,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
finished_at=db_model.finished_at,
|
||||
)
|
||||
|
||||
def to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution:
|
||||
def to_db_model(self, domain_model: WorkflowNodeExecution) -> WorkflowNodeExecutionModel:
|
||||
"""
|
||||
Convert a domain model to a database model.
|
||||
|
||||
@ -147,14 +146,14 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
if not self._creator_user_role:
|
||||
raise ValueError("created_by_role is required in repository constructor")
|
||||
|
||||
db_model = WorkflowNodeExecution()
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = domain_model.id
|
||||
db_model.tenant_id = self._tenant_id
|
||||
if self._app_id is not None:
|
||||
db_model.app_id = self._app_id
|
||||
db_model.workflow_id = domain_model.workflow_id
|
||||
db_model.triggered_from = self._triggered_from
|
||||
db_model.workflow_run_id = domain_model.workflow_run_id
|
||||
db_model.workflow_run_id = domain_model.workflow_execution_id
|
||||
db_model.index = domain_model.index
|
||||
db_model.predecessor_node_id = domain_model.predecessor_node_id
|
||||
db_model.node_execution_id = domain_model.node_execution_id
|
||||
@ -176,7 +175,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
db_model.finished_at = domain_model.finished_at
|
||||
return db_model
|
||||
|
||||
def save(self, execution: NodeExecution) -> None:
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save or update a NodeExecution domain entity to the database.
|
||||
|
||||
@ -208,7 +207,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
|
||||
self._node_execution_cache[db_model.node_execution_id] = db_model
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve a NodeExecution by its node_execution_id.
|
||||
|
||||
@ -231,13 +230,13 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
# If not in cache, query the database
|
||||
logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.node_execution_id == node_execution_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
||||
|
||||
db_model = session.scalar(stmt)
|
||||
if db_model:
|
||||
@ -253,7 +252,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
|
||||
|
||||
@ -271,20 +270,20 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
A list of WorkflowNodeExecution database models
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
||||
|
||||
# Apply ordering if provided
|
||||
if order_config and order_config.order_by:
|
||||
order_columns: list[UnaryExpression] = []
|
||||
for field in order_config.order_by:
|
||||
column = getattr(WorkflowNodeExecution, field, None)
|
||||
column = getattr(WorkflowNodeExecutionModel, field, None)
|
||||
if not column:
|
||||
continue
|
||||
if order_config.order_direction == "desc":
|
||||
@ -308,7 +307,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[NodeExecution]:
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all NodeExecution instances for a specific workflow run.
|
||||
|
||||
@ -335,7 +334,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
|
||||
return domain_models
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all running NodeExecution instances for a specific workflow run.
|
||||
|
||||
@ -349,15 +348,15 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
A list of running NodeExecution instances
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING,
|
||||
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
||||
|
||||
db_models = session.scalars(stmt).all()
|
||||
domain_models = []
|
||||
@ -382,10 +381,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
It also clears the in-memory cache.
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
|
||||
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
||||
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
- audio
|
||||
- code
|
||||
- time
|
||||
- qrcode
|
||||
- webscraper
|
||||
|
||||
@ -168,7 +168,7 @@ class ApiTool(Tool):
|
||||
cookies[parameter["name"]] = value
|
||||
|
||||
elif parameter["in"] == "header":
|
||||
headers[parameter["name"]] = value
|
||||
headers[parameter["name"]] = str(value)
|
||||
|
||||
# check if there is a request body and handle it
|
||||
if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None:
|
||||
|
||||
@ -279,7 +279,6 @@ class ToolParameter(PluginParameter):
|
||||
:param options: the options of the parameter
|
||||
"""
|
||||
# convert options to ToolParameterOption
|
||||
# FIXME fix the type error
|
||||
if options:
|
||||
option_objs = [
|
||||
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
|
||||
@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -107,7 +108,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
else:
|
||||
document_context_list.append(segment.get_sign_content())
|
||||
if self.return_resource:
|
||||
context_list = []
|
||||
context_list: list[RetrievalSourceMetadata] = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
@ -121,28 +122,28 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
.first()
|
||||
)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"position": resource_number,
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
"doc_metadata": document.doc_metadata,
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
position=resource_number,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=self.retriever_from,
|
||||
score=document_score_list.get(segment.index_node_id, None),
|
||||
doc_metadata=document.doc_metadata,
|
||||
)
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
source.hit_count = segment.hit_count
|
||||
source.word_count = segment.word_count
|
||||
source.segment_position = segment.position
|
||||
source.index_node_hash = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
source.content = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
|
||||
@ -152,8 +153,6 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
return str("\n".join(document_context_list))
|
||||
return ""
|
||||
|
||||
raise RuntimeError("not segments found")
|
||||
|
||||
def _retriever(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
|
||||
@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
@ -14,7 +15,7 @@ from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
@ -79,7 +80,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
else:
|
||||
document_ids_filter = None
|
||||
if dataset.provider == "external":
|
||||
results = []
|
||||
results: list[RetrievalDocument] = []
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
@ -100,21 +101,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
# deal with external documents
|
||||
context_list = []
|
||||
context_list: list[RetrievalSourceMetadata] = []
|
||||
for position, item in enumerate(results, start=1):
|
||||
if item.metadata is not None:
|
||||
source = {
|
||||
"position": position,
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": item.metadata.get("score"),
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
position=position,
|
||||
dataset_id=item.metadata.get("dataset_id"),
|
||||
dataset_name=item.metadata.get("dataset_name"),
|
||||
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
document_name=item.metadata.get("title"),
|
||||
data_source_type="external",
|
||||
retriever_from=self.retriever_from,
|
||||
score=item.metadata.get("score"),
|
||||
title=item.metadata.get("title"),
|
||||
content=item.page_content,
|
||||
)
|
||||
context_list.append(source)
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
@ -125,7 +126,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
return ""
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_resource_list = []
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
@ -163,7 +164,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
for item in documents:
|
||||
if item.metadata is not None and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
document_context_list = []
|
||||
document_context_list: list[DocumentContext] = []
|
||||
records = RetrievalService.format_retrieval_documents(documents)
|
||||
if records:
|
||||
for record in records:
|
||||
@ -197,37 +198,37 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
.first()
|
||||
)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id, # type: ignore
|
||||
"document_name": document.name, # type: ignore
|
||||
"data_source_type": document.data_source_type, # type: ignore
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": record.score or 0.0,
|
||||
"doc_metadata": document.doc_metadata, # type: ignore
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id, # type: ignore
|
||||
document_name=document.name, # type: ignore
|
||||
data_source_type=document.data_source_type, # type: ignore
|
||||
segment_id=segment.id,
|
||||
retriever_from=self.retriever_from,
|
||||
score=record.score or 0.0,
|
||||
doc_metadata=document.doc_metadata, # type: ignore
|
||||
)
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
source.hit_count = segment.hit_count
|
||||
source.word_count = segment.word_count
|
||||
source.segment_position = segment.position
|
||||
source.index_node_hash = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
source.content = segment.content
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
if self.return_resource and retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=lambda x: x.get("score") or 0.0,
|
||||
key=lambda x: x.score or 0.0,
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
|
||||
item["position"] = position # type: ignore
|
||||
item.position = position # type: ignore
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
|
||||
@ -32,14 +32,14 @@ class ToolFileMessageTransformer:
|
||||
try:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
tool_file_manager = ToolFileManager()
|
||||
file = tool_file_manager.create_file_by_url(
|
||||
tool_file = tool_file_manager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=message.message.text,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}"
|
||||
url = f"/files/tools/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}"
|
||||
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
@ -66,10 +66,9 @@ class ToolFileMessageTransformer:
|
||||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
tool_file_manager = ToolFileManager()
|
||||
file = tool_file_manager.create_file_by_raw(
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
@ -78,7 +77,7 @@ class ToolFileMessageTransformer:
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))
|
||||
url = cls.get_tool_file_url(tool_file_id=tool_file.id, extension=guess_extension(tool_file.mimetype))
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
|
||||
@ -55,6 +55,13 @@ class ApiBasedToolSchemaParser:
|
||||
# convert parameters
|
||||
parameters = []
|
||||
if "parameters" in interface["operation"]:
|
||||
for i, parameter in enumerate(interface["operation"]["parameters"]):
|
||||
if "$ref" in parameter:
|
||||
root = openapi
|
||||
reference = parameter["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
interface["operation"]["parameters"][i] = root
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter["name"],
|
||||
|
||||
@ -1,36 +1,10 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeRunMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
@ -43,7 +17,7 @@ class NodeRunResult(BaseModel):
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[Mapping[str, Any]] = None # process data
|
||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # node metadata
|
||||
llm_usage: Optional[LLMUsage] = None # llm usage
|
||||
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
@ -36,12 +36,10 @@ class WorkflowExecution(BaseModel):
|
||||
user, tenant, and app attributes.
|
||||
"""
|
||||
|
||||
id: str = Field(...)
|
||||
id_: str = Field(...)
|
||||
workflow_id: str = Field(...)
|
||||
workflow_version: str = Field(...)
|
||||
sequence_number: int = Field(...)
|
||||
|
||||
type: WorkflowType = Field(...)
|
||||
workflow_type: WorkflowType = Field(...)
|
||||
graph: Mapping[str, Any] = Field(...)
|
||||
|
||||
inputs: Mapping[str, Any] = Field(...)
|
||||
@ -69,20 +67,18 @@ class WorkflowExecution(BaseModel):
|
||||
def new(
|
||||
cls,
|
||||
*,
|
||||
id: str,
|
||||
id_: str,
|
||||
workflow_id: str,
|
||||
sequence_number: int,
|
||||
type: WorkflowType,
|
||||
workflow_type: WorkflowType,
|
||||
workflow_version: str,
|
||||
graph: Mapping[str, Any],
|
||||
inputs: Mapping[str, Any],
|
||||
started_at: datetime,
|
||||
) -> "WorkflowExecution":
|
||||
return WorkflowExecution(
|
||||
id=id,
|
||||
id_=id_,
|
||||
workflow_id=workflow_id,
|
||||
sequence_number=sequence_number,
|
||||
type=type,
|
||||
workflow_type=workflow_type,
|
||||
workflow_version=workflow_version,
|
||||
graph=graph,
|
||||
inputs=inputs,
|
||||
@ -13,11 +13,35 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class NodeExecutionStatus(StrEnum):
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
"""
|
||||
Node Execution Status Enum.
|
||||
"""
|
||||
@ -29,7 +53,7 @@ class NodeExecutionStatus(StrEnum):
|
||||
RETRY = "retry"
|
||||
|
||||
|
||||
class NodeExecution(BaseModel):
|
||||
class WorkflowNodeExecution(BaseModel):
|
||||
"""
|
||||
Domain model for workflow node execution.
|
||||
|
||||
@ -46,7 +70,7 @@ class NodeExecution(BaseModel):
|
||||
id: str # Unique identifier for this execution record
|
||||
node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing
|
||||
workflow_id: str # ID of the workflow this node belongs to
|
||||
workflow_run_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
|
||||
workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
|
||||
|
||||
# Execution positioning and flow
|
||||
index: int # Sequence number for ordering in trace visualization
|
||||
@ -61,12 +85,12 @@ class NodeExecution(BaseModel):
|
||||
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
|
||||
|
||||
# Execution state
|
||||
status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status
|
||||
error: Optional[str] = None # Error message if execution failed
|
||||
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
|
||||
|
||||
# Additional metadata
|
||||
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
|
||||
|
||||
# Timing information
|
||||
created_at: datetime # When execution started
|
||||
@ -77,7 +101,7 @@ class NodeExecution(BaseModel):
|
||||
inputs: Optional[Mapping[str, Any]] = None,
|
||||
process_data: Optional[Mapping[str, Any]] = None,
|
||||
outputs: Optional[Mapping[str, Any]] = None,
|
||||
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None,
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update the model from mappings.
|
||||
@ -13,4 +13,4 @@ class SystemVariableKey(StrEnum):
|
||||
DIALOGUE_COUNT = "dialogue_count"
|
||||
APP_ID = "app_id"
|
||||
WORKFLOW_ID = "workflow_id"
|
||||
WORKFLOW_RUN_ID = "workflow_run_id"
|
||||
WORKFLOW_EXECUTION_ID = "workflow_run_id"
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes import NodeType
|
||||
@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user