mirror of
https://github.com/langgenius/dify.git
synced 2026-02-04 10:47:49 +08:00
Compare commits
1 Commits
0.8.0-beta
...
fix/toolti
| Author | SHA1 | Date | |
|---|---|---|---|
| 2474dbdff0 |
1
.github/workflows/build-push.yml
vendored
1
.github/workflows/build-push.yml
vendored
@ -125,6 +125,7 @@ jobs:
|
||||
with:
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
tags: |
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
|
||||
@ -1,54 +0,0 @@
|
||||
name: Check i18n Files and Create PR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed]
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
check-and-update:
|
||||
if: github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: web
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2 # last 2 commits
|
||||
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
run: |
|
||||
recent_commit_sha=$(git rev-parse HEAD)
|
||||
second_recent_commit_sha=$(git rev-parse HEAD~1)
|
||||
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Set up Node.js
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: yarn install --frozen-lockfile
|
||||
|
||||
- name: Run npm script
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: npm run auto-gen-i18n
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
commit-message: Update i18n files based on en-US changes
|
||||
title: 'chore: translate i18n files'
|
||||
body: This PR was automatically created to update i18n files based on changes in en-US locale.
|
||||
branch: chore/automated-i18n-updates
|
||||
@ -8,7 +8,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
|
||||
|
||||
## Before you jump in
|
||||
|
||||
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
|
||||
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
|
||||
|
||||
### Feature requests:
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
|
||||
## 在开始之前
|
||||
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:open)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
|
||||
|
||||
### 功能请求:
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ Dify にコントリビュートしたいとお考えなのですね。それは
|
||||
|
||||
## 飛び込む前に
|
||||
|
||||
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:open) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
|
||||
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
|
||||
|
||||
### 機能リクエスト
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ Về vấn đề cấp phép, xin vui lòng dành chút thời gian đọc qua [
|
||||
|
||||
## Trước khi bắt đầu
|
||||
|
||||
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:open) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
|
||||
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
|
||||
|
||||
### Yêu cầu tính năng:
|
||||
|
||||
|
||||
@ -60,8 +60,7 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key
|
||||
ALIYUN_OSS_ENDPOINT=your-endpoint
|
||||
ALIYUN_OSS_AUTH_VERSION=v1
|
||||
ALIYUN_OSS_REGION=your-region
|
||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||
ALIYUN_OSS_PATH=your-path
|
||||
|
||||
# Google Storage configuration
|
||||
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
|
||||
|
||||
@ -55,7 +55,7 @@ RUN apt-get update \
|
||||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-2 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@ -559,9 +559,8 @@ def add_qdrant_doc_id_index(field: str):
|
||||
|
||||
@click.command("create-tenant", help="Create account and tenant.")
|
||||
@click.option("--email", prompt=True, help="The email address of the tenant account.")
|
||||
@click.option("--name", prompt=True, help="The workspace name of the tenant account.")
|
||||
@click.option("--language", prompt=True, help="Account language, default: en-US.")
|
||||
def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None):
|
||||
def create_tenant(email: str, language: Optional[str] = None):
|
||||
"""
|
||||
Create tenant account
|
||||
"""
|
||||
@ -581,15 +580,13 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
|
||||
if language not in languages:
|
||||
language = "en-US"
|
||||
|
||||
name = name.strip()
|
||||
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
|
||||
# register account
|
||||
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name)
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
@ -217,17 +217,20 @@ class HttpConfig(BaseSettings):
|
||||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request")
|
||||
] = 10
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: NonNegativeInt = Field(
|
||||
description="",
|
||||
default=300,
|
||||
)
|
||||
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request")
|
||||
] = 60
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: NonNegativeInt = Field(
|
||||
description="",
|
||||
default=600,
|
||||
)
|
||||
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request")
|
||||
] = 20
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: NonNegativeInt = Field(
|
||||
description="",
|
||||
default=600,
|
||||
)
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
||||
description="",
|
||||
|
||||
@ -38,8 +38,3 @@ class AliyunOSSStorageConfig(BaseSettings):
|
||||
description="Aliyun OSS authentication version",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_PATH: Optional[str] = Field(
|
||||
description="Aliyun OSS path",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.8.0-beta1",
|
||||
default="0.7.2",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -174,7 +174,6 @@ class AppApi(Resource):
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("max_active_requests", type=int, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
@ -34,7 +34,6 @@ def parse_app_site_args():
|
||||
)
|
||||
parser.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -69,7 +68,6 @@ class AppSite(Resource):
|
||||
"customize_token_strategy",
|
||||
"prompt_public",
|
||||
"show_workflow_steps",
|
||||
"use_icon_as_answer_icon",
|
||||
]:
|
||||
value = args.get(attr_name)
|
||||
if value is not None:
|
||||
|
||||
@ -122,7 +122,6 @@ class DatasetListApi(Resource):
|
||||
name=args["name"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
@ -39,7 +39,7 @@ class FileApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(file_fields)
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
@cloud_edition_billing_resource_check(resource="documents")
|
||||
def post(self):
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
|
||||
@ -35,7 +35,6 @@ class InstalledAppsListApi(Resource):
|
||||
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
|
||||
}
|
||||
for installed_app in installed_apps
|
||||
if installed_app.app is not None
|
||||
]
|
||||
installed_apps.sort(
|
||||
key=lambda app: (
|
||||
|
||||
@ -46,7 +46,9 @@ def only_edition_self_hosted(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str):
|
||||
def cloud_edition_billing_resource_check(
|
||||
resource: str, error_msg: str = "You have reached the limit of your subscription."
|
||||
):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
@ -58,22 +60,22 @@ def cloud_edition_billing_resource_check(resource: str):
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
annotation_quota_limit = features.annotation_quota_limit
|
||||
if resource == "members" and 0 < members.limit <= members.size:
|
||||
abort(403, "The number of members has reached the limit of your subscription.")
|
||||
abort(403, error_msg)
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
abort(403, "The number of apps has reached the limit of your subscription.")
|
||||
abort(403, error_msg)
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
|
||||
abort(403, error_msg)
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
|
||||
source = request.args.get("source")
|
||||
if source == "datasets":
|
||||
abort(403, "The number of documents has reached the limit of your subscription.")
|
||||
abort(403, error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
elif resource == "workspace_custom" and not features.can_replace_logo:
|
||||
abort(403, "The workspace custom feature has reached the limit of your subscription.")
|
||||
abort(403, error_msg)
|
||||
elif resource == "annotation" and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:
|
||||
abort(403, "The annotation quota has reached the limit of your subscription.")
|
||||
abort(403, error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -84,7 +86,10 @@ def cloud_edition_billing_resource_check(resource: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
def cloud_edition_billing_knowledge_limit_check(
|
||||
resource: str,
|
||||
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
|
||||
):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
@ -92,10 +97,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
abort(
|
||||
403,
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
|
||||
)
|
||||
abort(403, error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@ -36,10 +36,6 @@ class SegmentApi(DatasetApiResource):
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if document.indexing_status != "completed":
|
||||
raise NotFound("Document is already completed.")
|
||||
if not document.enabled:
|
||||
raise NotFound("Document is disabled.")
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
|
||||
@ -83,7 +83,9 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
def cloud_edition_billing_resource_check(
|
||||
resource: str, api_token_type: str, error_msg: str = "You have reached the limit of your subscription."
|
||||
):
|
||||
def interceptor(view):
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
@ -96,13 +98,13 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
|
||||
if resource == "members" and 0 < members.limit <= members.size:
|
||||
raise Forbidden("The number of members has reached the limit of your subscription.")
|
||||
raise Forbidden(error_msg)
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
raise Forbidden("The number of apps has reached the limit of your subscription.")
|
||||
raise Forbidden(error_msg)
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.")
|
||||
raise Forbidden(error_msg)
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
raise Forbidden("The number of documents has reached the limit of your subscription.")
|
||||
raise Forbidden(error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -113,7 +115,11 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
||||
def cloud_edition_billing_knowledge_limit_check(
|
||||
resource: str,
|
||||
api_token_type: str,
|
||||
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
|
||||
):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
@ -122,9 +128,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
raise Forbidden(
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
|
||||
)
|
||||
raise Forbidden(error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@ -39,7 +39,6 @@ class AppSiteApi(WebApiResource):
|
||||
"default_language": fields.String,
|
||||
"prompt_public": fields.Boolean,
|
||||
"show_workflow_steps": fields.Boolean,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
}
|
||||
|
||||
app_fields = {
|
||||
|
||||
@ -93,7 +93,7 @@ class DatasetConfigManager:
|
||||
reranking_model=dataset_configs.get('reranking_model'),
|
||||
weights=dataset_configs.get('weights'),
|
||||
reranking_enabled=dataset_configs.get('reranking_enabled', True),
|
||||
rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'),
|
||||
rerank_mode=dataset_configs.get('rerank_mode', 'reranking_model'),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -4,10 +4,12 @@ import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
@ -18,49 +20,33 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import ConversationVariable, Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
stream: bool = True,
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -148,8 +134,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True) \
|
||||
-> dict[str, Any] | Generator[str, Any, None]:
|
||||
stream: bool = True):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -166,6 +151,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
conversation_id = args.get('conversation_id')
|
||||
if conversation_id:
|
||||
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
@ -176,16 +171,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
conversation_id=None,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs={},
|
||||
query='',
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={
|
||||
"auto_generate_conversation_name": False
|
||||
},
|
||||
extras=extras,
|
||||
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id,
|
||||
inputs=args['inputs']
|
||||
@ -198,28 +191,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=None,
|
||||
conversation=conversation,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
def _generate(self, *,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = True) \
|
||||
-> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param invoke_from: invoke from source
|
||||
:param application_generate_entity: application generate entity
|
||||
:param conversation: conversation
|
||||
:param stream: is stream
|
||||
"""
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Conversation | None = None,
|
||||
stream: bool = True):
|
||||
is_first_conversation = False
|
||||
if not conversation:
|
||||
is_first_conversation = True
|
||||
@ -234,7 +216,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# update conversation features
|
||||
conversation.override_model_configs = workflow.features
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
# db.session.refresh(conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
@ -246,12 +228,67 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
message_id=message.id
|
||||
)
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
# Increment dialogue count.
|
||||
conversation.dialogue_count += 1
|
||||
|
||||
conversation_id = conversation.id
|
||||
conversation_dialogue_count = conversation.dialogue_count
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.QUERY: query,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation_id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
contexts.workflow_variable_pool.set(variable_pool)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
'context': contextvars.copy_context(),
|
||||
})
|
||||
@ -277,7 +314,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
context: contextvars.Context) -> None:
|
||||
"""
|
||||
@ -293,19 +329,28 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
var.set(val)
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
runner = AdvancedChatAppRunner()
|
||||
if application_generate_entity.single_iteration_run:
|
||||
single_iteration_run = application_generate_entity.single_iteration_run
|
||||
runner.single_iteration_run(
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=application_generate_entity.app_config.workflow_id,
|
||||
queue_manager=queue_manager,
|
||||
inputs=single_iteration_run.inputs,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_id=application_generate_entity.user_id
|
||||
)
|
||||
else:
|
||||
# get message
|
||||
message = self._get_message(message_id)
|
||||
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
|
||||
runner.run()
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
|
||||
@ -1,67 +1,49 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||
from core.moderation.base import ModerationException
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import ConversationVariable, WorkflowType
|
||||
from models import App, Message, Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
class AdvancedChatAppRunner(AppRunner):
|
||||
"""
|
||||
AdvancedChat Application Runner
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message
|
||||
def run(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
"""
|
||||
super().__init__(queue_manager)
|
||||
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
self.message = message
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
:return:
|
||||
"""
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
@ -72,133 +54,101 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
queue_manager=queue_manager,
|
||||
app_record=app_record,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
):
|
||||
return
|
||||
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
query = self.application_generate_entity.query
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
app_record=app_record,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=self.message.id
|
||||
):
|
||||
return
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=self.message,
|
||||
query=query,
|
||||
app_generate_entity=self.application_generate_entity
|
||||
):
|
||||
return
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
# Increment dialogue count.
|
||||
self.conversation.dialogue_count += 1
|
||||
|
||||
conversation_dialogue_count = self.conversation.dialogue_count
|
||||
db.session.commit()
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.QUERY: query,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
}
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
):
|
||||
return
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = [
|
||||
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
|
||||
]
|
||||
|
||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
workflow=workflow,
|
||||
user_id=application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
def single_iteration_run(
|
||||
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Single iteration run
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
|
||||
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.single_step_run_iteration_workflow_node(
|
||||
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def handle_input_moderation(
|
||||
self,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str
|
||||
self,
|
||||
queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Handle input moderation
|
||||
:param queue_manager: application queue manager
|
||||
:param app_record: app record
|
||||
:param app_generate_entity: application generate entity
|
||||
:param inputs: inputs
|
||||
@ -217,23 +167,30 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
message_id=message_id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self._complete_with_stream_output(
|
||||
self._stream_output(
|
||||
queue_manager=queue_manager,
|
||||
text=str(e),
|
||||
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
|
||||
stream=app_generate_entity.stream,
|
||||
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def handle_annotation_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
|
||||
def handle_annotation_reply(
|
||||
self,
|
||||
app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
queue_manager: AppQueueManager,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
) -> bool:
|
||||
"""
|
||||
Handle annotation reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param queue_manager: application queue manager
|
||||
:param app_generate_entity: application generate entity
|
||||
"""
|
||||
# annotation reply
|
||||
@ -246,32 +203,37 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
self._publish_event(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
self._complete_with_stream_output(
|
||||
self._stream_output(
|
||||
queue_manager=queue_manager,
|
||||
text=annotation_reply.content,
|
||||
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
|
||||
stream=app_generate_entity.stream,
|
||||
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _complete_with_stream_output(self,
|
||||
text: str,
|
||||
stopped_by: QueueStopEvent.StopBy) -> None:
|
||||
def _stream_output(
|
||||
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
|
||||
) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
:param text: text
|
||||
:param stream: stream
|
||||
:return:
|
||||
"""
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
)
|
||||
)
|
||||
if stream:
|
||||
index = 0
|
||||
for token in text:
|
||||
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
self._publish_event(
|
||||
QueueStopEvent(stopped_by=stopped_by)
|
||||
)
|
||||
queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
@ -2,8 +2,9 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
import contexts
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -21,9 +22,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
@ -33,28 +31,34 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatTaskState,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ChatflowStreamGenerateRoute,
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import Conversation, EndUser, Message
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
@ -65,15 +69,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
_task_state: WorkflowTaskState
|
||||
_task_state: AdvancedChatTaskState
|
||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
# Deprecated
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
@ -101,6 +106,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._workflow = workflow
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
# Deprecated
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
@ -108,8 +114,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._task_state = AdvancedChatTaskState(
|
||||
usage=LLMUsage.empty_usage()
|
||||
)
|
||||
|
||||
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
|
||||
self._stream_generate_routes = self._get_stream_generate_routes()
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(self):
|
||||
@ -130,7 +140,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
@ -190,18 +199,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
tts_publisher = None
|
||||
publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@ -212,9 +220,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
# timeout
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not tts_publisher:
|
||||
if not publisher:
|
||||
break
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
@ -232,34 +240,34 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
# init fake graph runtime state
|
||||
graph_runtime_state = None
|
||||
workflow_run = None
|
||||
for message in self._queue_manager.listen():
|
||||
if (message.event
|
||||
and getattr(message.event, 'metadata', None)
|
||||
and message.event.metadata.get('is_answer_previous_node', False)
|
||||
and publisher):
|
||||
publisher.publish(message=message)
|
||||
elif (hasattr(message.event, 'execution_metadata')
|
||||
and message.event.execution_metadata
|
||||
and message.event.execution_metadata.get('is_answer_previous_node', False)
|
||||
and publisher):
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event, self._message)
|
||||
yield self._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
workflow_run = self._handle_workflow_start()
|
||||
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start()
|
||||
|
||||
self._refetch_message()
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
self._message.workflow_run_id = workflow_run.id
|
||||
|
||||
db.session.commit()
|
||||
@ -271,242 +279,133 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
workflow_node_execution = self._handle_node_start(event)
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
# search stream_generate_routes if node id is answer start at node
|
||||
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
|
||||
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
|
||||
# reset current route position to 0
|
||||
self._task_state.current_stream_generate_state.current_route_position = 0
|
||||
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
# generate stream outputs when node started
|
||||
yield from self._generate_stream_outputs_when_node_started()
|
||||
|
||||
yield self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_node_finished(event)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
# stream outputs when node finished
|
||||
generator = self._generate_stream_outputs_when_node_finished()
|
||||
if generator:
|
||||
yield from generator
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
yield self._workflow_node_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=json.dumps(event.outputs) if event.outputs else None,
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
error=event.error,
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
|
||||
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent):
|
||||
if workflow_run and graph_runtime_state:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
error=event.get_stop_reason(),
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
if isinstance(event, QueueNodeFailedEvent):
|
||||
yield from self._handle_iteration_exception(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
error=f'Child node failed: {event.error}'
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
|
||||
if isinstance(event, QueueIterationNextEvent):
|
||||
# clear ran node execution infos of current iteration
|
||||
iteration_relations = self._iteration_nested_relations.get(event.node_id)
|
||||
if iteration_relations:
|
||||
for node_id in iteration_relations:
|
||||
self._task_state.ran_node_execution_infos.pop(node_id, None)
|
||||
|
||||
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
|
||||
self._handle_iteration_operation(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
workflow_run = self._handle_workflow_finished(
|
||||
event, conversation_id=self._conversation.id, trace_manager=trace_manager
|
||||
)
|
||||
if workflow_run:
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
|
||||
if workflow_run.status == WorkflowRunStatus.FAILED.value:
|
||||
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
|
||||
break
|
||||
|
||||
if isinstance(event, QueueStopEvent):
|
||||
# Save message
|
||||
self._save_message()
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
else:
|
||||
self._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
|
||||
# Save message
|
||||
self._save_message(graph_runtime_state=graph_runtime_state)
|
||||
self._save_message()
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._handle_retriever_resources(event)
|
||||
|
||||
self._refetch_message()
|
||||
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(self._message)
|
||||
db.session.close()
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
self._handle_annotation_reply(event)
|
||||
|
||||
self._refetch_message()
|
||||
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(self._message)
|
||||
db.session.close()
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if not self._is_stream_out_support(
|
||||
event=event
|
||||
):
|
||||
continue
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_to_stream_response(delta_text, self._message.id)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
# published by moderation
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
|
||||
# Save message
|
||||
self._save_message(graph_runtime_state=graph_runtime_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
# publish None when task finished
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||
def _save_message(self) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
"""
|
||||
self._refetch_message()
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
|
||||
self._message.answer = self._task_state.answer
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||
usage = graph_runtime_state.llm_usage
|
||||
if self._task_state.metadata and self._task_state.metadata.get('usage'):
|
||||
usage = LLMUsage(**self._task_state.metadata['usage'])
|
||||
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
@ -533,10 +432,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
"""
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata.copy()
|
||||
|
||||
if 'annotation_reply' in extras['metadata']:
|
||||
del extras['metadata']['annotation_reply']
|
||||
extras['metadata'] = self._task_state.metadata
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -544,6 +440,323 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
**extras
|
||||
)
|
||||
|
||||
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# find all answer nodes
|
||||
graph = self._workflow.graph_dict
|
||||
answer_node_configs = [
|
||||
node for node in graph['nodes']
|
||||
if node.get('data', {}).get('type') == NodeType.ANSWER.value
|
||||
]
|
||||
|
||||
# parse stream output node value selectors of answer nodes
|
||||
stream_generate_routes = {}
|
||||
for node_config in answer_node_configs:
|
||||
# get generate route for stream output
|
||||
answer_node_id = node_config['id']
|
||||
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
|
||||
start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id)
|
||||
if not start_node_ids:
|
||||
continue
|
||||
|
||||
for start_node_id in start_node_ids:
|
||||
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
|
||||
answer_node_id=answer_node_id,
|
||||
generate_route=generate_route
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
"""
|
||||
Get answer start at node id.
|
||||
:param graph: graph
|
||||
:param target_node_id: target node ID
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
edges = graph.get('edges')
|
||||
|
||||
# fetch all ingoing edges from source node
|
||||
ingoing_edges = []
|
||||
for edge in edges:
|
||||
if edge.get('target') == target_node_id:
|
||||
ingoing_edges.append(edge)
|
||||
|
||||
if not ingoing_edges:
|
||||
# check if it's the first node in the iteration
|
||||
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
|
||||
if not target_node:
|
||||
return []
|
||||
|
||||
node_iteration_id = target_node.get('data', {}).get('iteration_id')
|
||||
# get iteration start node id
|
||||
for node in nodes:
|
||||
if node.get('id') == node_iteration_id:
|
||||
if node.get('data', {}).get('start_node_id') == target_node_id:
|
||||
return [target_node_id]
|
||||
|
||||
return []
|
||||
|
||||
start_node_ids = []
|
||||
for ingoing_edge in ingoing_edges:
|
||||
source_node_id = ingoing_edge.get('source')
|
||||
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
node_iteration_id = source_node.get('data', {}).get('iteration_id')
|
||||
iteration_start_node_id = None
|
||||
if node_iteration_id:
|
||||
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
|
||||
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
|
||||
|
||||
if node_type in [
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value,
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value
|
||||
]:
|
||||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value or \
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id)
|
||||
if sub_start_node_ids:
|
||||
start_node_ids.extend(sub_start_node_ids)
|
||||
|
||||
return start_node_ids
|
||||
|
||||
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get iteration nested relations.
|
||||
:param graph: graph
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
|
||||
iteration_ids = [node.get('id') for node in nodes
|
||||
if node.get('data', {}).get('type') in [
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value,
|
||||
]]
|
||||
|
||||
return {
|
||||
iteration_id: [
|
||||
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||
] for iteration_id in iteration_ids
|
||||
}
|
||||
|
||||
def _generate_stream_outputs_when_node_started(self) -> Generator:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if self._task_state.current_stream_generate_state:
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:
|
||||
]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
self._task_state.answer += route_chunk.text
|
||||
yield self._message_to_stream_response(route_chunk.text, self._message.id)
|
||||
else:
|
||||
break
|
||||
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return
|
||||
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
self._task_state.answer += route_chunk.text
|
||||
yield self._message_to_stream_response(route_chunk.text, self._message.id)
|
||||
else:
|
||||
value = None
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
if not value_selector:
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
continue
|
||||
|
||||
route_chunk_node_id = value_selector[0]
|
||||
|
||||
if route_chunk_node_id == 'sys':
|
||||
# system variable
|
||||
value = contexts.workflow_variable_pool.get().get(value_selector)
|
||||
if value:
|
||||
value = value.text
|
||||
elif route_chunk_node_id in self._iteration_nested_relations:
|
||||
# it's a iteration variable
|
||||
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
|
||||
continue
|
||||
iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
|
||||
iterator = iteration_state.inputs
|
||||
if not iterator:
|
||||
continue
|
||||
iterator_selector = iterator.get('iterator_selector', [])
|
||||
if value_selector[1] == 'index':
|
||||
value = iteration_state.current_index
|
||||
elif value_selector[1] == 'item':
|
||||
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
|
||||
iterator_selector
|
||||
) else None
|
||||
else:
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
|
||||
break
|
||||
|
||||
latest_node_execution_info = self._task_state.latest_node_execution_info
|
||||
|
||||
# get route chunk node execution info
|
||||
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
|
||||
if (route_chunk_node_execution_info.node_type == NodeType.LLM
|
||||
and latest_node_execution_info.node_type == NodeType.LLM):
|
||||
# only LLM support chunk stream output
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
continue
|
||||
|
||||
# get route chunk node execution
|
||||
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id
|
||||
).first()
|
||||
|
||||
outputs = route_chunk_node_execution.outputs_dict
|
||||
|
||||
# get value from outputs
|
||||
value = None
|
||||
for key in value_selector[1:]:
|
||||
if not value:
|
||||
value = outputs.get(key) if outputs else None
|
||||
else:
|
||||
value = value.get(key)
|
||||
|
||||
if value is not None:
|
||||
text = ''
|
||||
if isinstance(value, str | int | float):
|
||||
text = str(value)
|
||||
elif isinstance(value, FileVar):
|
||||
# convert file to markdown
|
||||
text = value.to_markdown()
|
||||
elif isinstance(value, dict):
|
||||
# handle files
|
||||
file_vars = self._fetch_files_from_variable_value(value)
|
||||
if file_vars:
|
||||
file_var = file_vars[0]
|
||||
try:
|
||||
file_var_obj = FileVar(**file_var)
|
||||
|
||||
# convert file to markdown
|
||||
text = file_var_obj.to_markdown()
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating file var: {e}')
|
||||
|
||||
if not text:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
elif isinstance(value, list):
|
||||
# handle files
|
||||
file_vars = self._fetch_files_from_variable_value(value)
|
||||
for file_var in file_vars:
|
||||
try:
|
||||
file_var_obj = FileVar(**file_var)
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating file var: {e}')
|
||||
continue
|
||||
|
||||
# convert file to markdown
|
||||
text = file_var_obj.to_markdown() + ' '
|
||||
|
||||
text = text.strip()
|
||||
|
||||
if not text and value:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
|
||||
if text:
|
||||
self._task_state.answer += text
|
||||
yield self._message_to_stream_response(text, self._message.id)
|
||||
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.metadata:
|
||||
return True
|
||||
|
||||
if 'node_id' not in event.metadata:
|
||||
return True
|
||||
|
||||
node_type = event.metadata.get('node_type')
|
||||
stream_output_value_selector = event.metadata.get('value_selector')
|
||||
if not stream_output_value_selector:
|
||||
return False
|
||||
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return False
|
||||
|
||||
route_chunk = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position]
|
||||
|
||||
if route_chunk.type != 'var':
|
||||
return False
|
||||
|
||||
if node_type != NodeType.LLM:
|
||||
# only LLM support chunk stream output
|
||||
return False
|
||||
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
"""
|
||||
Handle output moderation chunk.
|
||||
@ -569,12 +782,3 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._output_moderation_handler.append_new_token(text)
|
||||
|
||||
return False
|
||||
|
||||
def _refetch_message(self) -> None:
|
||||
"""
|
||||
Refetch message.
|
||||
:return:
|
||||
"""
|
||||
message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
if message:
|
||||
self._message = message
|
||||
|
||||
@ -0,0 +1,203 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowStartedEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowSucceededEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFailedEvent(
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeStartedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
node_run_index=node_run_index,
|
||||
predecessor_node_id=predecessor_node_id
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeSucceededEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFailedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text,
|
||||
metadata={
|
||||
"node_id": node_id,
|
||||
**metadata
|
||||
}
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationStartEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
metadata=metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any]) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self._queue_manager._publish(
|
||||
QueueIterationNextEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
index=index,
|
||||
node_run_index=node_run_index,
|
||||
output=output
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self._queue_manager._publish(
|
||||
QueueIterationCompletedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
outputs=outputs
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
event,
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
@ -3,7 +3,7 @@ import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import Any, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -28,24 +28,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
) -> Generator[dict, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
|
||||
@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
def convert(cls, response: Union[
|
||||
AppBlockingResponse,
|
||||
Generator[AppStreamResponse, Any, None]
|
||||
], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
], invoke_from: InvokeFrom):
|
||||
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -347,7 +347,7 @@ class AppRunner:
|
||||
self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
inputs: dict,
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> tuple[bool, dict, str]:
|
||||
|
||||
@ -3,7 +3,7 @@ import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import Any, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -28,31 +28,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import Any, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -30,30 +30,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[str, None, None]]:
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -221,7 +203,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[str, None, None]]:
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -32,40 +32,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
@ -77,7 +51,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param call_depth: call depth
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
inputs = args['inputs']
|
||||
|
||||
@ -125,19 +98,16 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
stream=stream,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self, *,
|
||||
app_model: App,
|
||||
self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> dict[str, Any] | Generator[str, None, None]:
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -147,7 +117,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
@ -159,11 +128,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'context': contextvars.copy_context(),
|
||||
'workflow_thread_pool_id': workflow_thread_pool_id
|
||||
'context': contextvars.copy_context()
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
@ -187,7 +155,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
stream: bool = True):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -204,6 +172,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
@ -219,9 +191,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={
|
||||
"auto_generate_conversation_name": False
|
||||
},
|
||||
extras=extras,
|
||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id,
|
||||
inputs=args['inputs']
|
||||
@ -241,14 +211,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
workflow_thread_pool_id: Optional[str] = None) -> None:
|
||||
context: contextvars.Context) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
@ -256,13 +224,22 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# workflow app
|
||||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id
|
||||
)
|
||||
|
||||
runner.run()
|
||||
runner = WorkflowAppRunner()
|
||||
if application_generate_entity.single_iteration_run:
|
||||
single_iteration_run = application_generate_entity.single_iteration_run
|
||||
runner.single_iteration_run(
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=application_generate_entity.app_config.workflow_id,
|
||||
queue_manager=queue_manager,
|
||||
inputs=single_iteration_run.inputs,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_id=application_generate_entity.user_id
|
||||
)
|
||||
else:
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
@ -274,14 +251,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true':
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.close()
|
||||
db.session.remove()
|
||||
|
||||
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
|
||||
@ -4,61 +4,46 @@ from typing import Optional, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import WorkflowType
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
class WorkflowAppRunner:
|
||||
"""
|
||||
Workflow Application Runner
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.queue_manager = queue_manager
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
|
||||
def run(self) -> None:
|
||||
def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:return:
|
||||
"""
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
@ -68,64 +53,80 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
files = application_generate_entity.files
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
workflow_callbacks: list[WorkflowCallback] = [
|
||||
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
|
||||
]
|
||||
|
||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs
|
||||
)
|
||||
else:
|
||||
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
workflow=workflow,
|
||||
user_id=application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(
|
||||
callbacks=workflow_callbacks
|
||||
def single_iteration_run(
|
||||
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Single iteration run
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
|
||||
if not app_record.workflow_id:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
|
||||
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.single_step_run_iteration_workflow_node(
|
||||
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
@ -16,12 +15,10 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
@ -35,16 +32,19 @@ from core.app.entities.task_entities import (
|
||||
MessageAudioStreamResponse,
|
||||
StreamResponse,
|
||||
TextChunkStreamResponse,
|
||||
TextReplaceStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowStreamGenerateNodes,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
@ -52,8 +52,8 @@ from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -68,6 +68,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
@ -95,7 +96,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
SystemVariableKey.USER_ID: user_id
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._task_state = WorkflowTaskState(
|
||||
iteration_nested_node_ids=[]
|
||||
)
|
||||
self._stream_generate_nodes = self._get_stream_generate_nodes()
|
||||
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@ -124,20 +129,23 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id=stream_response.data.id,
|
||||
workflow_id=stream_response.data.workflow_id,
|
||||
status=stream_response.data.status,
|
||||
outputs=stream_response.data.outputs,
|
||||
error=stream_response.data.error,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=int(stream_response.data.created_at),
|
||||
finished_at=int(stream_response.data.finished_at)
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
status=workflow_run.status,
|
||||
outputs=workflow_run.outputs_dict,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=workflow_run.elapsed_time,
|
||||
total_tokens=workflow_run.total_tokens,
|
||||
total_steps=workflow_run.total_steps,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp())
|
||||
)
|
||||
)
|
||||
|
||||
@ -153,13 +161,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
To stream response.
|
||||
:return:
|
||||
"""
|
||||
workflow_run_id = None
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, WorkflowStartStreamResponse):
|
||||
workflow_run_id = stream_response.workflow_run_id
|
||||
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
@ -174,18 +178,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
tts_publisher = None
|
||||
publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@ -195,9 +198,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
start_listener_time = time.time()
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not tts_publisher:
|
||||
if not publisher:
|
||||
break
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
@ -215,159 +218,69 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
graph_runtime_state = None
|
||||
workflow_run = None
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event)
|
||||
yield self._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start()
|
||||
workflow_run = self._handle_workflow_start()
|
||||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
workflow_node_execution = self._handle_node_start(event)
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
# search stream_generate_routes if node id is answer start at node
|
||||
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes:
|
||||
self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id]
|
||||
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
# generate stream outputs when node started
|
||||
yield from self._generate_stream_outputs_when_node_started()
|
||||
|
||||
yield self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_node_finished(event)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
yield self._workflow_node_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
if isinstance(event, QueueNodeFailedEvent):
|
||||
yield from self._handle_iteration_exception(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
error=f'Child node failed: {event.error}'
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
|
||||
if isinstance(event, QueueIterationNextEvent):
|
||||
# clear ran node execution infos of current iteration
|
||||
iteration_relations = self._iteration_nested_relations.get(event.node_id)
|
||||
if iteration_relations:
|
||||
for node_id in iteration_relations:
|
||||
self._task_state.ran_node_execution_infos.pop(node_id, None)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED,
|
||||
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
|
||||
self._handle_iteration_operation(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
workflow_run = self._handle_workflow_finished(
|
||||
event, trace_manager=trace_manager
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
@ -382,17 +295,22 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
if not self._is_stream_out_support(
|
||||
event=event
|
||||
):
|
||||
continue
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(delta_text)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._text_replace_to_stream_response(event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
|
||||
|
||||
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
|
||||
@ -411,15 +329,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# not save log for debugging
|
||||
return
|
||||
|
||||
workflow_app_log = WorkflowAppLog()
|
||||
workflow_app_log.tenant_id = workflow_run.tenant_id
|
||||
workflow_app_log.app_id = workflow_run.app_id
|
||||
workflow_app_log.workflow_id = workflow_run.workflow_id
|
||||
workflow_app_log.workflow_run_id = workflow_run.id
|
||||
workflow_app_log.created_from = created_from.value
|
||||
workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user'
|
||||
workflow_app_log.created_by = self._user.id
|
||||
|
||||
workflow_app_log = WorkflowAppLog(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from=created_from.value,
|
||||
created_by_role=('account' if isinstance(self._user, Account) else 'end_user'),
|
||||
created_by=self._user.id,
|
||||
)
|
||||
db.session.add(workflow_app_log)
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
@ -436,3 +354,180 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse:
|
||||
"""
|
||||
Text replace to stream response.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
return TextReplaceStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
text=TextReplaceStreamResponse.Data(text=text)
|
||||
)
|
||||
|
||||
def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]:
|
||||
"""
|
||||
Get stream generate nodes.
|
||||
:return:
|
||||
"""
|
||||
# find all answer nodes
|
||||
graph = self._workflow.graph_dict
|
||||
end_node_configs = [
|
||||
node for node in graph['nodes']
|
||||
if node.get('data', {}).get('type') == NodeType.END.value
|
||||
]
|
||||
|
||||
# parse stream output node value selectors of end nodes
|
||||
stream_generate_routes = {}
|
||||
for node_config in end_node_configs:
|
||||
# get generate route for stream output
|
||||
end_node_id = node_config['id']
|
||||
generate_nodes = EndNode.extract_generate_nodes(graph, node_config)
|
||||
start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id)
|
||||
if not start_node_ids:
|
||||
continue
|
||||
|
||||
for start_node_id in start_node_ids:
|
||||
stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes(
|
||||
end_node_id=end_node_id,
|
||||
stream_node_ids=generate_nodes
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
"""
|
||||
Get end start at node id.
|
||||
:param graph: graph
|
||||
:param target_node_id: target node ID
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
edges = graph.get('edges')
|
||||
|
||||
# fetch all ingoing edges from source node
|
||||
ingoing_edges = []
|
||||
for edge in edges:
|
||||
if edge.get('target') == target_node_id:
|
||||
ingoing_edges.append(edge)
|
||||
|
||||
if not ingoing_edges:
|
||||
return []
|
||||
|
||||
start_node_ids = []
|
||||
for ingoing_edge in ingoing_edges:
|
||||
source_node_id = ingoing_edge.get('source')
|
||||
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
node_iteration_id = source_node.get('data', {}).get('iteration_id')
|
||||
iteration_start_node_id = None
|
||||
if node_iteration_id:
|
||||
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
|
||||
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
|
||||
|
||||
if node_type in [
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value
|
||||
]:
|
||||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value or \
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id)
|
||||
if sub_start_node_ids:
|
||||
start_node_ids.extend(sub_start_node_ids)
|
||||
|
||||
return start_node_ids
|
||||
|
||||
def _generate_stream_outputs_when_node_started(self) -> Generator:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if self._task_state.current_stream_generate_state:
|
||||
stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids
|
||||
|
||||
for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items():
|
||||
if node_id not in stream_node_ids:
|
||||
continue
|
||||
|
||||
node_execution_info = self._task_state.ran_node_execution_infos[node_id]
|
||||
|
||||
# get chunk node execution
|
||||
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first()
|
||||
|
||||
if not route_chunk_node_execution:
|
||||
continue
|
||||
|
||||
outputs = route_chunk_node_execution.outputs_dict
|
||||
|
||||
if not outputs:
|
||||
continue
|
||||
|
||||
# get value from outputs
|
||||
text = outputs.get('text')
|
||||
|
||||
if text:
|
||||
self._task_state.answer += text
|
||||
yield self._text_chunk_to_stream_response(text)
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.metadata:
|
||||
return False
|
||||
|
||||
if 'node_id' not in event.metadata:
|
||||
return False
|
||||
|
||||
node_id = event.metadata.get('node_id')
|
||||
node_type = event.metadata.get('node_type')
|
||||
stream_output_value_selector = event.metadata.get('value_selector')
|
||||
if not stream_output_value_selector:
|
||||
return False
|
||||
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return False
|
||||
|
||||
if node_id not in self._task_state.current_stream_generate_state.stream_node_ids:
|
||||
return False
|
||||
|
||||
if node_type != NodeType.LLM:
|
||||
# only LLM support chunk stream output
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get iteration nested relations.
|
||||
:param graph: graph
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
|
||||
iteration_ids = [node.get('id') for node in nodes
|
||||
if node.get('data', {}).get('type') in [
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value,
|
||||
]]
|
||||
|
||||
return {
|
||||
iteration_id: [
|
||||
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||
] for iteration_id in iteration_ids
|
||||
}
|
||||
|
||||
200
api/core/app/apps/workflow/workflow_event_trigger_callback.py
Normal file
200
api/core/app/apps/workflow/workflow_event_trigger_callback.py
Normal file
@ -0,0 +1,200 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowStartedEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowSucceededEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFailedEvent(
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeStartedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
node_run_index=node_run_index,
|
||||
predecessor_node_id=predecessor_node_id
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeSucceededEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFailedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text,
|
||||
metadata={
|
||||
"node_id": node_id,
|
||||
**metadata
|
||||
}
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationStartEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
metadata=metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any]) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationNextEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
index=index,
|
||||
node_run_index=node_run_index,
|
||||
output=output
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationCompletedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
outputs=outputs
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
pass
|
||||
@ -1,374 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner(AppRunner):
|
||||
def __init__(self, queue_manager: AppQueueManager):
|
||||
self.queue_manager = queue_manager
|
||||
|
||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
"""
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
return graph
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
|
||||
# filter nodes only in iteration
|
||||
node_configs = [
|
||||
node for node in graph_config.get('nodes', [])
|
||||
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
|
||||
]
|
||||
|
||||
graph_config['nodes'] = node_configs
|
||||
|
||||
node_ids = [node.get('id') for node in node_configs]
|
||||
|
||||
# filter edges only in iteration
|
||||
edge_configs = [
|
||||
edge for edge in graph_config.get('edges', [])
|
||||
if (edge.get('source') is None or edge.get('source') in node_ids)
|
||||
and (edge.get('target') is None or edge.get('target') in node_ids)
|
||||
]
|
||||
|
||||
graph_config['edges'] = edge_configs
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=node_id
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
# fetch node config from node id
|
||||
iteration_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get('id') == node_id:
|
||||
iteration_node_config = node
|
||||
break
|
||||
|
||||
if not iteration_node_config:
|
||||
raise ValueError('iteration node id not found in workflow graph')
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict,
|
||||
config=iteration_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=IterationNodeData(**iteration_node_config.get('data', {}))
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Handle event
|
||||
:param workflow_entry: workflow entry
|
||||
:param event: event
|
||||
"""
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowStartedEvent(
|
||||
graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state
|
||||
)
|
||||
)
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowSucceededEvent(outputs=event.outputs)
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowFailedEvent(error=event.error)
|
||||
)
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeStartedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result else {},
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result
|
||||
and event.route_node_state.node_run_result.error
|
||||
else "Unknown error"
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
self._publish_event(
|
||||
QueueRetrieverResourcesEvent(
|
||||
retriever_resources=event.retriever_resources
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunSucceededEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunFailedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueIterationStartEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
metadata=event.metadata
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self._publish_event(
|
||||
QueueIterationNextEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueIterationCompletedEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None
|
||||
)
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def _publish_event(self, event: AppQueueEvent) -> None:
|
||||
self.queue_manager.publish(
|
||||
event,
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
@ -1,24 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.app.entities.queue_entities import AppQueueEvent
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
@ -34,203 +20,127 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
def __init__(self) -> None:
|
||||
self.current_node_id = None
|
||||
|
||||
def on_event(
|
||||
self,
|
||||
event: GraphEngineEvent
|
||||
) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self.print_text("\n[GraphRunStartedEvent]", color='pink')
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.print_text("\n[GraphRunSucceededEvent]", color='green')
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red')
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.on_workflow_node_execute_started(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self.on_workflow_node_execute_succeeded(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self.on_workflow_node_execute_failed(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self.on_node_text_chunk(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self.on_workflow_parallel_started(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
|
||||
self.on_workflow_parallel_completed(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self.on_workflow_iteration_started(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self.on_workflow_iteration_next(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
|
||||
self.on_workflow_iteration_completed(
|
||||
event=event
|
||||
)
|
||||
else:
|
||||
self.print_text(f"\n[{event.__class__.__name__}]", color='blue')
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self.print_text("\n[on_workflow_run_started]", color='pink')
|
||||
|
||||
def on_workflow_node_execute_started(
|
||||
self,
|
||||
event: NodeRunStartedEvent
|
||||
) -> None:
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self.print_text("\n[on_workflow_run_succeeded]", color='green')
|
||||
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self.print_text("\n[on_workflow_run_failed]", color='red')
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self.print_text("\n[NodeRunStartedEvent]", color='yellow')
|
||||
self.print_text(f"Node ID: {event.node_id}", color='yellow')
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color='yellow')
|
||||
self.print_text(f"Type: {event.node_type.value}", color='yellow')
|
||||
self.print_text("\n[on_workflow_node_execute_started]", color='yellow')
|
||||
self.print_text(f"Node ID: {node_id}", color='yellow')
|
||||
self.print_text(f"Type: {node_type.value}", color='yellow')
|
||||
self.print_text(f"Index: {node_run_index}", color='yellow')
|
||||
if predecessor_node_id:
|
||||
self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow')
|
||||
|
||||
def on_workflow_node_execute_succeeded(
|
||||
self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> None:
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
self.print_text("\n[on_workflow_node_execute_succeeded]", color='green')
|
||||
self.print_text(f"Node ID: {node_id}", color='green')
|
||||
self.print_text(f"Type: {node_type.value}", color='green')
|
||||
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green')
|
||||
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green')
|
||||
self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}",
|
||||
color='green')
|
||||
|
||||
self.print_text("\n[NodeRunSucceededEvent]", color='green')
|
||||
self.print_text(f"Node ID: {event.node_id}", color='green')
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color='green')
|
||||
self.print_text(f"Type: {event.node_type.value}", color='green')
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color='green')
|
||||
self.print_text(
|
||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color='green')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color='green')
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
|
||||
color='green')
|
||||
|
||||
def on_workflow_node_execute_failed(
|
||||
self,
|
||||
event: NodeRunFailedEvent
|
||||
) -> None:
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
self.print_text("\n[on_workflow_node_execute_failed]", color='red')
|
||||
self.print_text(f"Node ID: {node_id}", color='red')
|
||||
self.print_text(f"Type: {node_type.value}", color='red')
|
||||
self.print_text(f"Error: {error}", color='red')
|
||||
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red')
|
||||
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red')
|
||||
|
||||
self.print_text("\n[NodeRunFailedEvent]", color='red')
|
||||
self.print_text(f"Node ID: {event.node_id}", color='red')
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color='red')
|
||||
self.print_text(f"Type: {event.node_type.value}", color='red')
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(f"Error: {node_run_result.error}", color='red')
|
||||
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color='red')
|
||||
self.print_text(
|
||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color='red')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color='red')
|
||||
|
||||
def on_node_text_chunk(
|
||||
self,
|
||||
event: NodeRunStreamChunkEvent
|
||||
) -> None:
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
|
||||
self.current_node_id = route_node_state.node_id
|
||||
self.print_text('\n[NodeRunStreamChunkEvent]')
|
||||
self.print_text(f"Node ID: {route_node_state.node_id}")
|
||||
if not self.current_node_id or self.current_node_id != node_id:
|
||||
self.current_node_id = node_id
|
||||
self.print_text('\n[on_node_text_chunk]')
|
||||
self.print_text(f"Node ID: {node_id}")
|
||||
self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}")
|
||||
|
||||
node_run_result = route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}")
|
||||
self.print_text(text, color="pink", end="")
|
||||
|
||||
self.print_text(event.chunk_content, color="pink", end="")
|
||||
|
||||
def on_workflow_parallel_started(
|
||||
self,
|
||||
event: ParallelBranchRunStartedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish parallel started
|
||||
"""
|
||||
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
|
||||
|
||||
def on_workflow_parallel_completed(
|
||||
self,
|
||||
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish parallel completed
|
||||
"""
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
color = 'blue'
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
color = 'red'
|
||||
|
||||
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
|
||||
|
||||
if isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self.print_text(f"Error: {event.error}", color=color)
|
||||
|
||||
def on_workflow_iteration_started(
|
||||
self,
|
||||
event: IterationRunStartedEvent
|
||||
) -> None:
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self.print_text("\n[IterationRunStartedEvent]", color='blue')
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
|
||||
self.print_text("\n[on_workflow_iteration_started]", color='blue')
|
||||
self.print_text(f"Node ID: {node_id}", color='blue')
|
||||
|
||||
def on_workflow_iteration_next(
|
||||
self,
|
||||
event: IterationRunNextEvent
|
||||
) -> None:
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[dict]) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self.print_text("\n[IterationRunNextEvent]", color='blue')
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
|
||||
self.print_text(f"Iteration Index: {event.index}", color='blue')
|
||||
self.print_text("\n[on_workflow_iteration_next]", color='blue')
|
||||
|
||||
def on_workflow_iteration_completed(
|
||||
self,
|
||||
event: IterationRunSucceededEvent | IterationRunFailedEvent
|
||||
) -> None:
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
|
||||
self.print_text("\n[on_workflow_iteration_completed]", color='blue')
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
self.print_text("\n[on_workflow_event]", color='blue')
|
||||
self.print_text(f"Event: {jsonable_encoder(event)}", color='blue')
|
||||
|
||||
def print_text(
|
||||
self, text: str, color: Optional[str] = None, end: str = "\n"
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -6,8 +5,7 @@ from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
|
||||
|
||||
class QueueEvent(str, Enum):
|
||||
@ -33,9 +31,6 @@ class QueueEvent(str, Enum):
|
||||
ANNOTATION_REPLY = "annotation_reply"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
MESSAGE_FILE = "message_file"
|
||||
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
||||
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
||||
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
@ -43,7 +38,7 @@ class QueueEvent(str, Enum):
|
||||
|
||||
class AppQueueEvent(BaseModel):
|
||||
"""
|
||||
QueueEvent abstract entity
|
||||
QueueEvent entity
|
||||
"""
|
||||
event: QueueEvent
|
||||
|
||||
@ -51,7 +46,6 @@ class AppQueueEvent(BaseModel):
|
||||
class QueueLLMChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueLLMChunkEvent entity
|
||||
Only for basic mode apps
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.LLM_CHUNK
|
||||
chunk: LLMResultChunk
|
||||
@ -61,24 +55,14 @@ class QueueIterationStartEvent(AppQueueEvent):
|
||||
QueueIterationStartEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.ITERATION_START
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
inputs: dict = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
class QueueIterationNextEvent(AppQueueEvent):
|
||||
"""
|
||||
@ -87,18 +71,8 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||
event: QueueEvent = QueueEvent.ITERATION_NEXT
|
||||
|
||||
index: int
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
@ -119,30 +93,13 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueIterationCompletedEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.ITERATION_COMPLETED
|
||||
event:QueueEvent = QueueEvent.ITERATION_COMPLETED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
steps: int = 0
|
||||
|
||||
error: Optional[str] = None
|
||||
|
||||
outputs: dict
|
||||
|
||||
class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
@ -150,8 +107,7 @@ class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.TEXT_CHUNK
|
||||
text: str
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
"""from variable selector"""
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
@ -206,7 +162,6 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
|
||||
QueueWorkflowStartedEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
|
||||
graph_runtime_state: GraphRuntimeState
|
||||
|
||||
|
||||
class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||
@ -214,7 +169,6 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||
QueueWorkflowSucceededEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class QueueWorkflowFailedEvent(AppQueueEvent):
|
||||
@ -231,21 +185,11 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.NODE_STARTED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_run_index: int = 1
|
||||
predecessor_node_id: Optional[str] = None
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
@ -254,24 +198,14 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.NODE_SUCCEEDED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
inputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
execution_metadata: Optional[dict] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
|
||||
@ -282,23 +216,13 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
inputs: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
|
||||
error: str
|
||||
|
||||
@ -350,23 +274,10 @@ class QueueStopEvent(AppQueueEvent):
|
||||
event: QueueEvent = QueueEvent.STOP
|
||||
stopped_by: StopBy
|
||||
|
||||
def get_stop_reason(self) -> str:
|
||||
"""
|
||||
To stop reason
|
||||
"""
|
||||
reason_mapping = {
|
||||
QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.',
|
||||
QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.',
|
||||
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
|
||||
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
|
||||
}
|
||||
|
||||
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
|
||||
|
||||
|
||||
class QueueMessage(BaseModel):
|
||||
"""
|
||||
QueueMessage abstract entity
|
||||
QueueMessage entity
|
||||
"""
|
||||
task_id: str
|
||||
app_mode: str
|
||||
@ -386,52 +297,3 @@ class WorkflowQueueMessage(QueueMessage):
|
||||
WorkflowQueueMessage entity
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunStartedEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunSucceededEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunFailedEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
error: str
|
||||
|
||||
@ -3,11 +3,40 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class WorkflowStreamGenerateNodes(BaseModel):
|
||||
"""
|
||||
WorkflowStreamGenerateNodes entity
|
||||
"""
|
||||
end_node_id: str
|
||||
stream_node_ids: list[str]
|
||||
|
||||
|
||||
class ChatflowStreamGenerateRoute(BaseModel):
|
||||
"""
|
||||
ChatflowStreamGenerateRoute entity
|
||||
"""
|
||||
answer_node_id: str
|
||||
generate_route: list[GenerateRouteChunk]
|
||||
current_route_position: int = 0
|
||||
|
||||
|
||||
class NodeExecutionInfo(BaseModel):
|
||||
"""
|
||||
NodeExecutionInfo entity
|
||||
"""
|
||||
workflow_node_execution_id: str
|
||||
node_type: NodeType
|
||||
start_at: float
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
"""
|
||||
TaskState entity
|
||||
@ -28,6 +57,27 @@ class WorkflowTaskState(TaskState):
|
||||
"""
|
||||
answer: str = ""
|
||||
|
||||
workflow_run_id: Optional[str] = None
|
||||
start_at: Optional[float] = None
|
||||
total_tokens: int = 0
|
||||
total_steps: int = 0
|
||||
|
||||
ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
|
||||
latest_node_execution_info: Optional[NodeExecutionInfo] = None
|
||||
|
||||
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
|
||||
|
||||
iteration_nested_node_ids: list[str] = None
|
||||
|
||||
|
||||
class AdvancedChatTaskState(WorkflowTaskState):
|
||||
"""
|
||||
AdvancedChatTaskState entity
|
||||
"""
|
||||
usage: LLMUsage
|
||||
|
||||
current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None
|
||||
|
||||
|
||||
class StreamEvent(Enum):
|
||||
"""
|
||||
@ -47,8 +97,6 @@ class StreamEvent(Enum):
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
@ -219,10 +267,6 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
inputs: Optional[dict] = None
|
||||
created_at: int
|
||||
extras: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
workflow_run_id: str
|
||||
@ -242,11 +286,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
"predecessor_node_id": self.data.predecessor_node_id,
|
||||
"inputs": None,
|
||||
"created_at": self.data.created_at,
|
||||
"extras": {},
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"extras": {}
|
||||
}
|
||||
}
|
||||
|
||||
@ -276,10 +316,6 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[list[dict]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_FINISHED
|
||||
workflow_run_id: str
|
||||
@ -306,57 +342,9 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"execution_metadata": None,
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": [],
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"files": []
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class ParallelBranchFinishedStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchFinishedStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeStartStreamResponse(StreamResponse):
|
||||
@ -376,8 +364,6 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
extras: dict = {}
|
||||
metadata: dict = {}
|
||||
inputs: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_STARTED
|
||||
workflow_run_id: str
|
||||
@ -401,8 +387,6 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
created_at: int
|
||||
pre_iteration_output: Optional[Any] = None
|
||||
extras: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
@ -424,8 +408,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
title: str
|
||||
outputs: Optional[dict] = None
|
||||
created_at: int
|
||||
extras: Optional[dict] = None
|
||||
inputs: Optional[dict] = None
|
||||
extras: dict = None
|
||||
inputs: dict = None
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
@ -433,8 +417,6 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
execution_metadata: Optional[dict] = None
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
|
||||
workflow_run_id: str
|
||||
@ -506,7 +488,7 @@ class WorkflowAppStreamResponse(AppStreamResponse):
|
||||
"""
|
||||
WorkflowAppStreamResponse entity
|
||||
"""
|
||||
workflow_run_id: Optional[str] = None
|
||||
workflow_run_id: str
|
||||
|
||||
|
||||
class AppBlockingResponse(BaseModel):
|
||||
@ -580,3 +562,25 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class WorkflowIterationState(BaseModel):
|
||||
"""
|
||||
WorkflowIterationState entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
parent_iteration_id: Optional[str] = None
|
||||
iteration_id: str
|
||||
current_index: int
|
||||
iteration_steps_boundary: list[int] = None
|
||||
node_execution_id: str
|
||||
started_at: float
|
||||
inputs: dict = None
|
||||
total_tokens: int = 0
|
||||
node_data: BaseNodeData
|
||||
|
||||
current_iterations: dict[str, Data] = None
|
||||
|
||||
@ -68,18 +68,16 @@ class BasedGenerateTaskPipeline:
|
||||
err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||
|
||||
if message:
|
||||
refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
message = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
err_desc = self._error_to_desc(err)
|
||||
message.status = 'error'
|
||||
message.error = err_desc
|
||||
|
||||
if refetch_message:
|
||||
err_desc = self._error_to_desc(err)
|
||||
refetch_message.status = 'error'
|
||||
refetch_message.error = err_desc
|
||||
|
||||
db.session.commit()
|
||||
db.session.commit()
|
||||
|
||||
return err
|
||||
|
||||
def _error_to_desc(self, e: Exception) -> str:
|
||||
def _error_to_desc(cls, e: Exception) -> str:
|
||||
"""
|
||||
Error to desc.
|
||||
:param e: exception
|
||||
|
||||
@ -8,6 +8,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
@ -15,11 +16,11 @@ from core.app.entities.queue_entities import (
|
||||
QueueRetrieverResourcesEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatTaskState,
|
||||
EasyUITaskState,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
MessageStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
@ -35,7 +36,7 @@ class MessageCycleManage:
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
]
|
||||
_task_state: Union[EasyUITaskState, WorkflowTaskState]
|
||||
_task_state: Union[EasyUITaskState, AdvancedChatTaskState]
|
||||
|
||||
def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]:
|
||||
"""
|
||||
@ -44,9 +45,6 @@ class MessageCycleManage:
|
||||
:param query: query
|
||||
:return: thread
|
||||
"""
|
||||
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
|
||||
return None
|
||||
|
||||
is_first_message = self._application_generate_entity.conversation_id is None
|
||||
extras = self._application_generate_entity.extras
|
||||
auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
|
||||
@ -54,7 +52,7 @@ class MessageCycleManage:
|
||||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
thread = Thread(target=self._generate_conversation_name_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'conversation_id': conversation.id,
|
||||
'query': query
|
||||
})
|
||||
@ -77,9 +75,6 @@ class MessageCycleManage:
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
return
|
||||
|
||||
if conversation.mode != AppMode.COMPLETION.value:
|
||||
app_model = conversation.app
|
||||
if not app_model:
|
||||
@ -126,13 +121,34 @@ class MessageCycleManage:
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
|
||||
def _get_response_metadata(self) -> dict:
|
||||
"""
|
||||
Get response metadata by invoke from.
|
||||
:return:
|
||||
"""
|
||||
metadata = {}
|
||||
|
||||
# show_retrieve_source
|
||||
if 'retriever_resources' in self._task_state.metadata:
|
||||
metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
|
||||
|
||||
# show annotation reply
|
||||
if 'annotation_reply' in self._task_state.metadata:
|
||||
metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
|
||||
|
||||
# show usage
|
||||
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
metadata['usage'] = self._task_state.metadata['usage']
|
||||
|
||||
return metadata
|
||||
|
||||
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
"""
|
||||
Message file to stream response.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
message_file = (
|
||||
message_file: MessageFile = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(MessageFile.id == event.message_file_id)
|
||||
.first()
|
||||
|
||||
@ -1,41 +1,33 @@
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeExecutionInfo,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
@ -49,56 +41,54 @@ from models.workflow import (
|
||||
WorkflowRunStatus,
|
||||
WorkflowRunTriggeredFrom,
|
||||
)
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class WorkflowCycleManage:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: WorkflowTaskState
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
|
||||
def _handle_workflow_run_start(self) -> WorkflowRun:
|
||||
max_sequence = (
|
||||
db.session.query(db.func.max(WorkflowRun.sequence_number))
|
||||
.filter(WorkflowRun.tenant_id == self._workflow.tenant_id)
|
||||
.filter(WorkflowRun.app_id == self._workflow.app_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
def _init_workflow_run(self, workflow: Workflow,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
user: Union[Account, EndUser],
|
||||
user_inputs: dict,
|
||||
system_inputs: Optional[dict] = None) -> WorkflowRun:
|
||||
"""
|
||||
Init workflow run
|
||||
:param workflow: Workflow instance
|
||||
:param triggered_from: triggered from
|
||||
:param user: account or end user
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:return:
|
||||
"""
|
||||
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
|
||||
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \
|
||||
.filter(WorkflowRun.app_id == workflow.app_id) \
|
||||
.scalar() or 0
|
||||
new_sequence_number = max_sequence + 1
|
||||
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
for key, value in (self._workflow_system_variables or {}).items():
|
||||
inputs = {**user_inputs}
|
||||
for key, value in (system_inputs or {}).items():
|
||||
if key.value == 'conversation':
|
||||
continue
|
||||
|
||||
inputs[f'sys.{key.value}'] = value
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(inputs)
|
||||
|
||||
triggered_from= (
|
||||
WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN
|
||||
)
|
||||
inputs = WorkflowEngineManager.handle_special_values(inputs)
|
||||
|
||||
# init workflow run
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.tenant_id = self._workflow.tenant_id
|
||||
workflow_run.app_id = self._workflow.app_id
|
||||
workflow_run.sequence_number = new_sequence_number
|
||||
workflow_run.workflow_id = self._workflow.id
|
||||
workflow_run.type = self._workflow.type
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = self._workflow.version
|
||||
workflow_run.graph = self._workflow.graph
|
||||
workflow_run.inputs = json.dumps(inputs)
|
||||
workflow_run.status = WorkflowRunStatus.RUNNING.value
|
||||
workflow_run.created_by_role = (
|
||||
CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value
|
||||
workflow_run = WorkflowRun(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
sequence_number=new_sequence_number,
|
||||
workflow_id=workflow.id,
|
||||
type=workflow.type,
|
||||
triggered_from=triggered_from.value,
|
||||
version=workflow.version,
|
||||
graph=workflow.graph,
|
||||
inputs=json.dumps(inputs),
|
||||
status=WorkflowRunStatus.RUNNING.value,
|
||||
created_by_role=(CreatedByRole.ACCOUNT.value
|
||||
if isinstance(user, Account) else CreatedByRole.END_USER.value),
|
||||
created_by=user.id
|
||||
)
|
||||
workflow_run.created_by = self._user.id
|
||||
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
@ -107,37 +97,33 @@ class WorkflowCycleManage:
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_success(
|
||||
self,
|
||||
workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
def _workflow_run_success(
|
||||
self, workflow_run: WorkflowRun,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run success
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param outputs: outputs
|
||||
:param conversation_id: conversation id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
||||
workflow_run.outputs = outputs
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
@ -149,58 +135,34 @@ class WorkflowCycleManage:
|
||||
)
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_failed(
|
||||
self,
|
||||
workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
def _workflow_run_failed(
|
||||
self, workflow_run: WorkflowRun,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
status: WorkflowRunStatus,
|
||||
error: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run failed
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param status: status
|
||||
:param error: error message
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
workflow_run.status = status.value
|
||||
workflow_run.error = error
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value
|
||||
).all()
|
||||
|
||||
for workflow_node_execution in running_workflow_node_executions:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds()
|
||||
db.session.commit()
|
||||
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
@ -216,24 +178,39 @@ class WorkflowCycleManage:
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
|
||||
def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_title: str,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Init workflow node execution from workflow run
|
||||
:param workflow_run: workflow run
|
||||
:param node_id: node id
|
||||
:param node_type: node type
|
||||
:param node_title: node title
|
||||
:param node_run_index: run index
|
||||
:param predecessor_node_id: predecessor node id if exists
|
||||
:return:
|
||||
"""
|
||||
# init workflow node execution
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
workflow_run_id=workflow_run.id,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
index=node_run_index,
|
||||
node_id=node_id,
|
||||
node_type=node_type.value,
|
||||
title=node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=workflow_run.created_by_role,
|
||||
created_by=workflow_run.created_by,
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
@ -242,26 +219,33 @@ class WorkflowCycleManage:
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
|
||||
start_at: float,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution success
|
||||
:param event: queue node succeeded event
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:param start_at: start time
|
||||
:param inputs: inputs
|
||||
:param process_data: process data
|
||||
:param outputs: outputs
|
||||
:param execution_metadata: execution metadata
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
inputs = WorkflowEngineManager.handle_special_values(inputs)
|
||||
outputs = WorkflowEngineManager.handle_special_values(outputs)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
|
||||
if execution_metadata else None
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
@ -269,24 +253,33 @@ class WorkflowCycleManage:
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
|
||||
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
|
||||
start_at: float,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:param start_at: start time
|
||||
:param error: error message
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
inputs = WorkflowEngineManager.handle_special_values(inputs)
|
||||
outputs = WorkflowEngineManager.handle_special_values(outputs)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
|
||||
if execution_metadata else None
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
@ -294,13 +287,8 @@ class WorkflowCycleManage:
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
#################################################
|
||||
# to stream responses #
|
||||
#################################################
|
||||
|
||||
def _workflow_start_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun
|
||||
) -> WorkflowStartStreamResponse:
|
||||
def _workflow_start_to_stream_response(self, task_id: str,
|
||||
workflow_run: WorkflowRun) -> WorkflowStartStreamResponse:
|
||||
"""
|
||||
Workflow start to stream response.
|
||||
:param task_id: task id
|
||||
@ -314,14 +302,13 @@ class WorkflowCycleManage:
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
inputs=workflow_run.inputs_dict or {},
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
),
|
||||
inputs=workflow_run.inputs_dict,
|
||||
created_at=int(workflow_run.created_at.timestamp())
|
||||
)
|
||||
)
|
||||
|
||||
def _workflow_finish_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
def _workflow_finish_to_stream_response(self, task_id: str,
|
||||
workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse:
|
||||
"""
|
||||
Workflow finish to stream response.
|
||||
:param task_id: task id
|
||||
@ -333,16 +320,16 @@ class WorkflowCycleManage:
|
||||
created_by_account = workflow_run.created_by_account
|
||||
if created_by_account:
|
||||
created_by = {
|
||||
'id': created_by_account.id,
|
||||
'name': created_by_account.name,
|
||||
'email': created_by_account.email,
|
||||
"id": created_by_account.id,
|
||||
"name": created_by_account.name,
|
||||
"email": created_by_account.email,
|
||||
}
|
||||
else:
|
||||
created_by_end_user = workflow_run.created_by_end_user
|
||||
if created_by_end_user:
|
||||
created_by = {
|
||||
'id': created_by_end_user.id,
|
||||
'user': created_by_end_user.session_id,
|
||||
"id": created_by_end_user.id,
|
||||
"user": created_by_end_user.session_id,
|
||||
}
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
@ -361,13 +348,14 @@ class WorkflowCycleManage:
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}),
|
||||
),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict)
|
||||
)
|
||||
)
|
||||
|
||||
def _workflow_node_start_to_stream_response(
|
||||
self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution
|
||||
) -> Optional[NodeStartStreamResponse]:
|
||||
def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution) \
|
||||
-> NodeStartStreamResponse:
|
||||
"""
|
||||
Workflow node start to stream response.
|
||||
:param event: queue node started event
|
||||
@ -375,9 +363,6 @@ class WorkflowCycleManage:
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
|
||||
return None
|
||||
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
@ -389,12 +374,8 @@ class WorkflowCycleManage:
|
||||
index=workflow_node_execution.index,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
),
|
||||
created_at=int(workflow_node_execution.created_at.timestamp())
|
||||
)
|
||||
)
|
||||
|
||||
# extras logic
|
||||
@ -403,27 +384,19 @@ class WorkflowCycleManage:
|
||||
response.data.extras['icon'] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id,
|
||||
provider_id=node_data.provider_id
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \
|
||||
-> NodeFinishStreamResponse:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param event: queue node succeeded or failed event
|
||||
:param task_id: task id
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
|
||||
return None
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
@ -443,154 +416,181 @@ class WorkflowCycleManage:
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_start_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
"""
|
||||
Workflow parallel branch start to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: parallel branch run started event
|
||||
:return:
|
||||
"""
|
||||
return ParallelBranchStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=ParallelBranchStartStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
created_at=int(time.time()),
|
||||
)
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_finished_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
"""
|
||||
Workflow parallel branch finished to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: parallel branch run succeeded or failed event
|
||||
:return:
|
||||
"""
|
||||
return ParallelBranchFinishedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=ParallelBranchFinishedStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
|
||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||
created_at=int(time.time()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict)
|
||||
)
|
||||
)
|
||||
|
||||
def _workflow_iteration_start_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueIterationStartEvent
|
||||
) -> IterationNodeStartStreamResponse:
|
||||
"""
|
||||
Workflow iteration start to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: iteration start event
|
||||
:return:
|
||||
"""
|
||||
return IterationNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
def _handle_workflow_start(self) -> WorkflowRun:
|
||||
self._task_state.start_at = time.perf_counter()
|
||||
|
||||
workflow_run = self._init_workflow_run(
|
||||
workflow=self._workflow,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN,
|
||||
user=self._user,
|
||||
user_inputs=self._application_generate_entity.inputs,
|
||||
system_inputs=self._workflow_system_variables
|
||||
)
|
||||
|
||||
def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse:
|
||||
"""
|
||||
Workflow iteration next to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: iteration next event
|
||||
:return:
|
||||
"""
|
||||
return IterationNodeNextStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
self._task_state.workflow_run_id = workflow_run.id
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
workflow_node_execution = self._init_node_execution_from_workflow_run(
|
||||
workflow_run=workflow_run,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_data.title,
|
||||
node_run_index=event.node_run_index,
|
||||
predecessor_node_id=event.predecessor_node_id
|
||||
)
|
||||
|
||||
def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse:
|
||||
"""
|
||||
Workflow iteration completed to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: iteration completed event
|
||||
:return:
|
||||
"""
|
||||
return IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=event.node_type,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
self._task_state.total_steps += 1
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
|
||||
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
|
||||
|
||||
execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None
|
||||
|
||||
if self._iteration_state and self._iteration_state.current_iterations:
|
||||
if not execution_metadata:
|
||||
execution_metadata = {}
|
||||
current_iteration_data = None
|
||||
for iteration_node_id in self._iteration_state.current_iterations:
|
||||
data = self._iteration_state.current_iterations[iteration_node_id]
|
||||
if data.parent_iteration_id == None:
|
||||
current_iteration_data = data
|
||||
break
|
||||
|
||||
if current_iteration_data:
|
||||
execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id
|
||||
execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index
|
||||
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._workflow_node_execution_success(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=current_node_execution.start_at,
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
outputs=event.outputs,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0,
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
execution_metadata=execution_metadata
|
||||
)
|
||||
)
|
||||
|
||||
if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
self._task_state.total_tokens += (
|
||||
int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
|
||||
|
||||
if self._iteration_state:
|
||||
for iteration_node_id in self._iteration_state.current_iterations:
|
||||
data = self._iteration_state.current_iterations[iteration_node_id]
|
||||
if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
|
||||
|
||||
if workflow_node_execution.node_type == NodeType.LLM.value:
|
||||
outputs = workflow_node_execution.outputs_dict
|
||||
usage_dict = outputs.get('usage', {})
|
||||
self._task_state.metadata['usage'] = usage_dict
|
||||
else:
|
||||
workflow_node_execution = self._workflow_node_execution_failed(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=current_node_execution.start_at,
|
||||
error=event.error,
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
outputs=event.outputs,
|
||||
execution_metadata=execution_metadata
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_finished(
|
||||
self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Optional[WorkflowRun]:
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
if not workflow_run:
|
||||
return None
|
||||
|
||||
if conversation_id is None:
|
||||
conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id')
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
error='Workflow stopped.',
|
||||
conversation_id=conversation_id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
latest_node_execution_info = self._task_state.latest_node_execution_info
|
||||
if latest_node_execution_info:
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first()
|
||||
if (workflow_node_execution
|
||||
and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value):
|
||||
self._workflow_node_execution_failed(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=latest_node_execution_info.start_at,
|
||||
error='Workflow stopped.'
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
error=event.error,
|
||||
conversation_id=conversation_id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
else:
|
||||
if self._task_state.latest_node_execution_info:
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
|
||||
outputs = workflow_node_execution.outputs
|
||||
else:
|
||||
outputs = None
|
||||
|
||||
workflow_run = self._workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
outputs=outputs,
|
||||
conversation_id=conversation_id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
self._task_state.workflow_run_id = workflow_run.id
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
|
||||
"""
|
||||
@ -647,40 +647,3 @@ class WorkflowCycleManage:
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
||||
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Refetch workflow run
|
||||
:param workflow_run_id: workflow run id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id == workflow_run_id).first()
|
||||
|
||||
if not workflow_run:
|
||||
raise Exception(f'Workflow run not found: {workflow_run_id}')
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Refetch workflow node execution
|
||||
:param node_execution_id: workflow node execution id
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id,
|
||||
WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id,
|
||||
WorkflowNodeExecution.workflow_id == self._workflow.id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not workflow_node_execution:
|
||||
raise Exception(f'Workflow node execution not found: {node_execution_id}')
|
||||
|
||||
return workflow_node_execution
|
||||
@ -0,0 +1,16 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowCycleStateManager:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
|
||||
290
api/core/app/task_pipeline/workflow_iteration_cycle_manage.py
Normal file
290
api/core/app/task_pipeline/workflow_iteration_cycle_manage.py
Normal file
@ -0,0 +1,290 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeExecutionInfo,
|
||||
WorkflowIterationState,
|
||||
)
|
||||
from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowIterationCycleManage(WorkflowCycleStateManager):
|
||||
_iteration_state: WorkflowIterationState = None
|
||||
|
||||
def _init_iteration_state(self) -> WorkflowIterationState:
|
||||
if not self._iteration_state:
|
||||
self._iteration_state = WorkflowIterationState(
|
||||
current_iterations={}
|
||||
)
|
||||
|
||||
def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \
|
||||
-> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]:
|
||||
"""
|
||||
Handle iteration to stream response
|
||||
:param task_id: task id
|
||||
:param event: iteration event
|
||||
:return:
|
||||
"""
|
||||
if isinstance(event, QueueIterationStartEvent):
|
||||
return IterationNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata
|
||||
)
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
|
||||
return IterationNodeNextStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=current_iteration.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={}
|
||||
)
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
|
||||
return IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=current_iteration.node_data.title,
|
||||
outputs=event.outputs,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=current_iteration.inputs,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=time.perf_counter() - current_iteration.started_at,
|
||||
total_tokens=current_iteration.total_tokens,
|
||||
execution_metadata={
|
||||
'total_tokens': current_iteration.total_tokens,
|
||||
},
|
||||
finished_at=int(time.time()),
|
||||
steps=current_iteration.current_index
|
||||
)
|
||||
)
|
||||
|
||||
def _init_iteration_execution_from_workflow_run(self,
|
||||
workflow_run: WorkflowRun,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_title: str,
|
||||
node_run_index: int = 1,
|
||||
inputs: Optional[dict] = None,
|
||||
predecessor_node_id: Optional[str] = None
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
workflow_run_id=workflow_run.id,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
index=node_run_index,
|
||||
node_id=node_id,
|
||||
node_type=node_type.value,
|
||||
inputs=json.dumps(inputs) if inputs else None,
|
||||
title=node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=workflow_run.created_by_role,
|
||||
created_by=workflow_run.created_by,
|
||||
execution_metadata=json.dumps({
|
||||
'started_run_index': node_run_index + 1,
|
||||
'current_index': 0,
|
||||
'steps_boundary': [],
|
||||
}),
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution:
|
||||
if isinstance(event, QueueIterationStartEvent):
|
||||
return self._handle_iteration_started(event)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
return self._handle_iteration_next(event)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
return self._handle_iteration_completed(event)
|
||||
|
||||
def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution:
|
||||
self._init_iteration_state()
|
||||
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
workflow_node_execution = self._init_iteration_execution_from_workflow_run(
|
||||
workflow_run=workflow_run,
|
||||
node_id=event.node_id,
|
||||
node_type=NodeType.ITERATION,
|
||||
node_title=event.node_data.title,
|
||||
node_run_index=event.node_run_index,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id
|
||||
)
|
||||
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=NodeType.ITERATION,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data(
|
||||
parent_iteration_id=None,
|
||||
iteration_id=event.node_id,
|
||||
current_index=0,
|
||||
iteration_steps_boundary=[],
|
||||
node_execution_id=workflow_node_execution.id,
|
||||
started_at=time.perf_counter(),
|
||||
inputs=event.inputs,
|
||||
total_tokens=0,
|
||||
node_data=event.node_data
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution:
|
||||
if event.node_id not in self._iteration_state.current_iterations:
|
||||
return
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
current_iteration.current_index = event.index
|
||||
current_iteration.iteration_steps_boundary.append(event.node_run_index)
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
|
||||
if original_node_execution_metadata:
|
||||
original_node_execution_metadata['current_index'] = event.index
|
||||
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
|
||||
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
|
||||
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _handle_iteration_completed(self, event: QueueIterationCompletedEvent):
|
||||
if event.node_id not in self._iteration_state.current_iterations:
|
||||
return
|
||||
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
|
||||
|
||||
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
|
||||
if original_node_execution_metadata:
|
||||
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
|
||||
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
|
||||
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# remove current iteration
|
||||
self._iteration_state.current_iterations.pop(event.node_id, None)
|
||||
|
||||
# set latest node execution info
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=NodeType.ITERATION,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]:
|
||||
"""
|
||||
Handle iteration exception
|
||||
"""
|
||||
if not self._iteration_state or not self._iteration_state.current_iterations:
|
||||
return
|
||||
|
||||
for node_id, current_iteration in self._iteration_state.current_iterations.items():
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
yield IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=node_id,
|
||||
node_id=node_id,
|
||||
node_type=NodeType.ITERATION.value,
|
||||
title=current_iteration.node_data.title,
|
||||
outputs={},
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=current_iteration.inputs,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
elapsed_time=time.perf_counter() - current_iteration.started_at,
|
||||
total_tokens=current_iteration.total_tokens,
|
||||
execution_metadata={
|
||||
'total_tokens': current_iteration.total_tokens,
|
||||
},
|
||||
finished_at=int(time.time()),
|
||||
steps=current_iteration.current_index
|
||||
)
|
||||
)
|
||||
@ -65,7 +65,7 @@ class Extensible:
|
||||
if os.path.exists(builtin_file_path):
|
||||
with open(builtin_file_path, encoding='utf-8') as f:
|
||||
position = int(f.read().strip())
|
||||
position_map[extension_name] = position
|
||||
position_map[extension_name] = position
|
||||
|
||||
if (extension_name + '.py') not in file_names:
|
||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||
|
||||
@ -16,7 +16,9 @@ from configs import dify_config
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
@ -253,8 +255,11 @@ class IndexingRunner:
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
total_price = 0
|
||||
currency = 'USD'
|
||||
index_type = doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
all_text_docs = []
|
||||
@ -281,22 +286,54 @@ class IndexingRunner:
|
||||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
if indexing_technique == 'high_quality' or embedding_model_instance:
|
||||
tokens += embedding_model_instance.get_text_embedding_num_tokens(
|
||||
texts=[self.filter_string(document.page_content)]
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||
doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
|
||||
price_info = model_type_instance.get_price(
|
||||
model=model_instance.model,
|
||||
credentials=model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=total_segments * 2000,
|
||||
)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(price_info.total_amount),
|
||||
"currency": price_info.currency,
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
|
||||
embedding_price_info = embedding_model_type_instance.get_price(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
total_price = '{:f}'.format(embedding_price_info.total_amount)
|
||||
currency = embedding_price_info.currency
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": total_price,
|
||||
"currency": currency,
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
|
||||
@ -63,39 +63,6 @@ class LLMUsage(ModelUsage):
|
||||
latency=0.0
|
||||
)
|
||||
|
||||
def plus(self, other: 'LLMUsage') -> 'LLMUsage':
|
||||
"""
|
||||
Add two LLMUsage instances together.
|
||||
|
||||
:param other: Another LLMUsage instance to add
|
||||
:return: A new LLMUsage instance with summed values
|
||||
"""
|
||||
if self.total_tokens == 0:
|
||||
return other
|
||||
else:
|
||||
return LLMUsage(
|
||||
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
||||
prompt_unit_price=other.prompt_unit_price,
|
||||
prompt_price_unit=other.prompt_price_unit,
|
||||
prompt_price=self.prompt_price + other.prompt_price,
|
||||
completion_tokens=self.completion_tokens + other.completion_tokens,
|
||||
completion_unit_price=other.completion_unit_price,
|
||||
completion_price_unit=other.completion_price_unit,
|
||||
completion_price=self.completion_price + other.completion_price,
|
||||
total_tokens=self.total_tokens + other.total_tokens,
|
||||
total_price=self.total_price + other.total_price,
|
||||
currency=other.currency,
|
||||
latency=self.latency + other.latency
|
||||
)
|
||||
|
||||
def __add__(self, other: 'LLMUsage') -> 'LLMUsage':
|
||||
"""
|
||||
Overload the + operator to add two LLMUsage instances.
|
||||
|
||||
:param other: Another LLMUsage instance to add
|
||||
:return: A new LLMUsage instance with summed values
|
||||
"""
|
||||
return self.plus(other)
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
"""
|
||||
|
||||
@ -150,9 +150,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
except json.JSONDecodeError as e:
|
||||
raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error')
|
||||
|
||||
if (completion_type is LLMMode.CHAT and json_result.get('object','') == ''):
|
||||
if (completion_type is LLMMode.CHAT and json_result['object'] == ''):
|
||||
json_result['object'] = 'chat.completion'
|
||||
elif (completion_type is LLMMode.COMPLETION and json_result.get('object','') == ''):
|
||||
elif (completion_type is LLMMode.COMPLETION and json_result['object'] == ''):
|
||||
json_result['object'] = 'text_completion'
|
||||
|
||||
if (completion_type is LLMMode.CHAT
|
||||
|
||||
@ -71,24 +71,11 @@ class ArkClientV3:
|
||||
args = {
|
||||
"base_url": credentials['api_endpoint_host'],
|
||||
"region": credentials['volc_region'],
|
||||
"ak": credentials['volc_access_key_id'],
|
||||
"sk": credentials['volc_secret_access_key'],
|
||||
}
|
||||
if credentials.get("auth_method") == "api_key":
|
||||
args = {
|
||||
**args,
|
||||
"api_key": credentials['volc_api_key'],
|
||||
}
|
||||
else:
|
||||
args = {
|
||||
**args,
|
||||
"ak": credentials['volc_access_key_id'],
|
||||
"sk": credentials['volc_secret_access_key'],
|
||||
}
|
||||
|
||||
if cls.is_compatible_with_legacy(credentials):
|
||||
args = {
|
||||
**args,
|
||||
"base_url": DEFAULT_V3_ENDPOINT
|
||||
}
|
||||
args["base_url"] = DEFAULT_V3_ENDPOINT
|
||||
|
||||
client = ArkClientV3(
|
||||
**args
|
||||
|
||||
@ -30,28 +30,8 @@ model_credential_schema:
|
||||
en_US: Enter your Model Name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: auth_method
|
||||
required: true
|
||||
label:
|
||||
en_US: Authentication Method
|
||||
zh_Hans: 鉴权方式
|
||||
type: select
|
||||
default: aksk
|
||||
options:
|
||||
- label:
|
||||
en_US: API Key
|
||||
value: api_key
|
||||
- label:
|
||||
en_US: Access Key / Secret Access Key
|
||||
value: aksk
|
||||
placeholder:
|
||||
en_US: Enter your Authentication Method
|
||||
zh_Hans: 选择鉴权方式
|
||||
- variable: volc_access_key_id
|
||||
required: true
|
||||
show_on:
|
||||
- variable: auth_method
|
||||
value: aksk
|
||||
label:
|
||||
en_US: Access Key
|
||||
zh_Hans: Access Key
|
||||
@ -61,9 +41,6 @@ model_credential_schema:
|
||||
zh_Hans: 输入您的 Access Key
|
||||
- variable: volc_secret_access_key
|
||||
required: true
|
||||
show_on:
|
||||
- variable: auth_method
|
||||
value: aksk
|
||||
label:
|
||||
en_US: Secret Access Key
|
||||
zh_Hans: Secret Access Key
|
||||
@ -71,17 +48,6 @@ model_credential_schema:
|
||||
placeholder:
|
||||
en_US: Enter your Secret Access Key
|
||||
zh_Hans: 输入您的 Secret Access Key
|
||||
- variable: volc_api_key
|
||||
required: true
|
||||
show_on:
|
||||
- variable: auth_method
|
||||
value: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: Enter your API Key
|
||||
zh_Hans: 输入您的 API Key
|
||||
- variable: volc_region
|
||||
required: true
|
||||
label:
|
||||
|
||||
@ -174,11 +174,6 @@ class XinferenceText2SpeechModel(TTSModel):
|
||||
return voices[language]
|
||||
elif 'all' in voices:
|
||||
return voices['all']
|
||||
else:
|
||||
all_voices = []
|
||||
for lang, lang_voices in voices.items():
|
||||
all_voices.extend(lang_voices)
|
||||
return all_voices
|
||||
|
||||
return self.model_voices['__default']['all']
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0'
|
||||
output: '0'
|
||||
input: '0.0001'
|
||||
output: '0.0001'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
||||
@ -37,8 +37,3 @@ parameter_rules:
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.001'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
||||
@ -37,8 +37,3 @@ parameter_rules:
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.1'
|
||||
output: '0.1'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
||||
@ -30,9 +30,4 @@ parameter_rules:
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.001'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
max: 4096
|
||||
|
||||
@ -1,44 +0,0 @@
|
||||
model: glm-4-plus
|
||||
label:
|
||||
en_US: glm-4-plus
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.95
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
|
||||
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.7
|
||||
help:
|
||||
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
|
||||
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
|
||||
- name: incremental
|
||||
label:
|
||||
zh_Hans: 增量返回
|
||||
en_US: Incremental
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。
|
||||
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.05'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@ -34,9 +34,4 @@ parameter_rules:
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 1024
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.05'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
max: 8192
|
||||
|
||||
@ -1,42 +0,0 @@
|
||||
model: glm-4v-plus
|
||||
label:
|
||||
en_US: glm-4v-plus
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
features:
|
||||
- vision
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.95
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
|
||||
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.7
|
||||
help:
|
||||
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
|
||||
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
|
||||
- name: incremental
|
||||
label:
|
||||
zh_Hans: 增量返回
|
||||
en_US: Incremental
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。
|
||||
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 1024
|
||||
pricing:
|
||||
input: '0.01'
|
||||
output: '0.01'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@ -153,8 +153,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
extra_model_kwargs = {}
|
||||
# request to glm-4v-plus with stop words will always response "finish_reason":"network_error"
|
||||
if stop and model!= 'glm-4v-plus':
|
||||
if stop:
|
||||
extra_model_kwargs['stop'] = stop
|
||||
|
||||
client = ZhipuAI(
|
||||
@ -175,7 +174,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]:
|
||||
if isinstance(copy_prompt_message.content, list):
|
||||
# check if model is 'glm-4v'
|
||||
if model not in ('glm-4v', 'glm-4v-plus'):
|
||||
if model != 'glm-4v':
|
||||
# not support list message
|
||||
continue
|
||||
# get image and
|
||||
@ -208,7 +207,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
else:
|
||||
new_prompt_messages.append(copy_prompt_message)
|
||||
|
||||
if model == 'glm-4v' or model == 'glm-4v-plus':
|
||||
if model == 'glm-4v':
|
||||
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
|
||||
else:
|
||||
params = {
|
||||
@ -305,7 +304,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
|
||||
return params
|
||||
|
||||
def _construct_glm_4v_messages(self, prompt_message: Union[str, list[PromptMessageContent]]) -> list[dict]:
|
||||
def _construct_glm_4v_messages(self, prompt_message: Union[str | list[PromptMessageContent]]) -> list[dict]:
|
||||
if isinstance(prompt_message, str):
|
||||
return [{'type': 'text', 'text': prompt_message}]
|
||||
|
||||
|
||||
@ -34,13 +34,13 @@ class OutputModeration(BaseModel):
|
||||
final_output: Optional[str] = None
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def should_direct_output(self) -> bool:
|
||||
def should_direct_output(self):
|
||||
return self.final_output is not None
|
||||
|
||||
def get_final_output(self) -> str:
|
||||
return self.final_output or ""
|
||||
def get_final_output(self):
|
||||
return self.final_output
|
||||
|
||||
def append_new_token(self, token: str) -> None:
|
||||
def append_new_token(self, token: str):
|
||||
self.buffer += token
|
||||
|
||||
if not self.thread:
|
||||
|
||||
@ -204,7 +204,6 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
node_generation_data = LangfuseGeneration(
|
||||
name="llm",
|
||||
trace_id=trace_id,
|
||||
model=process_data.get("model_name"),
|
||||
parent_observation_id=node_execution_id,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
|
||||
@ -139,7 +139,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
||||
)
|
||||
node_total_tokens = execution_metadata.get("total_tokens", 0)
|
||||
metadata = execution_metadata.copy()
|
||||
|
||||
metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
||||
metadata.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
@ -155,12 +156,6 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
run_type = LangSmithRunType.llm
|
||||
metadata.update(
|
||||
{
|
||||
'ls_provider': process_data.get('model_provider', ''),
|
||||
'ls_model_name': process_data.get('model_name', ''),
|
||||
}
|
||||
)
|
||||
elif node_type == "knowledge-retrieval":
|
||||
run_type = LangSmithRunType.retriever
|
||||
else:
|
||||
|
||||
@ -146,7 +146,7 @@ class RetrievalService:
|
||||
)
|
||||
|
||||
if documents:
|
||||
if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value:
|
||||
if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
|
||||
RerankMode.RERANKING_MODEL.value,
|
||||
reranking_model, None, False)
|
||||
@ -180,7 +180,7 @@ class RetrievalService:
|
||||
top_k=top_k
|
||||
)
|
||||
if documents:
|
||||
if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
|
||||
if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
|
||||
RerankMode.RERANKING_MODEL.value,
|
||||
reranking_model, None, False)
|
||||
|
||||
@ -281,25 +281,20 @@ class NotionExtractor(BaseExtractor):
|
||||
for table_header_cell_text in tabel_header_cell:
|
||||
text = table_header_cell_text["text"]["content"]
|
||||
table_header_cell_texts.append(text)
|
||||
else:
|
||||
table_header_cell_texts.append('')
|
||||
# Initialize Markdown table with headers
|
||||
markdown_table = "| " + " | ".join(table_header_cell_texts) + " |\n"
|
||||
markdown_table += "| " + " | ".join(['---'] * len(table_header_cell_texts)) + " |\n"
|
||||
|
||||
# Process data to format each row in Markdown table format
|
||||
# get table columns text and format
|
||||
results = data["results"]
|
||||
for i in range(len(results) - 1):
|
||||
column_texts = []
|
||||
table_column_cells = data["results"][i + 1]['table_row']['cells']
|
||||
for j in range(len(table_column_cells)):
|
||||
if table_column_cells[j]:
|
||||
for table_column_cell_text in table_column_cells[j]:
|
||||
tabel_column_cells = data["results"][i + 1]['table_row']['cells']
|
||||
for j in range(len(tabel_column_cells)):
|
||||
if tabel_column_cells[j]:
|
||||
for table_column_cell_text in tabel_column_cells[j]:
|
||||
column_text = table_column_cell_text["text"]["content"]
|
||||
column_texts.append(column_text)
|
||||
# Add row to Markdown table
|
||||
markdown_table += "| " + " | ".join(column_texts) + " |\n"
|
||||
result_lines_arr.append(markdown_table)
|
||||
column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
|
||||
|
||||
cur_result_text = "\n".join(column_texts)
|
||||
result_lines_arr.append(cur_result_text)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
done = True
|
||||
break
|
||||
|
||||
@ -170,8 +170,6 @@ class WordExtractor(BaseExtractor):
|
||||
if run.element.xpath('.//a:blip'):
|
||||
for blip in run.element.xpath('.//a:blip'):
|
||||
image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
|
||||
if not image_id:
|
||||
continue
|
||||
image_part = paragraph.part.rels[image_id].target_part
|
||||
|
||||
if image_part in image_map:
|
||||
@ -258,6 +256,6 @@ class WordExtractor(BaseExtractor):
|
||||
content.append(parsed_paragraph)
|
||||
elif isinstance(element.tag, str) and element.tag.endswith('tbl'): # table
|
||||
table = tables.pop(0)
|
||||
content.append(self._table_to_markdown(table, image_map))
|
||||
content.append(self._table_to_markdown(table,image_map))
|
||||
return '\n'.join(content)
|
||||
|
||||
|
||||
@ -30,14 +30,15 @@ def _split_text_with_regex(
|
||||
if keep_separator:
|
||||
# The parentheses in the pattern keep the delimiters in the result.
|
||||
_splits = re.split(f"({re.escape(separator)})", text)
|
||||
splits = [_splits[i - 1] + _splits[i] for i in range(1, len(_splits), 2)]
|
||||
if len(_splits) % 2 != 0:
|
||||
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
||||
if len(_splits) % 2 == 0:
|
||||
splits += _splits[-1:]
|
||||
splits = [_splits[0]] + splits
|
||||
else:
|
||||
splits = re.split(separator, text)
|
||||
else:
|
||||
splits = list(text)
|
||||
return [s for s in splits if (s != "" and s != '\n')]
|
||||
return [s for s in splits if s != ""]
|
||||
|
||||
|
||||
class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
@ -108,7 +109,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
else:
|
||||
return text
|
||||
|
||||
def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
|
||||
def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
|
||||
# We now want to combine these smaller pieces into medium size
|
||||
# chunks to send to the LLM.
|
||||
separator_len = self._length_function(separator)
|
||||
@ -116,9 +117,8 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
docs = []
|
||||
current_doc: list[str] = []
|
||||
total = 0
|
||||
index = 0
|
||||
for d in splits:
|
||||
_len = lengths[index]
|
||||
_len = self._length_function(d)
|
||||
if (
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
||||
> self._chunk_size
|
||||
@ -146,7 +146,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
current_doc = current_doc[1:]
|
||||
current_doc.append(d)
|
||||
total += _len + (separator_len if len(current_doc) > 1 else 0)
|
||||
index += 1
|
||||
doc = self._join_docs(current_doc, separator)
|
||||
if doc is not None:
|
||||
docs.append(doc)
|
||||
@ -495,10 +494,11 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||
|
||||
def _split_text(self, text: str, separators: list[str]) -> list[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
final_chunks = []
|
||||
# Get appropriate separator to use
|
||||
separator = separators[-1]
|
||||
new_separators = []
|
||||
|
||||
for i, _s in enumerate(separators):
|
||||
if _s == "":
|
||||
separator = _s
|
||||
@ -509,31 +509,25 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
break
|
||||
|
||||
splits = _split_text_with_regex(text, separator, self._keep_separator)
|
||||
# Now go merging things, recursively splitting longer texts.
|
||||
_good_splits = []
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
_separator = "" if self._keep_separator else separator
|
||||
|
||||
for s in splits:
|
||||
s_len = self._length_function(s)
|
||||
if s_len < self._chunk_size:
|
||||
if self._length_function(s) < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
_good_splits_lengths.append(s_len)
|
||||
else:
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
|
||||
merged_text = self._merge_splits(_good_splits, _separator)
|
||||
final_chunks.extend(merged_text)
|
||||
_good_splits = []
|
||||
_good_splits_lengths = []
|
||||
if not new_separators:
|
||||
final_chunks.append(s)
|
||||
else:
|
||||
other_info = self._split_text(s, new_separators)
|
||||
final_chunks.extend(other_info)
|
||||
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
|
||||
merged_text = self._merge_splits(_good_splits, _separator)
|
||||
final_chunks.extend(merged_text)
|
||||
|
||||
return final_chunks
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
- google
|
||||
- bing
|
||||
- perplexity
|
||||
- duckduckgo
|
||||
- searchapi
|
||||
- serper
|
||||
@ -11,7 +10,6 @@
|
||||
- wikipedia
|
||||
- nominatim
|
||||
- yahoo
|
||||
- alphavantage
|
||||
- arxiv
|
||||
- pubmed
|
||||
- stablediffusion
|
||||
|
||||
@ -1,7 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="56px" height="56px" viewBox="0 0 56 56" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<title>形状结合</title>
|
||||
<g id="设计规范" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<path d="M56,0 L56,56 L0,56 L0,0 L56,0 Z M31.6063018,12 L24.3936982,12 L24.1061064,12.7425499 L12.6071308,42.4324141 L12,44 L19.7849972,44 L20.0648488,43.2391815 L22.5196173,36.5567427 L33.4780427,36.5567427 L35.9351512,43.2391815 L36.2150028,44 L44,44 L43.3928692,42.4324141 L31.8938936,12.7425499 L31.6063018,12 Z M28.0163803,21.5755126 L31.1613993,30.2523823 L24.8432808,30.2523823 L28.0163803,21.5755126 Z" id="形状结合" fill="#2F4F4F"></path>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 780 B |
@ -1,22 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.alphavantage.tools.query_stock import QueryStockTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AlphaVantageProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
QueryStockTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"code": "AAPL", # Apple Inc.
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,31 +0,0 @@
|
||||
identity:
|
||||
author: zhuhao
|
||||
name: alphavantage
|
||||
label:
|
||||
en_US: AlphaVantage
|
||||
zh_Hans: AlphaVantage
|
||||
pt_BR: AlphaVantage
|
||||
description:
|
||||
en_US: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis.
|
||||
zh_Hans: AlphaVantage是一个在线平台,它提供金融市场数据和API,便于个人投资者和开发者获取股票报价、技术指标和股票分析。
|
||||
pt_BR: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- finance
|
||||
credentials_for_provider:
|
||||
api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: AlphaVantage API key
|
||||
zh_Hans: AlphaVantage API key
|
||||
pt_BR: AlphaVantage API key
|
||||
placeholder:
|
||||
en_US: Please input your AlphaVantage API key
|
||||
zh_Hans: 请输入你的 AlphaVantage API key
|
||||
pt_BR: Please input your AlphaVantage API key
|
||||
help:
|
||||
en_US: Get your AlphaVantage API key from AlphaVantage
|
||||
zh_Hans: 从 AlphaVantage 获取您的 AlphaVantage API key
|
||||
pt_BR: Get your AlphaVantage API key from AlphaVantage
|
||||
url: https://www.alphavantage.co/support/#api-key
|
||||
@ -1,49 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
ALPHAVANTAGE_API_URL = "https://www.alphavantage.co/query"
|
||||
|
||||
|
||||
class QueryStockTool(BuiltinTool):
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
stock_code = tool_parameters.get('code', '')
|
||||
if not stock_code:
|
||||
return self.create_text_message('Please tell me your stock code')
|
||||
|
||||
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
|
||||
return self.create_text_message("Alpha Vantage API key is required.")
|
||||
|
||||
params = {
|
||||
"function": "TIME_SERIES_DAILY",
|
||||
"symbol": stock_code,
|
||||
"outputsize": "compact",
|
||||
"datatype": "json",
|
||||
"apikey": self.runtime.credentials['api_key']
|
||||
}
|
||||
response = requests.get(url=ALPHAVANTAGE_API_URL, params=params)
|
||||
response.raise_for_status()
|
||||
result = self._handle_response(response.json())
|
||||
return self.create_json_message(result)
|
||||
|
||||
def _handle_response(self, response: dict[str, Any]) -> dict[str, Any]:
|
||||
result = response.get('Time Series (Daily)', {})
|
||||
if not result:
|
||||
return {}
|
||||
stock_result = {}
|
||||
for k, v in result.items():
|
||||
stock_result[k] = {}
|
||||
stock_result[k]['open'] = v.get('1. open')
|
||||
stock_result[k]['high'] = v.get('2. high')
|
||||
stock_result[k]['low'] = v.get('3. low')
|
||||
stock_result[k]['close'] = v.get('4. close')
|
||||
stock_result[k]['volume'] = v.get('5. volume')
|
||||
return stock_result
|
||||
@ -1,27 +0,0 @@
|
||||
identity:
|
||||
name: query_stock
|
||||
author: zhuhao
|
||||
label:
|
||||
en_US: query_stock
|
||||
zh_Hans: query_stock
|
||||
pt_BR: query_stock
|
||||
description:
|
||||
human:
|
||||
en_US: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol.
|
||||
zh_Hans: 获取指定股票代码的每日开盘价、每日最高价、每日最低价、每日收盘价和每日交易量等信息。
|
||||
pt_BR: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol
|
||||
llm: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol
|
||||
parameters:
|
||||
- name: code
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: stock code
|
||||
zh_Hans: 股票代码
|
||||
pt_BR: stock code
|
||||
human_description:
|
||||
en_US: stock code
|
||||
zh_Hans: 股票代码
|
||||
pt_BR: stock code
|
||||
llm_description: stock code for query from alphavantage
|
||||
form: llm
|
||||
@ -1,3 +0,0 @@
|
||||
<svg width="400" height="400" viewBox="0 0 400 400" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M101.008 42L190.99 124.905L190.99 124.886L190.99 42.1913H208.506L208.506 125.276L298.891 42V136.524L336 136.524V272.866H299.005V357.035L208.506 277.525L208.506 357.948H190.99L190.99 278.836L101.11 358V272.866H64V136.524H101.008V42ZM177.785 153.826H81.5159V255.564H101.088V223.472L177.785 153.826ZM118.625 231.149V319.392L190.99 255.655L190.99 165.421L118.625 231.149ZM209.01 254.812V165.336L281.396 231.068V272.866H281.489V318.491L209.01 254.812ZM299.005 255.564H318.484V153.826L222.932 153.826L299.005 222.751V255.564ZM281.375 136.524V81.7983L221.977 136.524L281.375 136.524ZM177.921 136.524H118.524V81.7983L177.921 136.524Z" fill="black"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 798 B |
@ -1,46 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.perplexity.tools.perplexity_search import PERPLEXITY_API_URL
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class PerplexityProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credentials.get('perplexity_api_key')}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "llama-3.1-sonar-small-128k-online",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
],
|
||||
"max_tokens": 5,
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.9,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(PERPLEXITY_API_URL, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"Failed to validate Perplexity API key: {str(e)}"
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"Perplexity API key is invalid. Status code: {response.status_code}"
|
||||
)
|
||||
@ -1,26 +0,0 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: perplexity
|
||||
label:
|
||||
en_US: Perplexity
|
||||
zh_Hans: Perplexity
|
||||
description:
|
||||
en_US: Perplexity.AI
|
||||
zh_Hans: Perplexity.AI
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
credentials_for_provider:
|
||||
perplexity_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Perplexity API key
|
||||
zh_Hans: Perplexity API key
|
||||
placeholder:
|
||||
en_US: Please input your Perplexity API key
|
||||
zh_Hans: 请输入你的 Perplexity API key
|
||||
help:
|
||||
en_US: Get your Perplexity API key from Perplexity
|
||||
zh_Hans: 从 Perplexity 获取您的 Perplexity API key
|
||||
url: https://www.perplexity.ai/settings/api
|
||||
@ -1,72 +0,0 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
PERPLEXITY_API_URL = "https://api.perplexity.ai/chat/completions"
|
||||
|
||||
class PerplexityAITool(BuiltinTool):
|
||||
def _parse_response(self, response: dict) -> dict:
|
||||
"""Parse the response from Perplexity AI API"""
|
||||
if 'choices' in response and len(response['choices']) > 0:
|
||||
message = response['choices'][0]['message']
|
||||
return {
|
||||
'content': message.get('content', ''),
|
||||
'role': message.get('role', ''),
|
||||
'citations': response.get('citations', [])
|
||||
}
|
||||
else:
|
||||
return {'content': 'Unable to get a valid response', 'role': 'assistant', 'citations': []}
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.runtime.credentials['perplexity_api_key']}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": tool_parameters.get('model', 'llama-3.1-sonar-small-128k-online'),
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Be precise and concise."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": tool_parameters['query']
|
||||
}
|
||||
],
|
||||
"max_tokens": tool_parameters.get('max_tokens', 4096),
|
||||
"temperature": tool_parameters.get('temperature', 0.7),
|
||||
"top_p": tool_parameters.get('top_p', 1),
|
||||
"top_k": tool_parameters.get('top_k', 5),
|
||||
"presence_penalty": tool_parameters.get('presence_penalty', 0),
|
||||
"frequency_penalty": tool_parameters.get('frequency_penalty', 1),
|
||||
"stream": False
|
||||
}
|
||||
|
||||
if 'search_recency_filter' in tool_parameters:
|
||||
payload['search_recency_filter'] = tool_parameters['search_recency_filter']
|
||||
if 'return_citations' in tool_parameters:
|
||||
payload['return_citations'] = tool_parameters['return_citations']
|
||||
if 'search_domain_filter' in tool_parameters:
|
||||
if isinstance(tool_parameters['search_domain_filter'], str):
|
||||
payload['search_domain_filter'] = [tool_parameters['search_domain_filter']]
|
||||
elif isinstance(tool_parameters['search_domain_filter'], list):
|
||||
payload['search_domain_filter'] = tool_parameters['search_domain_filter']
|
||||
|
||||
|
||||
response = requests.post(url=PERPLEXITY_API_URL, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
valuable_res = self._parse_response(response.json())
|
||||
|
||||
return [
|
||||
self.create_json_message(valuable_res),
|
||||
self.create_text_message(json.dumps(valuable_res, ensure_ascii=False, indent=2))
|
||||
]
|
||||
@ -1,178 +0,0 @@
|
||||
identity:
|
||||
name: perplexity
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Perplexity Search
|
||||
description:
|
||||
human:
|
||||
en_US: Search information using Perplexity AI's language models.
|
||||
llm: This tool is used to search information using Perplexity AI's language models.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query
|
||||
zh_Hans: 查询
|
||||
human_description:
|
||||
en_US: The text query to be processed by the AI model.
|
||||
zh_Hans: 要由 AI 模型处理的文本查询。
|
||||
form: llm
|
||||
- name: model
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
human_description:
|
||||
en_US: The Perplexity AI model to use for generating the response.
|
||||
zh_Hans: 用于生成响应的 Perplexity AI 模型。
|
||||
form: form
|
||||
default: "llama-3.1-sonar-small-128k-online"
|
||||
options:
|
||||
- value: llama-3.1-sonar-small-128k-online
|
||||
label:
|
||||
en_US: llama-3.1-sonar-small-128k-online
|
||||
zh_Hans: llama-3.1-sonar-small-128k-online
|
||||
- value: llama-3.1-sonar-large-128k-online
|
||||
label:
|
||||
en_US: llama-3.1-sonar-large-128k-online
|
||||
zh_Hans: llama-3.1-sonar-large-128k-online
|
||||
- value: llama-3.1-sonar-huge-128k-online
|
||||
label:
|
||||
en_US: llama-3.1-sonar-huge-128k-online
|
||||
zh_Hans: llama-3.1-sonar-huge-128k-online
|
||||
- name: max_tokens
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Max Tokens
|
||||
zh_Hans: 最大令牌数
|
||||
pt_BR: Máximo de Tokens
|
||||
human_description:
|
||||
en_US: The maximum number of tokens to generate in the response.
|
||||
zh_Hans: 在响应中生成的最大令牌数。
|
||||
pt_BR: O número máximo de tokens a serem gerados na resposta.
|
||||
form: form
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: temperature
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Temperature
|
||||
zh_Hans: 温度
|
||||
pt_BR: Temperatura
|
||||
human_description:
|
||||
en_US: Controls randomness in the output. Lower values make the output more focused and deterministic.
|
||||
zh_Hans: 控制输出的随机性。较低的值使输出更加集中和确定。
|
||||
form: form
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_k
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Top K
|
||||
zh_Hans: 取样数量
|
||||
human_description:
|
||||
en_US: The number of top results to consider for response generation.
|
||||
zh_Hans: 用于生成响应的顶部结果数量。
|
||||
form: form
|
||||
default: 5
|
||||
min: 1
|
||||
max: 100
|
||||
- name: top_p
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Top P
|
||||
zh_Hans: Top P
|
||||
human_description:
|
||||
en_US: Controls diversity via nucleus sampling.
|
||||
zh_Hans: 通过核心采样控制多样性。
|
||||
form: form
|
||||
default: 1
|
||||
min: 0.1
|
||||
max: 1
|
||||
step: 0.1
|
||||
- name: presence_penalty
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Presence Penalty
|
||||
zh_Hans: 存在惩罚
|
||||
human_description:
|
||||
en_US: Positive values penalize new tokens based on whether they appear in the text so far.
|
||||
zh_Hans: 正值会根据新词元是否已经出现在文本中来对其进行惩罚。
|
||||
form: form
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
step: 0.1
|
||||
- name: frequency_penalty
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Frequency Penalty
|
||||
zh_Hans: 频率惩罚
|
||||
human_description:
|
||||
en_US: Positive values penalize new tokens based on their existing frequency in the text so far.
|
||||
zh_Hans: 正值会根据新词元在文本中已经出现的频率来对其进行惩罚。
|
||||
form: form
|
||||
default: 1
|
||||
min: 0.1
|
||||
max: 1.0
|
||||
step: 0.1
|
||||
- name: return_citations
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Return Citations
|
||||
zh_Hans: 返回引用
|
||||
human_description:
|
||||
en_US: Whether to return citations in the response.
|
||||
zh_Hans: 是否在响应中返回引用。
|
||||
form: form
|
||||
default: true
|
||||
- name: search_domain_filter
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Search Domain Filter
|
||||
zh_Hans: 搜索域过滤器
|
||||
human_description:
|
||||
en_US: Domain to filter the search results.
|
||||
zh_Hans: 用于过滤搜索结果的域名。
|
||||
form: form
|
||||
default: ""
|
||||
- name: search_recency_filter
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: Search Recency Filter
|
||||
zh_Hans: 搜索时间过滤器
|
||||
human_description:
|
||||
en_US: Filter for search results based on recency.
|
||||
zh_Hans: 基于时间筛选搜索结果。
|
||||
form: form
|
||||
default: "month"
|
||||
options:
|
||||
- value: day
|
||||
label:
|
||||
en_US: Day
|
||||
zh_Hans: 天
|
||||
- value: week
|
||||
label:
|
||||
en_US: Week
|
||||
zh_Hans: 周
|
||||
- value: month
|
||||
label:
|
||||
en_US: Month
|
||||
zh_Hans: 月
|
||||
- value: year
|
||||
label:
|
||||
en_US: Year
|
||||
zh_Hans: 年
|
||||
@ -17,8 +17,11 @@ class StepfunTool(BuiltinTool):
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
base_url = self.runtime.credentials.get('stepfun_base_url', 'https://api.stepfun.com')
|
||||
base_url = str(URL(base_url) / 'v1')
|
||||
base_url = self.runtime.credentials.get('stepfun_base_url', None)
|
||||
if not base_url:
|
||||
base_url = None
|
||||
else:
|
||||
base_url = str(URL(base_url) / 'v1')
|
||||
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials['stepfun_api_key'],
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
@ -18,7 +18,6 @@ class WorkflowTool(Tool):
|
||||
version: str
|
||||
workflow_entities: dict[str, Any]
|
||||
workflow_call_depth: int
|
||||
thread_pool_id: Optional[str] = None
|
||||
|
||||
label: str
|
||||
|
||||
@ -58,7 +57,6 @@ class WorkflowTool(Tool):
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
stream=False,
|
||||
call_depth=self.workflow_call_depth + 1,
|
||||
workflow_thread_pool_id=self.thread_pool_id
|
||||
)
|
||||
|
||||
data = result.get('data', {})
|
||||
|
||||
@ -128,7 +128,6 @@ class ToolEngine:
|
||||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int,
|
||||
thread_pool_id: Optional[str] = None
|
||||
) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Workflow invokes the tool with the given arguments.
|
||||
@ -142,7 +141,6 @@ class ToolEngine:
|
||||
|
||||
if isinstance(tool, WorkflowTool):
|
||||
tool.workflow_call_depth = workflow_call_depth + 1
|
||||
tool.thread_pool_id = thread_pool_id
|
||||
|
||||
if tool.runtime and tool.runtime.runtime_parameters:
|
||||
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
||||
|
||||
@ -25,6 +25,7 @@ from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
@ -248,7 +249,7 @@ class ToolManager:
|
||||
return tool_entity
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
|
||||
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
"""
|
||||
|
||||
@ -7,7 +7,6 @@ from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],
|
||||
|
||||
@ -1,15 +1,116 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent
|
||||
from core.app.entities.queue_entities import AppQueueEvent
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
|
||||
|
||||
class WorkflowCallback(ABC):
|
||||
@abstractmethod
|
||||
def on_event(
|
||||
self,
|
||||
event: GraphEngineEvent
|
||||
) -> None:
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Published event
|
||||
Workflow run started
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: Optional[dict] = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any],
|
||||
) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel):
|
||||
desc: Optional[str] = None
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: Optional[str] = None
|
||||
start_node_id: str
|
||||
|
||||
class BaseIterationState(BaseModel):
|
||||
iteration_node_id: str
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from models import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@ -28,7 +28,6 @@ class NodeType(Enum):
|
||||
VARIABLE_ASSIGNER = 'variable-assigner'
|
||||
LOOP = 'loop'
|
||||
ITERATION = 'iteration'
|
||||
ITERATION_START = 'iteration-start' # fake start node for iteration
|
||||
PARAMETER_EXTRACTOR = 'parameter-extractor'
|
||||
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
|
||||
|
||||
@ -57,10 +56,6 @@ class NodeRunMetadataKey(Enum):
|
||||
TOOL_INFO = 'tool_info'
|
||||
ITERATION_ID = 'iteration_id'
|
||||
ITERATION_INDEX = 'iteration_index'
|
||||
PARALLEL_ID = 'parallel_id'
|
||||
PARALLEL_START_NODE_ID = 'parallel_start_node_id'
|
||||
PARENT_PARALLEL_ID = 'parent_parallel_id'
|
||||
PARENT_PARALLEL_START_NODE_ID = 'parent_parallel_start_node_id'
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
@ -70,32 +65,11 @@ class NodeRunResult(BaseModel):
|
||||
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None # node inputs
|
||||
process_data: Optional[dict[str, Any]] = None # process data
|
||||
outputs: Optional[dict[str, Any]] = None # node outputs
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[dict] = None # process data
|
||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
llm_usage: Optional[LLMUsage] = None # llm usage
|
||||
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
error: Optional[str] = None # error message if status is failed
|
||||
|
||||
|
||||
class UserFrom(Enum):
|
||||
"""
|
||||
User from
|
||||
"""
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "UserFrom":
|
||||
"""
|
||||
Value of
|
||||
:param value: value
|
||||
:return:
|
||||
"""
|
||||
for item in cls:
|
||||
if item.value == value:
|
||||
return item
|
||||
raise ValueError(f"Invalid value: {value}")
|
||||
|
||||
@ -2,7 +2,6 @@ from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.app.segments import Segment, Variable, factory
|
||||
@ -17,52 +16,43 @@ ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
|
||||
|
||||
class VariablePool(BaseModel):
|
||||
# Variable dictionary is a dictionary for looking up variables by their selector.
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
variable_dictionary: dict[str, dict[int, Segment]] = Field(
|
||||
description='Variables mapping',
|
||||
default=defaultdict(dict)
|
||||
)
|
||||
class VariablePool:
|
||||
def __init__(
|
||||
self,
|
||||
system_variables: Mapping[SystemVariableKey, Any],
|
||||
user_inputs: Mapping[str, Any],
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable] | None = None,
|
||||
) -> None:
|
||||
# system variables
|
||||
# for example:
|
||||
# {
|
||||
# 'query': 'abc',
|
||||
# 'files': []
|
||||
# }
|
||||
|
||||
# TODO: This user inputs is not used for pool.
|
||||
user_inputs: Mapping[str, Any] = Field(
|
||||
description='User inputs',
|
||||
)
|
||||
# Varaible dictionary is a dictionary for looking up variables by their selector.
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
|
||||
|
||||
system_variables: Mapping[SystemVariableKey, Any] = Field(
|
||||
description='System variables',
|
||||
)
|
||||
# TODO: This user inputs is not used for pool.
|
||||
self.user_inputs = user_inputs
|
||||
|
||||
environment_variables: Sequence[Variable] = Field(
|
||||
description="Environment variables.",
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
conversation_variables: Sequence[Variable] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def val_model_after(self):
|
||||
"""
|
||||
Append system variables
|
||||
:return:
|
||||
"""
|
||||
# Add system variables to the variable pool
|
||||
for key, value in self.system_variables.items():
|
||||
self.system_variables = system_variables
|
||||
for key, value in system_variables.items():
|
||||
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
|
||||
|
||||
# Add environment variables to the variable pool
|
||||
for var in self.environment_variables or []:
|
||||
for var in environment_variables:
|
||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
# Add conversation variables to the variable pool
|
||||
for var in self.conversation_variables or []:
|
||||
for var in conversation_variables or []:
|
||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
return self
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
"""
|
||||
Adds a variable to the variable pool.
|
||||
@ -89,7 +79,7 @@ class VariablePool(BaseModel):
|
||||
v = factory.build_segment(value)
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self.variable_dictionary[selector[0]][hash_key] = v
|
||||
self._variable_dictionary[selector[0]][hash_key] = v
|
||||
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""
|
||||
@ -107,7 +97,7 @@ class VariablePool(BaseModel):
|
||||
if len(selector) < 2:
|
||||
raise ValueError("Invalid selector")
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self.variable_dictionary[selector[0]].get(hash_key)
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
|
||||
return value
|
||||
|
||||
@ -128,7 +118,7 @@ class VariablePool(BaseModel):
|
||||
if len(selector) < 2:
|
||||
raise ValueError("Invalid selector")
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self.variable_dictionary[selector[0]].get(hash_key)
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
return value.to_object() if value else None
|
||||
|
||||
def remove(self, selector: Sequence[str], /):
|
||||
@ -144,19 +134,7 @@ class VariablePool(BaseModel):
|
||||
if not selector:
|
||||
return
|
||||
if len(selector) == 1:
|
||||
self.variable_dictionary[selector[0]] = {}
|
||||
self._variable_dictionary[selector[0]] = {}
|
||||
return
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self.variable_dictionary[selector[0]].pop(hash_key, None)
|
||||
|
||||
def remove_node(self, node_id: str, /):
|
||||
"""
|
||||
Remove all variables associated with a given node id.
|
||||
|
||||
Args:
|
||||
node_id (str): The node id to remove.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.variable_dictionary.pop(node_id, None)
|
||||
self._variable_dictionary[selector[0]].pop(hash_key, None)
|
||||
|
||||
@ -66,7 +66,8 @@ class WorkflowRunState:
|
||||
self.variable_pool = variable_pool
|
||||
|
||||
self.total_tokens = 0
|
||||
self.workflow_nodes_and_results = []
|
||||
|
||||
self.workflow_node_steps = 1
|
||||
self.workflow_node_runs = []
|
||||
self.current_iteration_state = None
|
||||
self.workflow_node_steps = 1
|
||||
self.workflow_node_runs = []
|
||||
@ -1,8 +1,10 @@
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
|
||||
|
||||
class WorkflowNodeRunFailedError(Exception):
|
||||
def __init__(self, node_instance: BaseNode, error: str):
|
||||
self.node_instance = node_instance
|
||||
def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str):
|
||||
self.node_id = node_id
|
||||
self.node_type = node_type
|
||||
self.node_title = node_title
|
||||
self.error = error
|
||||
super().__init__(f"Node {node_instance.node_data.title} run failed: {error}")
|
||||
super().__init__(f"Node {node_title} run failed: {error}")
|
||||
|
||||
@ -1,31 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class RunConditionHandler(ABC):
|
||||
def __init__(self,
|
||||
init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
condition: RunCondition):
|
||||
self.init_params = init_params
|
||||
self.graph = graph
|
||||
self.condition = condition
|
||||
|
||||
@abstractmethod
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@ -1,28 +0,0 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.branch_identify:
|
||||
raise Exception("Branch identify is required")
|
||||
|
||||
run_result = previous_route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return False
|
||||
|
||||
if not run_result.edge_source_handle:
|
||||
return False
|
||||
|
||||
return self.condition.branch_identify == run_result.edge_source_handle
|
||||
@ -1,32 +0,0 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.conditions:
|
||||
return True
|
||||
|
||||
# process condition
|
||||
condition_processor = ConditionProcessor()
|
||||
input_conditions, group_result = condition_processor.process_conditions(
|
||||
variable_pool=graph_runtime_state.variable_pool,
|
||||
conditions=self.condition.conditions
|
||||
)
|
||||
|
||||
# Apply the logical operator for the current case
|
||||
compare_result = all(group_result)
|
||||
|
||||
return compare_result
|
||||
@ -1,35 +0,0 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
class ConditionManager:
|
||||
@staticmethod
|
||||
def get_condition_handler(
|
||||
init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
run_condition: RunCondition
|
||||
) -> RunConditionHandler:
|
||||
"""
|
||||
Get condition handler
|
||||
|
||||
:param init_params: init params
|
||||
:param graph: graph
|
||||
:param run_condition: run condition
|
||||
:return: condition handler
|
||||
"""
|
||||
if run_condition.type == "branch_identify":
|
||||
return BranchIdentifyRunConditionHandler(
|
||||
init_params=init_params,
|
||||
graph=graph,
|
||||
condition=run_condition
|
||||
)
|
||||
else:
|
||||
return ConditionRunConditionHandlerHandler(
|
||||
init_params=init_params,
|
||||
graph=graph,
|
||||
condition=run_condition
|
||||
)
|
||||
@ -1,163 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class GraphEngineEvent(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
###########################################
|
||||
# Graph Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseGraphEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunStartedEvent(BaseGraphEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
"""outputs"""
|
||||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Node Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseNodeEvent(GraphEngineEvent):
|
||||
id: str = Field(..., description="node execution id")
|
||||
node_id: str = Field(..., description="node id")
|
||||
node_type: NodeType = Field(..., description="node type")
|
||||
node_data: BaseNodeData = Field(..., description="node data")
|
||||
route_node_state: RouteNodeState = Field(..., description="route node state")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
predecessor_node_id: Optional[str] = None
|
||||
"""predecessor node id"""
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
"""from variable selector"""
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
class NodeRunSucceededEvent(BaseNodeEvent):
|
||||
pass
|
||||
|
||||
|
||||
class NodeRunFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Branch Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseParallelBranchEvent(GraphEngineEvent):
|
||||
parallel_id: str = Field(..., description="parallel id")
|
||||
"""parallel id"""
|
||||
parallel_start_node_id: str = Field(..., description="parallel start node id")
|
||||
"""parallel start node id"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Iteration Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseIterationEvent(GraphEngineEvent):
|
||||
iteration_id: str = Field(..., description="iteration node execution id")
|
||||
iteration_node_id: str = Field(..., description="iteration node id")
|
||||
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
|
||||
iteration_node_data: BaseNodeData = Field(..., description="node data")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
|
||||
|
||||
class IterationRunStartedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class IterationRunNextEvent(BaseIterationEvent):
|
||||
index: int = Field(..., description="index")
|
||||
pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output")
|
||||
|
||||
|
||||
class IterationRunSucceededEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class IterationRunFailedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent
|
||||
@ -1,672 +0,0 @@
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
|
||||
from core.workflow.nodes.end.entities import EndStreamParam
|
||||
|
||||
|
||||
class GraphEdge(BaseModel):
|
||||
source_node_id: str = Field(..., description="source node id")
|
||||
target_node_id: str = Field(..., description="target node id")
|
||||
run_condition: Optional[RunCondition] = None
|
||||
"""run condition"""
|
||||
|
||||
|
||||
class GraphParallel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
|
||||
start_from_node_id: str = Field(..., description="start from node id")
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id"""
|
||||
end_to_node_id: Optional[str] = None
|
||||
"""end to node id"""
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
root_node_id: str = Field(..., description="root node id of the graph")
|
||||
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
|
||||
node_id_config_mapping: dict[str, dict] = Field(
|
||||
default_factory=list,
|
||||
description="node configs mapping (node id: node config)"
|
||||
)
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict,
|
||||
description="graph edge mapping (source node id: edges)"
|
||||
)
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict,
|
||||
description="reverse graph edge mapping (target node id: edges)"
|
||||
)
|
||||
parallel_mapping: dict[str, GraphParallel] = Field(
|
||||
default_factory=dict,
|
||||
description="graph parallel mapping (parallel id: parallel)"
|
||||
)
|
||||
node_parallel_mapping: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="graph node parallel mapping (node id: parallel id)"
|
||||
)
|
||||
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(
|
||||
...,
|
||||
description="answer stream generate routes"
|
||||
)
|
||||
end_stream_param: EndStreamParam = Field(
|
||||
...,
|
||||
description="end stream param"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
root_node_id: Optional[str] = None) -> "Graph":
|
||||
"""
|
||||
Init graph
|
||||
|
||||
:param graph_config: graph config
|
||||
:param root_node_id: root node id
|
||||
:return: graph
|
||||
"""
|
||||
# edge configs
|
||||
edge_configs = graph_config.get('edges')
|
||||
if edge_configs is None:
|
||||
edge_configs = []
|
||||
|
||||
edge_configs = cast(list, edge_configs)
|
||||
|
||||
# reorganize edges mapping
|
||||
edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
target_edge_ids = set()
|
||||
for edge_config in edge_configs:
|
||||
source_node_id = edge_config.get('source')
|
||||
if not source_node_id:
|
||||
continue
|
||||
|
||||
if source_node_id not in edge_mapping:
|
||||
edge_mapping[source_node_id] = []
|
||||
|
||||
target_node_id = edge_config.get('target')
|
||||
if not target_node_id:
|
||||
continue
|
||||
|
||||
if target_node_id not in reverse_edge_mapping:
|
||||
reverse_edge_mapping[target_node_id] = []
|
||||
|
||||
# is target node id in source node id edge mapping
|
||||
if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]):
|
||||
continue
|
||||
|
||||
target_edge_ids.add(target_node_id)
|
||||
|
||||
# parse run condition
|
||||
run_condition = None
|
||||
if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source':
|
||||
run_condition = RunCondition(
|
||||
type='branch_identify',
|
||||
branch_identify=edge_config.get('sourceHandle')
|
||||
)
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id,
|
||||
run_condition=run_condition
|
||||
)
|
||||
|
||||
edge_mapping[source_node_id].append(graph_edge)
|
||||
reverse_edge_mapping[target_node_id].append(graph_edge)
|
||||
|
||||
# node configs
|
||||
node_configs = graph_config.get('nodes')
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
node_configs = cast(list, node_configs)
|
||||
|
||||
# fetch nodes that have no predecessor node
|
||||
root_node_configs = []
|
||||
all_node_id_config_mapping: dict[str, dict] = {}
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get('id')
|
||||
if not node_id:
|
||||
continue
|
||||
|
||||
if node_id not in target_edge_ids:
|
||||
root_node_configs.append(node_config)
|
||||
|
||||
all_node_id_config_mapping[node_id] = node_config
|
||||
|
||||
root_node_ids = [node_config.get('id') for node_config in root_node_configs]
|
||||
|
||||
# fetch root node
|
||||
if not root_node_id:
|
||||
# if no root node id, use the START type node as root node
|
||||
root_node_id = next((node_config.get("id") for node_config in root_node_configs
|
||||
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
|
||||
|
||||
if not root_node_id or root_node_id not in root_node_ids:
|
||||
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||
|
||||
# Check whether it is connected to the previous node
|
||||
cls._check_connected_to_previous_node(
|
||||
route=[root_node_id],
|
||||
edge_mapping=edge_mapping
|
||||
)
|
||||
|
||||
# fetch all node ids from root node
|
||||
node_ids = [root_node_id]
|
||||
cls._recursively_add_node_ids(
|
||||
node_ids=node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
node_id=root_node_id
|
||||
)
|
||||
|
||||
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
|
||||
|
||||
# init parallel mapping
|
||||
parallel_mapping: dict[str, GraphParallel] = {}
|
||||
node_parallel_mapping: dict[str, str] = {}
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=root_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping
|
||||
)
|
||||
|
||||
# Check if it exceeds N layers of parallel
|
||||
for parallel in parallel_mapping.values():
|
||||
if parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=3,
|
||||
parent_parallel_id=parallel.parent_parallel_id
|
||||
)
|
||||
|
||||
# init answer stream generate routes
|
||||
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping
|
||||
)
|
||||
|
||||
# init end stream param
|
||||
end_stream_param = EndStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = cls(
|
||||
root_node_id=root_node_id,
|
||||
node_ids=node_ids,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
answer_stream_generate_routes=answer_stream_generate_routes,
|
||||
end_stream_param=end_stream_param
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
def add_extra_edge(self, source_node_id: str,
|
||||
target_node_id: str,
|
||||
run_condition: Optional[RunCondition] = None) -> None:
|
||||
"""
|
||||
Add extra edge to the graph
|
||||
|
||||
:param source_node_id: source node id
|
||||
:param target_node_id: target node id
|
||||
:param run_condition: run condition
|
||||
"""
|
||||
if source_node_id not in self.node_ids or target_node_id not in self.node_ids:
|
||||
return
|
||||
|
||||
if source_node_id not in self.edge_mapping:
|
||||
self.edge_mapping[source_node_id] = []
|
||||
|
||||
if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]:
|
||||
return
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id,
|
||||
run_condition=run_condition
|
||||
)
|
||||
|
||||
self.edge_mapping[source_node_id].append(graph_edge)
|
||||
|
||||
def get_leaf_node_ids(self) -> list[str]:
|
||||
"""
|
||||
Get leaf node ids of the graph
|
||||
|
||||
:return: leaf node ids
|
||||
"""
|
||||
leaf_node_ids = []
|
||||
for node_id in self.node_ids:
|
||||
if node_id not in self.edge_mapping:
|
||||
leaf_node_ids.append(node_id)
|
||||
elif (len(self.edge_mapping[node_id]) == 1
|
||||
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id):
|
||||
leaf_node_ids.append(node_id)
|
||||
|
||||
return leaf_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_node_ids(cls,
|
||||
node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
node_id: str) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param node_ids: node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param node_id: node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(node_id, []):
|
||||
if graph_edge.target_node_id in node_ids:
|
||||
continue
|
||||
|
||||
node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_node_ids(
|
||||
node_ids=node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
node_id=graph_edge.target_node_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _check_connected_to_previous_node(
|
||||
cls,
|
||||
route: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]]
|
||||
) -> None:
|
||||
"""
|
||||
Check whether it is connected to the previous node
|
||||
"""
|
||||
last_node_id = route[-1]
|
||||
|
||||
for graph_edge in edge_mapping.get(last_node_id, []):
|
||||
if not graph_edge.target_node_id:
|
||||
continue
|
||||
|
||||
if graph_edge.target_node_id in route:
|
||||
raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.")
|
||||
|
||||
new_route = route[:]
|
||||
new_route.append(graph_edge.target_node_id)
|
||||
cls._check_connected_to_previous_node(
|
||||
route=new_route,
|
||||
edge_mapping=edge_mapping,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallels(
|
||||
cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
node_parallel_mapping: dict[str, str],
|
||||
parent_parallel: Optional[GraphParallel] = None
|
||||
) -> None:
|
||||
"""
|
||||
Recursively add parallel ids
|
||||
|
||||
:param edge_mapping: edge mapping
|
||||
:param start_node_id: start from node id
|
||||
:param parallel_mapping: parallel mapping
|
||||
:param node_parallel_mapping: node parallel mapping
|
||||
:param parent_parallel: parent parallel
|
||||
"""
|
||||
target_node_edges = edge_mapping.get(start_node_id, [])
|
||||
parallel = None
|
||||
if len(target_node_edges) > 1:
|
||||
# fetch all node ids in current parallels
|
||||
parallel_branch_node_ids = []
|
||||
condition_edge_mappings = {}
|
||||
for graph_edge in target_node_edges:
|
||||
if graph_edge.run_condition is None:
|
||||
parallel_branch_node_ids.append(graph_edge.target_node_id)
|
||||
else:
|
||||
condition_hash = graph_edge.run_condition.hash
|
||||
if not condition_hash in condition_edge_mappings:
|
||||
condition_edge_mappings[condition_hash] = []
|
||||
|
||||
condition_edge_mappings[condition_hash].append(graph_edge)
|
||||
|
||||
for _, graph_edges in condition_edge_mappings.items():
|
||||
if len(graph_edges) > 1:
|
||||
for graph_edge in graph_edges:
|
||||
parallel_branch_node_ids.append(graph_edge.target_node_id)
|
||||
|
||||
# any target node id in node_parallel_mapping
|
||||
if parallel_branch_node_ids:
|
||||
parent_parallel_id = parent_parallel.id if parent_parallel else None
|
||||
|
||||
parallel = GraphParallel(
|
||||
start_from_node_id=start_node_id,
|
||||
parent_parallel_id=parent_parallel.id if parent_parallel else None,
|
||||
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None
|
||||
)
|
||||
parallel_mapping[parallel.id] = parallel
|
||||
|
||||
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_branch_node_ids=parallel_branch_node_ids
|
||||
)
|
||||
|
||||
# collect all branches node ids
|
||||
parallel_node_ids = []
|
||||
for _, node_ids in in_branch_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
in_parent_parallel = True
|
||||
if parent_parallel_id:
|
||||
in_parent_parallel = False
|
||||
for parallel_node_id, parallel_id in node_parallel_mapping.items():
|
||||
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
|
||||
in_parent_parallel = True
|
||||
break
|
||||
|
||||
if in_parent_parallel:
|
||||
parallel_node_ids.append(node_id)
|
||||
node_parallel_mapping[node_id] = parallel.id
|
||||
|
||||
outside_parallel_target_node_ids = set()
|
||||
for node_id in parallel_node_ids:
|
||||
if node_id == parallel.start_from_node_id:
|
||||
continue
|
||||
|
||||
node_edges = edge_mapping.get(node_id)
|
||||
if not node_edges:
|
||||
continue
|
||||
|
||||
if len(node_edges) > 1:
|
||||
continue
|
||||
|
||||
target_node_id = node_edges[0].target_node_id
|
||||
if target_node_id in parallel_node_ids:
|
||||
continue
|
||||
|
||||
if parent_parallel_id:
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
continue
|
||||
|
||||
if (
|
||||
(node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id)
|
||||
or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id)
|
||||
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
|
||||
):
|
||||
outside_parallel_target_node_ids.add(target_node_id)
|
||||
|
||||
if len(outside_parallel_target_node_ids) == 1:
|
||||
if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id:
|
||||
parallel.end_to_node_id = None
|
||||
else:
|
||||
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
||||
|
||||
for graph_edge in target_node_edges:
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
parent_parallel=parallel if parallel else parent_parallel
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _check_exceed_parallel_limit(
|
||||
cls,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
level_limit: int,
|
||||
parent_parallel_id: str,
|
||||
current_level: int = 1
|
||||
) -> None:
|
||||
"""
|
||||
Check if it exceeds N layers of parallel
|
||||
"""
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
return
|
||||
|
||||
current_level += 1
|
||||
if current_level > level_limit:
|
||||
raise ValueError(f"Exceeds {level_limit} layers of parallel")
|
||||
|
||||
if parent_parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=level_limit,
|
||||
parent_parallel_id=parent_parallel.parent_parallel_id,
|
||||
current_level=current_level
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallel_node_ids(cls,
|
||||
branch_node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
merge_node_id: str,
|
||||
start_node_id: str) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param branch_node_ids: in branch node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param merge_node_id: merge node id
|
||||
:param start_node_id: start node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(start_node_id, []):
|
||||
if (graph_edge.target_node_id != merge_node_id
|
||||
and graph_edge.target_node_id not in branch_node_ids):
|
||||
branch_node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=branch_node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=graph_edge.target_node_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _fetch_all_node_ids_in_parallels(cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch all node ids in parallels
|
||||
"""
|
||||
routes_node_ids: dict[str, list[str]] = {}
|
||||
for parallel_branch_node_id in parallel_branch_node_ids:
|
||||
routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]
|
||||
|
||||
# fetch routes node ids
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=parallel_branch_node_id,
|
||||
routes_node_ids=routes_node_ids[parallel_branch_node_id]
|
||||
)
|
||||
|
||||
# fetch leaf node ids from routes node ids
|
||||
leaf_node_ids: dict[str, list[str]] = {}
|
||||
merge_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
|
||||
if branch_node_id not in leaf_node_ids:
|
||||
leaf_node_ids[branch_node_id] = []
|
||||
|
||||
leaf_node_ids[branch_node_id].append(node_id)
|
||||
|
||||
for branch_node_id2, inner_route2 in routes_node_ids.items():
|
||||
if (
|
||||
branch_node_id != branch_node_id2
|
||||
and node_id in inner_route2
|
||||
and len(reverse_edge_mapping.get(node_id, [])) > 1
|
||||
and cls._is_node_in_routes(
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=node_id,
|
||||
routes_node_ids=routes_node_ids
|
||||
)
|
||||
):
|
||||
if node_id not in merge_branch_node_ids:
|
||||
merge_branch_node_ids[node_id] = []
|
||||
|
||||
if branch_node_id2 not in merge_branch_node_ids[node_id]:
|
||||
merge_branch_node_ids[node_id].append(branch_node_id2)
|
||||
|
||||
# sorted merge_branch_node_ids by branch_node_ids length desc
|
||||
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
|
||||
|
||||
duplicate_end_node_ids = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
|
||||
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
|
||||
if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids:
|
||||
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
|
||||
|
||||
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
|
||||
# check which node is after
|
||||
if cls._is_node2_after_node1(
|
||||
node1_id=node_id,
|
||||
node2_id=node_id2,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
if node_id in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id2]
|
||||
elif cls._is_node2_after_node1(
|
||||
node1_id=node_id2,
|
||||
node2_id=node_id,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
if node_id2 in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id]
|
||||
|
||||
branches_merge_node_ids: dict[str, str] = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
if len(branch_node_ids) <= 1:
|
||||
continue
|
||||
|
||||
for branch_node_id in branch_node_ids:
|
||||
if branch_node_id in branches_merge_node_ids:
|
||||
continue
|
||||
|
||||
branches_merge_node_ids[branch_node_id] = node_id
|
||||
|
||||
in_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
in_branch_node_ids[branch_node_id] = []
|
||||
if branch_node_id not in branches_merge_node_ids:
|
||||
# all node ids in current branch is in this thread
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
in_branch_node_ids[branch_node_id].extend(node_ids)
|
||||
else:
|
||||
merge_node_id = branches_merge_node_ids[branch_node_id]
|
||||
if merge_node_id != branch_node_id:
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
|
||||
# fetch all node ids from branch_node_id and merge_node_id
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=in_branch_node_ids[branch_node_id],
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=branch_node_id
|
||||
)
|
||||
|
||||
return in_branch_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_fetch_routes(cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
routes_node_ids: list[str]) -> None:
|
||||
"""
|
||||
Recursively fetch route
|
||||
"""
|
||||
if start_node_id not in edge_mapping:
|
||||
return
|
||||
|
||||
for graph_edge in edge_mapping[start_node_id]:
|
||||
# find next node ids
|
||||
if graph_edge.target_node_id not in routes_node_ids:
|
||||
routes_node_ids.append(graph_edge.target_node_id)
|
||||
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
routes_node_ids=routes_node_ids
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _is_node_in_routes(cls,
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
routes_node_ids: dict[str, list[str]]) -> bool:
|
||||
"""
|
||||
Recursively check if the node is in the routes
|
||||
"""
|
||||
if start_node_id not in reverse_edge_mapping:
|
||||
return False
|
||||
|
||||
all_routes_node_ids = set()
|
||||
parallel_start_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
all_routes_node_ids.add(node_id)
|
||||
|
||||
if branch_node_id in reverse_edge_mapping:
|
||||
for graph_edge in reverse_edge_mapping[branch_node_id]:
|
||||
if graph_edge.source_node_id not in parallel_start_node_ids:
|
||||
parallel_start_node_ids[graph_edge.source_node_id] = []
|
||||
|
||||
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
|
||||
|
||||
parallel_start_node_id = None
|
||||
for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
|
||||
if set(branch_node_ids) == set(routes_node_ids.keys()):
|
||||
parallel_start_node_id = p_start_node_id
|
||||
return True
|
||||
|
||||
if not parallel_start_node_id:
|
||||
raise Exception("Parallel start node id not found")
|
||||
|
||||
for graph_edge in reverse_edge_mapping[start_node_id]:
|
||||
if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_node2_after_node1(
|
||||
cls,
|
||||
node1_id: str,
|
||||
node2_id: str,
|
||||
edge_mapping: dict[str, list[GraphEdge]]
|
||||
) -> bool:
|
||||
"""
|
||||
is node2 after node1
|
||||
"""
|
||||
if node1_id not in edge_mapping:
|
||||
return False
|
||||
|
||||
for graph_edge in edge_mapping[node1_id]:
|
||||
if graph_edge.target_node_id == node2_id:
|
||||
return True
|
||||
|
||||
if cls._is_node2_after_node1(
|
||||
node1_id=graph_edge.target_node_id,
|
||||
node2_id=node2_id,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -1,21 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
workflow_type: WorkflowType = Field(..., description="workflow type")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
||||
call_depth: int = Field(..., description="call depth")
|
||||
@ -1,27 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
variable_pool: VariablePool = Field(..., description="variable pool")
|
||||
"""variable pool"""
|
||||
|
||||
start_at: float = Field(..., description="start time")
|
||||
"""start time"""
|
||||
total_tokens: int = 0
|
||||
"""total tokens"""
|
||||
llm_usage: LLMUsage = LLMUsage.empty_usage()
|
||||
"""llm usage info"""
|
||||
outputs: dict[str, Any] = {}
|
||||
"""outputs"""
|
||||
|
||||
node_run_steps: int = 0
|
||||
"""node run steps"""
|
||||
|
||||
node_run_state: RuntimeRouteState = RuntimeRouteState()
|
||||
"""node run state"""
|
||||
@ -1,13 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import GraphParallel
|
||||
|
||||
|
||||
class NextGraphNode(BaseModel):
|
||||
node_id: str
|
||||
"""next node id"""
|
||||
|
||||
parallel: Optional[GraphParallel] = None
|
||||
"""parallel"""
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user