Compare commits

..

1 Commits

Author SHA1 Message Date
Yi
2474dbdff0 fix the tooltip for the knowledge base's firecrawl max depth attribute 2024-08-28 17:09:30 +08:00
330 changed files with 5931 additions and 16246 deletions

View File

@ -125,6 +125,7 @@ jobs:
with:
images: ${{ env[matrix.image_name_env] }}
tags: |
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}

View File

@ -1,54 +0,0 @@
name: Check i18n Files and Create PR
on:
pull_request:
types: [closed]
branches: [main]
jobs:
check-and-update:
if: github.event.pull_request.merged == true
runs-on: ubuntu-latest
defaults:
run:
working-directory: web
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 2 # last 2 commits
- name: Check for file changes in i18n/en-US
id: check_files
run: |
recent_commit_sha=$(git rev-parse HEAD)
second_recent_commit_sha=$(git rev-parse HEAD~1)
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
fi
- name: Set up Node.js
if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v2
with:
node-version: 'lts/*'
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
run: yarn install --frozen-lockfile
- name: Run npm script
if: env.FILES_CHANGED == 'true'
run: npm run auto-gen-i18n
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6
with:
commit-message: Update i18n files based on en-US changes
title: 'chore: translate i18n files'
body: This PR was automatically created to update i18n files based on changes in en-US locale.
branch: chore/automated-i18n-updates

View File

@ -8,7 +8,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
## Before you jump in
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
### Feature requests:

View File

@ -8,7 +8,7 @@
## 在开始之前
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:open)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
### 功能请求:

View File

@ -10,7 +10,7 @@ Dify にコントリビュートしたいとお考えなのですね。それは
## 飛び込む前に
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:open) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
### 機能リクエスト

View File

@ -8,7 +8,7 @@ Về vấn đề cấp phép, xin vui lòng dành chút thời gian đọc qua [
## Trước khi bắt đầu
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:open) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
### Yêu cầu tính năng:

View File

@ -60,8 +60,7 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key
ALIYUN_OSS_ENDPOINT=your-endpoint
ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
# Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string

View File

@ -55,7 +55,7 @@ RUN apt-get update \
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
&& apt-get update \
# For Security
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-2 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*

View File

@ -559,9 +559,8 @@ def add_qdrant_doc_id_index(field: str):
@click.command("create-tenant", help="Create account and tenant.")
@click.option("--email", prompt=True, help="The email address of the tenant account.")
@click.option("--name", prompt=True, help="The workspace name of the tenant account.")
@click.option("--language", prompt=True, help="Account language, default: en-US.")
def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None):
def create_tenant(email: str, language: Optional[str] = None):
"""
Create tenant account
"""
@ -581,15 +580,13 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
if language not in languages:
language = "en-US"
name = name.strip()
# generate random password
new_password = secrets.token_urlsafe(16)
# register account
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
TenantService.create_owner_tenant_if_not_exist(account, name)
TenantService.create_owner_tenant_if_not_exist(account)
click.echo(
click.style(

View File

@ -1,4 +1,4 @@
from typing import Annotated, Optional
from typing import Optional
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings
@ -217,17 +217,20 @@ class HttpConfig(BaseSettings):
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request")
] = 10
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: NonNegativeInt = Field(
description="",
default=300,
)
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request")
] = 60
HTTP_REQUEST_MAX_READ_TIMEOUT: NonNegativeInt = Field(
description="",
default=600,
)
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request")
] = 20
HTTP_REQUEST_MAX_WRITE_TIMEOUT: NonNegativeInt = Field(
description="",
default=600,
)
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
description="",

View File

@ -38,8 +38,3 @@ class AliyunOSSStorageConfig(BaseSettings):
description="Aliyun OSS authentication version",
default=None,
)
ALIYUN_OSS_PATH: Optional[str] = Field(
description="Aliyun OSS path",
default=None,
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.8.0-beta1",
default="0.7.2",
)
COMMIT_SHA: str = Field(

View File

@ -174,7 +174,6 @@ class AppApi(Resource):
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("max_active_requests", type=int, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
args = parser.parse_args()
app_service = AppService()

View File

@ -34,7 +34,6 @@ def parse_app_site_args():
)
parser.add_argument("prompt_public", type=bool, required=False, location="json")
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
return parser.parse_args()
@ -69,7 +68,6 @@ class AppSite(Resource):
"customize_token_strategy",
"prompt_public",
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = args.get(attr_name)
if value is not None:

View File

@ -122,7 +122,6 @@ class DatasetListApi(Resource):
name=args["name"],
indexing_technique=args["indexing_technique"],
account=current_user,
permission=DatasetPermissionEnum.ONLY_ME,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()

View File

@ -39,7 +39,7 @@ class FileApi(Resource):
@login_required
@account_initialization_required
@marshal_with(file_fields)
@cloud_edition_billing_resource_check("documents")
@cloud_edition_billing_resource_check(resource="documents")
def post(self):
# get file from request
file = request.files["file"]

View File

@ -35,7 +35,6 @@ class InstalledAppsListApi(Resource):
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
}
for installed_app in installed_apps
if installed_app.app is not None
]
installed_apps.sort(
key=lambda app: (

View File

@ -46,7 +46,9 @@ def only_edition_self_hosted(view):
return decorated
def cloud_edition_billing_resource_check(resource: str):
def cloud_edition_billing_resource_check(
resource: str, error_msg: str = "You have reached the limit of your subscription."
):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
@ -58,22 +60,22 @@ def cloud_edition_billing_resource_check(resource: str):
documents_upload_quota = features.documents_upload_quota
annotation_quota_limit = features.annotation_quota_limit
if resource == "members" and 0 < members.limit <= members.size:
abort(403, "The number of members has reached the limit of your subscription.")
abort(403, error_msg)
elif resource == "apps" and 0 < apps.limit <= apps.size:
abort(403, "The number of apps has reached the limit of your subscription.")
abort(403, error_msg)
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
abort(403, error_msg)
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
source = request.args.get("source")
if source == "datasets":
abort(403, "The number of documents has reached the limit of your subscription.")
abort(403, error_msg)
else:
return view(*args, **kwargs)
elif resource == "workspace_custom" and not features.can_replace_logo:
abort(403, "The workspace custom feature has reached the limit of your subscription.")
abort(403, error_msg)
elif resource == "annotation" and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:
abort(403, "The annotation quota has reached the limit of your subscription.")
abort(403, error_msg)
else:
return view(*args, **kwargs)
@ -84,7 +86,10 @@ def cloud_edition_billing_resource_check(resource: str):
return interceptor
def cloud_edition_billing_knowledge_limit_check(resource: str):
def cloud_edition_billing_knowledge_limit_check(
resource: str,
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
@ -92,10 +97,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
abort(
403,
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
)
abort(403, error_msg)
else:
return view(*args, **kwargs)

View File

@ -36,10 +36,6 @@ class SegmentApi(DatasetApiResource):
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
if document.indexing_status != "completed":
raise NotFound("Document is already completed.")
if not document.enabled:
raise NotFound("Document is disabled.")
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:

View File

@ -83,7 +83,9 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
return decorator(view)
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
def cloud_edition_billing_resource_check(
resource: str, api_token_type: str, error_msg: str = "You have reached the limit of your subscription."
):
def interceptor(view):
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token(api_token_type)
@ -96,13 +98,13 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
documents_upload_quota = features.documents_upload_quota
if resource == "members" and 0 < members.limit <= members.size:
raise Forbidden("The number of members has reached the limit of your subscription.")
raise Forbidden(error_msg)
elif resource == "apps" and 0 < apps.limit <= apps.size:
raise Forbidden("The number of apps has reached the limit of your subscription.")
raise Forbidden(error_msg)
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.")
raise Forbidden(error_msg)
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
raise Forbidden("The number of documents has reached the limit of your subscription.")
raise Forbidden(error_msg)
else:
return view(*args, **kwargs)
@ -113,7 +115,11 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
return interceptor
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
def cloud_edition_billing_knowledge_limit_check(
resource: str,
api_token_type: str,
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
@ -122,9 +128,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
raise Forbidden(
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
)
raise Forbidden(error_msg)
else:
return view(*args, **kwargs)

View File

@ -39,7 +39,6 @@ class AppSiteApi(WebApiResource):
"default_language": fields.String,
"prompt_public": fields.Boolean,
"show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean,
}
app_fields = {

View File

@ -93,7 +93,7 @@ class DatasetConfigManager:
reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get('weights'),
reranking_enabled=dataset_configs.get('reranking_enabled', True),
rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'),
rerank_mode=dataset_configs.get('rerank_mode', 'reranking_model'),
)
)

View File

@ -4,10 +4,12 @@ import os
import threading
import uuid
from collections.abc import Generator
from typing import Any, Literal, Optional, Union, overload
from typing import Union
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@ -18,49 +20,33 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
from models.workflow import ConversationVariable, Workflow
logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
) -> Generator[str, None, None]: ...
@overload
def generate(
self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
stream: bool = True,
):
"""
Generate App response.
@ -148,8 +134,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
node_id: str,
user: Account,
args: dict,
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
stream: bool = True):
"""
Generate App response.
@ -166,6 +151,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
# get conversation
conversation = None
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
@ -176,16 +171,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
conversation_id=None,
conversation_id=conversation.id if conversation else None,
inputs={},
query='',
files=[],
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras={
"auto_generate_conversation_name": False
},
extras=extras,
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
@ -198,28 +191,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
conversation=None,
conversation=conversation,
stream=stream
)
def _generate(self, *,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None,
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
:param workflow: Workflow
:param user: account or end user
:param invoke_from: invoke from source
:param application_generate_entity: application generate entity
:param conversation: conversation
:param stream: is stream
"""
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation | None = None,
stream: bool = True):
is_first_conversation = False
if not conversation:
is_first_conversation = True
@ -234,7 +216,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
db.session.refresh(conversation)
# db.session.refresh(conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@ -246,12 +228,67 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id
)
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
conversation.dialogue_count += 1
conversation_id = conversation.id
conversation_dialogue_count = conversation.dialogue_count
db.session.commit()
db.session.refresh(conversation)
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
contexts.workflow_variable_pool.set(variable_pool)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
'context': contextvars.copy_context(),
})
@ -277,7 +314,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
def _generate_worker(self, flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
context: contextvars.Context) -> None:
"""
@ -293,19 +329,28 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
var.set(val)
with flask_app.app_context():
try:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
runner = AdvancedChatAppRunner()
if application_generate_entity.single_iteration_run:
single_iteration_run = application_generate_entity.single_iteration_run
runner.single_iteration_run(
app_id=application_generate_entity.app_config.app_id,
workflow_id=application_generate_entity.app_config.workflow_id,
queue_manager=queue_manager,
inputs=single_iteration_run.inputs,
node_id=single_iteration_run.node_id,
user_id=application_generate_entity.user_id
)
else:
# get message
message = self._get_message(message_id)
# chatbot app
runner = AdvancedChatAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
runner.run()
# chatbot app
runner = AdvancedChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message=message
)
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:

View File

@ -1,67 +1,49 @@
import logging
import os
import time
from collections.abc import Mapping
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing import Any, Optional, cast
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
QueueStopEvent,
QueueTextChunkEvent,
)
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, WorkflowType
from models import App, Message, Workflow
logger = logging.getLogger(__name__)
class AdvancedChatAppRunner(WorkflowBasedAppRunner):
class AdvancedChatAppRunner(AppRunner):
"""
AdvancedChat Application Runner
"""
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message
def run(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
message: Message,
) -> None:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
"""
super().__init__(queue_manager)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
def run(self) -> None:
"""
Run application
:return:
"""
app_config = self.application_generate_entity.app_config
app_config = application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
@ -72,133 +54,101 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not workflow:
raise ValueError('Workflow not initialized')
user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
inputs = application_generate_entity.inputs
query = application_generate_entity.query
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# moderation
if self.handle_input_moderation(
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id,
):
return
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs
)
else:
inputs = self.application_generate_entity.inputs
query = self.application_generate_entity.query
files = self.application_generate_entity.files
# moderation
if self.handle_input_moderation(
app_record=app_record,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
message_id=self.message.id
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity
):
return
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
self.conversation.dialogue_count += 1
conversation_dialogue_count = self.conversation.dialogue_count
db.session.commit()
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=message,
query=query,
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
):
return
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
)
generator = workflow_entry.run(
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
workflow=workflow,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
)
for event in generator:
self._handle_event(workflow_entry, event)
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def handle_input_moderation(
self,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str
self,
queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
) -> bool:
"""
Handle input moderation
:param queue_manager: application queue manager
:param app_record: app record
:param app_generate_entity: application generate entity
:param inputs: inputs
@ -217,23 +167,30 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
message_id=message_id,
)
except ModerationException as e:
self._complete_with_stream_output(
self._stream_output(
queue_manager=queue_manager,
text=str(e),
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
)
return True
return False
def handle_annotation_reply(self, app_record: App,
message: Message,
query: str,
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
def handle_annotation_reply(
self,
app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity,
) -> bool:
"""
Handle annotation reply
:param app_record: app record
:param message: message
:param query: query
:param queue_manager: application queue manager
:param app_generate_entity: application generate entity
"""
# annotation reply
@ -246,32 +203,37 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
if annotation_reply:
self._publish_event(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
)
self._complete_with_stream_output(
self._stream_output(
queue_manager=queue_manager,
text=annotation_reply.content,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
)
return True
return False
def _complete_with_stream_output(self,
text: str,
stopped_by: QueueStopEvent.StopBy) -> None:
def _stream_output(
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
) -> None:
"""
Direct output
:param queue_manager: application queue manager
:param text: text
:param stream: stream
:return:
"""
self._publish_event(
QueueTextChunkEvent(
text=text
)
)
if stream:
index = 0
for token in text:
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
index += 1
time.sleep(0.01)
else:
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
self._publish_event(
QueueStopEvent(stopped_by=stopped_by)
)
queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)

View File

@ -2,8 +2,9 @@ import json
import logging
import time
from collections.abc import Generator
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast
import contexts
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -21,9 +22,6 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
@ -33,28 +31,34 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
AdvancedChatTaskState,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ChatflowStreamGenerateRoute,
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
StreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from events.message_event import message_was_created
from extensions.ext_database import db
from models.account import Account
from models.model import Conversation, EndUser, Message
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowRunStatus,
)
@ -65,15 +69,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: WorkflowTaskState
_task_state: AdvancedChatTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
_user: Union[Account, EndUser]
# Deprecated
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
@ -101,6 +106,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._workflow = workflow
self._conversation = conversation
self._message = message
# Deprecated
self._workflow_system_variables = {
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
@ -108,8 +114,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
SystemVariableKey.USER_ID: user_id,
}
self._task_state = WorkflowTaskState()
self._task_state = AdvancedChatTaskState(
usage=LLMUsage.empty_usage()
)
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._stream_generate_routes = self._get_stream_generate_routes()
self._conversation_name_generate_thread = None
def process(self):
@ -130,7 +140,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
return self._to_stream_response(generator)
else:
@ -190,18 +199,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
tts_publisher = None
publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -212,9 +220,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not tts_publisher:
if not publisher:
break
audio_trunk = tts_publisher.checkAndGetAudio()
audio_trunk = publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@ -232,34 +240,34 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
def _process_stream_response(
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
# init fake graph runtime state
graph_runtime_state = None
workflow_run = None
for message in self._queue_manager.listen():
if (message.event
and getattr(message.event, 'metadata', None)
and message.event.metadata.get('is_answer_previous_node', False)
and publisher):
publisher.publish(message=message)
elif (hasattr(message.event, 'execution_metadata')
and message.event.execution_metadata
and message.event.execution_metadata.get('is_answer_previous_node', False)
and publisher):
publisher.publish(message=message)
event = message.event
for queue_message in self._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
if isinstance(event, QueueErrorEvent):
err = self._handle_error(event, self._message)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
workflow_run = self._handle_workflow_start()
# init workflow run
workflow_run = self._handle_workflow_run_start()
self._refetch_message()
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._message.workflow_run_id = workflow_run.id
db.session.commit()
@ -271,242 +279,133 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_run=workflow_run
)
elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
workflow_node_execution = self._handle_node_start(event)
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
event=event
)
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
# reset current route position to 0
self._task_state.current_stream_generate_state.current_route_position = 0
response = self._workflow_node_start_to_stream_response(
# generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
yield self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._handle_node_finished(event)
if response:
yield response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
# stream outputs when node finished
generator = self._generate_stream_outputs_when_node_finished()
if generator:
yield from generator
response = self._workflow_node_finish_to_stream_response(
event=event,
yield self._workflow_node_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if event.outputs else None,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED,
error=event.error,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
elif isinstance(event, QueueStopEvent):
if workflow_run and graph_runtime_state:
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED,
error=event.get_stop_reason(),
conversation_id=self._conversation.id,
trace_manager=trace_manager,
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(
event, conversation_id=self._conversation.id, trace_manager=trace_manager
)
if workflow_run:
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
if workflow_run.status == WorkflowRunStatus.FAILED.value:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
if isinstance(event, QueueStopEvent):
# Save message
self._save_message()
yield self._message_end_to_stream_response()
break
else:
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message
self._save_message(graph_runtime_state=graph_runtime_state)
self._save_message()
yield self._message_end_to_stream_response()
break
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)
self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
continue
if not self._is_stream_out_support(
event=event
):
continue
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
if should_direct_answer:
continue
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
self._task_state.answer += delta_text
yield self._message_to_stream_response(delta_text, self._message.id)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message
self._save_message(graph_runtime_state=graph_runtime_state)
yield self._message_end_to_stream_response()
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
else:
continue
# publish None when task finished
if tts_publisher:
tts_publisher.publish(None)
if publisher:
publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
def _save_message(self) -> None:
"""
Save message.
:return:
"""
self._refetch_message()
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
if self._task_state.metadata and self._task_state.metadata.get('usage'):
usage = LLMUsage(**self._task_state.metadata['usage'])
self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
@ -533,10 +432,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
extras = {}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata.copy()
if 'annotation_reply' in extras['metadata']:
del extras['metadata']['annotation_reply']
extras['metadata'] = self._task_state.metadata
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
@ -544,6 +440,323 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
**extras
)
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
"""
Get stream generate routes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
answer_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.ANSWER.value
]
# parse stream output node value selectors of answer nodes
stream_generate_routes = {}
for node_config in answer_node_configs:
# get generate route for stream output
answer_node_id = node_config['id']
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id)
if not start_node_ids:
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
answer_node_id=answer_node_id,
generate_route=generate_route
)
return stream_generate_routes
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
Get answer start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edges = []
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edges.append(edge)
if not ingoing_edges:
# check if it's the first node in the iteration
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
if not target_node:
return []
node_iteration_id = target_node.get('data', {}).get('iteration_id')
# get iteration start node id
for node in nodes:
if node.get('id') == node_iteration_id:
if node.get('data', {}).get('start_node_id') == target_node_id:
return [target_node_id]
return []
start_node_ids = []
for ingoing_edge in ingoing_edges:
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value,
NodeType.ITERATION.value,
NodeType.LOOP.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id)
if sub_start_node_ids:
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
:return:
"""
if self._task_state.current_stream_generate_state:
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:
]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
if should_direct_answer:
continue
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
break
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]:
"""
Generate stream outputs.
:return:
"""
if not self._task_state.current_stream_generate_state:
return
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
value = None
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
self._task_state.current_stream_generate_state.current_route_position += 1
continue
route_chunk_node_id = value_selector[0]
if route_chunk_node_id == 'sys':
# system variable
value = contexts.workflow_variable_pool.get().get(value_selector)
if value:
value = value.text
elif route_chunk_node_id in self._iteration_nested_relations:
# it's a iteration variable
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
continue
iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
iterator = iteration_state.inputs
if not iterator:
continue
iterator_selector = iterator.get('iterator_selector', [])
if value_selector[1] == 'index':
value = iteration_state.current_index
elif value_selector[1] == 'item':
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
iterator_selector
) else None
else:
# check chunk node id is before current node id or equal to current node id
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
break
latest_node_execution_info = self._task_state.latest_node_execution_info
# get route chunk node execution info
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
if (route_chunk_node_execution_info.node_type == NodeType.LLM
and latest_node_execution_info.node_type == NodeType.LLM):
# only LLM support chunk stream output
self._task_state.current_stream_generate_state.current_route_position += 1
continue
# get route chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id
).first()
outputs = route_chunk_node_execution.outputs_dict
# get value from outputs
value = None
for key in value_selector[1:]:
if not value:
value = outputs.get(key) if outputs else None
else:
value = value.get(key)
if value is not None:
text = ''
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, FileVar):
# convert file to markdown
text = value.to_markdown()
elif isinstance(value, dict):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
if file_vars:
file_var = file_vars[0]
try:
file_var_obj = FileVar(**file_var)
# convert file to markdown
text = file_var_obj.to_markdown()
except Exception as e:
logger.error(f'Error creating file var: {e}')
if not text:
# other types
text = json.dumps(value, ensure_ascii=False)
elif isinstance(value, list):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
for file_var in file_vars:
try:
file_var_obj = FileVar(**file_var)
except Exception as e:
logger.error(f'Error creating file var: {e}')
continue
# convert file to markdown
text = file_var_obj.to_markdown() + ' '
text = text.strip()
if not text and value:
# other types
text = json.dumps(value, ensure_ascii=False)
if text:
self._task_state.answer += text
yield self._message_to_stream_response(text, self._message.id)
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return True
if 'node_id' not in event.metadata:
return True
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
route_chunk = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position]
if route_chunk.type != 'var':
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
return False
return True
def _handle_output_moderation_chunk(self, text: str) -> bool:
"""
Handle output moderation chunk.
@ -569,12 +782,3 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._output_moderation_handler.append_new_token(text)
return False
def _refetch_message(self) -> None:
"""
Refetch message.
:return:
"""
message = db.session.query(Message).filter(Message.id == self._message.id).first()
if message:
self._message = message

View File

@ -0,0 +1,203 @@
from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow
class WorkflowEventTriggerCallback(WorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self._queue_manager.publish(
QueueWorkflowStartedEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self._queue_manager.publish(
QueueWorkflowSucceededEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self._queue_manager.publish(
QueueWorkflowFailedEvent(
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish(
QueueNodeStartedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
node_run_index=node_run_index,
predecessor_node_id=predecessor_node_id
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
self._queue_manager.publish(
QueueNodeSucceededEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
"""
Workflow node execute failed
"""
self._queue_manager.publish(
QueueNodeFailedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
outputs=outputs,
process_data=process_data,
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self._queue_manager.publish(
QueueIterationStartEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
node_data=node_data,
inputs=inputs,
predecessor_node_id=predecessor_node_id,
metadata=metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any]) -> None:
"""
Publish iteration next
"""
self._queue_manager._publish(
QueueIterationNextEvent(
node_id=node_id,
node_type=node_type,
index=index,
node_run_index=node_run_index,
output=output
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self._queue_manager._publish(
QueueIterationCompletedEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
outputs=outputs
),
PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
self._queue_manager.publish(
event,
PublishFrom.APPLICATION_MANAGER
)

View File

@ -3,7 +3,7 @@ import os
import threading
import uuid
from collections.abc import Generator
from typing import Any, Literal, Union, overload
from typing import Any, Union
from flask import Flask, current_app
from pydantic import ValidationError
@ -28,24 +28,6 @@ logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
) -> Generator[dict, None, None]: ...
@overload
def generate(
self, app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
def generate(self, app_model: App,
user: Union[Account, EndUser],
args: Any,

View File

@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC):
def convert(cls, response: Union[
AppBlockingResponse,
Generator[AppStreamResponse, Any, None]
], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]:
], invoke_from: InvokeFrom):
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)

View File

@ -1,6 +1,6 @@
import time
from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any, Optional, Union
from collections.abc import Generator
from typing import TYPE_CHECKING, Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -347,7 +347,7 @@ class AppRunner:
self, app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
inputs: dict,
query: str,
message_id: str,
) -> tuple[bool, dict, str]:

View File

@ -3,7 +3,7 @@ import os
import threading
import uuid
from collections.abc import Generator
from typing import Any, Literal, Union, overload
from typing import Any, Union
from flask import Flask, current_app
from pydantic import ValidationError
@ -28,31 +28,13 @@ logger = logging.getLogger(__name__)
class ChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
) -> Generator[str, None, None]: ...
@overload
def generate(
self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
def generate(
self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[str, None, None]]:
) -> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.

View File

@ -3,7 +3,7 @@ import os
import threading
import uuid
from collections.abc import Generator
from typing import Any, Literal, Union, overload
from typing import Any, Union
from flask import Flask, current_app
from pydantic import ValidationError
@ -30,30 +30,12 @@ logger = logging.getLogger(__name__)
class CompletionAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
) -> Generator[str, None, None]: ...
@overload
def generate(
self, app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
def generate(self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator[str, None, None]]:
-> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
@ -221,7 +203,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator[str, None, None]]:
-> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.

View File

@ -4,7 +4,7 @@ import os
import threading
import uuid
from collections.abc import Generator
from typing import Any, Literal, Optional, Union, overload
from typing import Union
from flask import Flask, current_app
from pydantic import ValidationError
@ -32,40 +32,14 @@ logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
@overload
def generate(
self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
) -> Generator[str, None, None]: ...
@overload
def generate(
self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
) -> dict: ...
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
):
"""
Generate App response.
@ -77,7 +51,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
:param call_depth: call depth
:param workflow_thread_pool_id: workflow thread pool id
"""
inputs = args['inputs']
@ -125,19 +98,16 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
stream=stream,
workflow_thread_pool_id=workflow_thread_pool_id
)
def _generate(
self, *,
app_model: App,
self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
stream: bool = True,
workflow_thread_pool_id: Optional[str] = None
) -> dict[str, Any] | Generator[str, None, None]:
) -> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
@ -147,7 +117,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param stream: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
# init queue manager
queue_manager = WorkflowAppQueueManager(
@ -159,11 +128,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'context': contextvars.copy_context(),
'workflow_thread_pool_id': workflow_thread_pool_id
'context': contextvars.copy_context()
})
worker_thread.start()
@ -187,7 +155,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
node_id: str,
user: Account,
args: dict,
stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]:
stream: bool = True):
"""
Generate App response.
@ -204,6 +172,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
@ -219,9 +191,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras={
"auto_generate_conversation_name": False
},
extras=extras,
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
@ -241,14 +211,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
def _generate_worker(self, flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
workflow_thread_pool_id: Optional[str] = None) -> None:
context: contextvars.Context) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
for var, val in context.items():
@ -256,13 +224,22 @@ class WorkflowAppGenerator(BaseAppGenerator):
with flask_app.app_context():
try:
# workflow app
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id
)
runner.run()
runner = WorkflowAppRunner()
if application_generate_entity.single_iteration_run:
single_iteration_run = application_generate_entity.single_iteration_run
runner.single_iteration_run(
app_id=application_generate_entity.app_config.app_id,
workflow_id=application_generate_entity.app_config.workflow_id,
queue_manager=queue_manager,
inputs=single_iteration_run.inputs,
node_id=single_iteration_run.node_id,
user_id=application_generate_entity.user_id
)
else:
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager
)
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
@ -274,14 +251,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true':
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
db.session.remove()
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,

View File

@ -4,61 +4,46 @@ from typing import Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.model import App, EndUser
from models.workflow import WorkflowType
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class WorkflowAppRunner(WorkflowBasedAppRunner):
class WorkflowAppRunner:
"""
Workflow Application Runner
"""
def __init__(
self,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
workflow_thread_pool_id: Optional[str] = None
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
self.application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id
def run(self) -> None:
def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:return:
"""
app_config = self.application_generate_entity.app_config
app_config = application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
user_id = application_generate_entity.user_id
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
@ -68,64 +53,80 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
if not workflow:
raise ValueError('Workflow not initialized')
inputs = application_generate_entity.inputs
files = application_generate_entity.files
db.session.close()
workflow_callbacks: list[WorkflowCallback] = []
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = {
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
# Create a variable pool.
system_inputs = {
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
workflow=workflow,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id
)
generator = workflow_entry.run(
callbacks=workflow_callbacks
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')
if not app_record.workflow_id:
raise ValueError('Workflow not initialized')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
)
for event in generator:
self._handle_event(workflow_entry, event)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow

View File

@ -1,4 +1,3 @@
import json
import logging
import time
from collections.abc import Generator
@ -16,12 +15,10 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
@ -35,16 +32,19 @@ from core.app.entities.task_entities import (
MessageAudioStreamResponse,
StreamResponse,
TextChunkStreamResponse,
TextReplaceStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowStreamGenerateNodes,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
@ -52,8 +52,8 @@ from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun,
WorkflowRunStatus,
)
logger = logging.getLogger(__name__)
@ -68,6 +68,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
@ -95,7 +96,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
SystemVariableKey.USER_ID: user_id
}
self._task_state = WorkflowTaskState()
self._task_state = WorkflowTaskState(
iteration_nested_node_ids=[]
)
self._stream_generate_nodes = self._get_stream_generate_nodes()
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@ -124,20 +129,23 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, WorkflowFinishStreamResponse):
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == self._task_state.workflow_run_id).first()
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.id,
workflow_run_id=workflow_run.id,
data=WorkflowAppBlockingResponse.Data(
id=stream_response.data.id,
workflow_id=stream_response.data.workflow_id,
status=stream_response.data.status,
outputs=stream_response.data.outputs,
error=stream_response.data.error,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at)
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
status=workflow_run.status,
outputs=workflow_run.outputs_dict,
error=workflow_run.error,
elapsed_time=workflow_run.elapsed_time,
total_tokens=workflow_run.total_tokens,
total_steps=workflow_run.total_steps,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp())
)
)
@ -153,13 +161,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
To stream response.
:return:
"""
workflow_run_id = None
for stream_response in generator:
if isinstance(stream_response, WorkflowStartStreamResponse):
workflow_run_id = stream_response.workflow_run_id
yield WorkflowAppStreamResponse(
workflow_run_id=workflow_run_id,
workflow_run_id=self._task_state.workflow_run_id,
stream_response=stream_response
)
@ -174,18 +178,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
tts_publisher = None
publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -195,9 +198,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
start_listener_time = time.time()
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not tts_publisher:
if not publisher:
break
audio_trunk = tts_publisher.checkAndGetAudio()
audio_trunk = publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@ -215,159 +218,69 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
def _process_stream_response(
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
graph_runtime_state = None
workflow_run = None
for message in self._queue_manager.listen():
if publisher:
publisher.publish(message=message)
event = message.event
for queue_message in self._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
if isinstance(event, QueueErrorEvent):
err = self._handle_error(event)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
# init workflow run
workflow_run = self._handle_workflow_run_start()
workflow_run = self._handle_workflow_start()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
workflow_node_execution = self._handle_node_start(event)
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
event=event
)
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes:
self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id]
response = self._workflow_node_start_to_stream_response(
# generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
yield self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._handle_node_finished(event)
if response:
yield response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
response = self._workflow_node_finish_to_stream_response(
event=event,
yield self._workflow_node_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(
event, trace_manager=trace_manager
)
# save workflow app log
@ -382,17 +295,22 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if delta_text is None:
continue
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
if not self._is_stream_out_support(
event=event
):
continue
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(delta_text)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._text_replace_to_stream_response(event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
else:
continue
if tts_publisher:
tts_publisher.publish(None)
if publisher:
publisher.publish(None)
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
@ -411,15 +329,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# not save log for debugging
return
workflow_app_log = WorkflowAppLog()
workflow_app_log.tenant_id = workflow_run.tenant_id
workflow_app_log.app_id = workflow_run.app_id
workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user'
workflow_app_log.created_by = self._user.id
workflow_app_log = WorkflowAppLog(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
workflow_run_id=workflow_run.id,
created_from=created_from.value,
created_by_role=('account' if isinstance(self._user, Account) else 'end_user'),
created_by=self._user.id,
)
db.session.add(workflow_app_log)
db.session.commit()
db.session.close()
@ -436,3 +354,180 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
)
return response
def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse:
"""
Text replace to stream response.
:param text: text
:return:
"""
return TextReplaceStreamResponse(
task_id=self._application_generate_entity.task_id,
text=TextReplaceStreamResponse.Data(text=text)
)
def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]:
"""
Get stream generate nodes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
end_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.END.value
]
# parse stream output node value selectors of end nodes
stream_generate_routes = {}
for node_config in end_node_configs:
# get generate route for stream output
end_node_id = node_config['id']
generate_nodes = EndNode.extract_generate_nodes(graph, node_config)
start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id)
if not start_node_ids:
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes(
end_node_id=end_node_id,
stream_node_ids=generate_nodes
)
return stream_generate_routes
def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
Get end start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edges = []
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edges.append(edge)
if not ingoing_edges:
return []
start_node_ids = []
for ingoing_edge in ingoing_edges:
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id)
if sub_start_node_ids:
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
:return:
"""
if self._task_state.current_stream_generate_state:
stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids
for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items():
if node_id not in stream_node_ids:
continue
node_execution_info = self._task_state.ran_node_execution_infos[node_id]
# get chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first()
if not route_chunk_node_execution:
continue
outputs = route_chunk_node_execution.outputs_dict
if not outputs:
continue
# get value from outputs
text = outputs.get('text')
if text:
self._task_state.answer += text
yield self._text_chunk_to_stream_response(text)
db.session.close()
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return False
if 'node_id' not in event.metadata:
return False
node_id = event.metadata.get('node_id')
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
if node_id not in self._task_state.current_stream_generate_state.stream_node_ids:
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
return True
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}

View File

@ -0,0 +1,200 @@
from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow
class WorkflowEventTriggerCallback(WorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self._queue_manager.publish(
QueueWorkflowStartedEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self._queue_manager.publish(
QueueWorkflowSucceededEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self._queue_manager.publish(
QueueWorkflowFailedEvent(
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish(
QueueNodeStartedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
node_run_index=node_run_index,
predecessor_node_id=predecessor_node_id
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
self._queue_manager.publish(
QueueNodeSucceededEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
"""
Workflow node execute failed
"""
self._queue_manager.publish(
QueueNodeFailedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
outputs=outputs,
process_data=process_data,
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self._queue_manager.publish(
QueueIterationStartEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
node_data=node_data,
inputs=inputs,
predecessor_node_id=predecessor_node_id,
metadata=metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any]) -> None:
"""
Publish iteration next
"""
self._queue_manager.publish(
QueueIterationNextEvent(
node_id=node_id,
node_type=node_type,
index=index,
node_run_index=node_run_index,
output=output
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self._queue_manager.publish(
QueueIterationCompletedEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
outputs=outputs
),
PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
pass

View File

@ -1,374 +0,0 @@
from collections.abc import Mapping
from typing import Any, Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.nodes.node_mapping import node_classes
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
from models.workflow import Workflow
class WorkflowBasedAppRunner(AppRunner):
def __init__(self, queue_manager: AppQueueManager):
self.queue_manager = queue_manager
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
"""
Init graph
"""
if 'nodes' not in graph_config or 'edges' not in graph_config:
raise ValueError('nodes or edges not found in workflow graph')
if not isinstance(graph_config.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list')
if not isinstance(graph_config.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# init graph
graph = Graph.init(
graph_config=graph_config
)
if not graph:
raise ValueError('graph not found in workflow')
return graph
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
"""
# fetch workflow graph
graph_config = workflow.graph_dict
if not graph_config:
raise ValueError('workflow graph not found')
graph_config = cast(dict[str, Any], graph_config)
if 'nodes' not in graph_config or 'edges' not in graph_config:
raise ValueError('nodes or edges not found in workflow graph')
if not isinstance(graph_config.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list')
if not isinstance(graph_config.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# filter nodes only in iteration
node_configs = [
node for node in graph_config.get('nodes', [])
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
]
graph_config['nodes'] = node_configs
node_ids = [node.get('id') for node in node_configs]
# filter edges only in iteration
edge_configs = [
edge for edge in graph_config.get('edges', [])
if (edge.get('source') is None or edge.get('source') in node_ids)
and (edge.get('target') is None or edge.get('target') in node_ids)
]
graph_config['edges'] = edge_configs
# init graph
graph = Graph.init(
graph_config=graph_config,
root_node_id=node_id
)
if not graph:
raise ValueError('graph not found in workflow')
# fetch node config from node id
iteration_node_config = None
for node in node_configs:
if node.get('id') == node_id:
iteration_node_config = node
break
if not iteration_node_config:
raise ValueError('iteration node id not found in workflow graph')
# Get node class
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
node_cls = cast(type[BaseNode], node_cls)
# init variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
environment_variables=workflow.environment_variables,
)
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict,
config=iteration_node_config
)
except NotImplementedError:
variable_mapping = {}
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
node_type=node_type,
node_data=IterationNodeData(**iteration_node_config.get('data', {}))
)
return graph, variable_pool
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
"""
Handle event
:param workflow_entry: workflow entry
:param event: event
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(
QueueWorkflowStartedEvent(
graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state
)
)
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(
QueueWorkflowSucceededEvent(outputs=event.outputs)
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(
QueueWorkflowFailedEvent(error=event.error)
)
elif isinstance(event, NodeRunStartedEvent):
self._publish_event(
QueueNodeStartedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id
)
)
elif isinstance(event, NodeRunSucceededEvent):
self._publish_event(
QueueNodeSucceededEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result else {},
)
)
elif isinstance(event, NodeRunFailedEvent):
self._publish_event(
QueueNodeFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result
and event.route_node_state.node_run_result.error
else "Unknown error"
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
text=event.chunk_content,
from_variable_selector=event.from_variable_selector
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):
self._publish_event(
QueueRetrieverResourcesEvent(
retriever_resources=event.retriever_resources
)
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self._publish_event(
QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, ParallelBranchRunSucceededEvent):
self._publish_event(
QueueParallelBranchRunSucceededEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, ParallelBranchRunFailedEvent):
self._publish_event(
QueueParallelBranchRunFailedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
error=event.error
)
)
elif isinstance(event, IterationRunStartedEvent):
self._publish_event(
QueueIterationStartEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata
)
)
elif isinstance(event, IterationRunNextEvent):
self._publish_event(
QueueIterationNextEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
)
)
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
self._publish_event(
QueueIterationCompletedEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, IterationRunFailedEvent) else None
)
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish(
event,
PublishFrom.APPLICATION_MANAGER
)

View File

@ -1,24 +1,10 @@
from typing import Optional
from core.app.entities.queue_entities import AppQueueEvent
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
@ -34,203 +20,127 @@ class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None:
self.current_node_id = None
def on_event(
self,
event: GraphEngineEvent
) -> None:
if isinstance(event, GraphRunStartedEvent):
self.print_text("\n[GraphRunStartedEvent]", color='pink')
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color='green')
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red')
elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started(
event=event
)
elif isinstance(event, NodeRunSucceededEvent):
self.on_workflow_node_execute_succeeded(
event=event
)
elif isinstance(event, NodeRunFailedEvent):
self.on_workflow_node_execute_failed(
event=event
)
elif isinstance(event, NodeRunStreamChunkEvent):
self.on_node_text_chunk(
event=event
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self.on_workflow_parallel_started(
event=event
)
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
self.on_workflow_parallel_completed(
event=event
)
elif isinstance(event, IterationRunStartedEvent):
self.on_workflow_iteration_started(
event=event
)
elif isinstance(event, IterationRunNextEvent):
self.on_workflow_iteration_next(
event=event
)
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
self.on_workflow_iteration_completed(
event=event
)
else:
self.print_text(f"\n[{event.__class__.__name__}]", color='blue')
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self.print_text("\n[on_workflow_run_started]", color='pink')
def on_workflow_node_execute_started(
self,
event: NodeRunStartedEvent
) -> None:
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self.print_text("\n[on_workflow_run_succeeded]", color='green')
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self.print_text("\n[on_workflow_run_failed]", color='red')
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
self.print_text("\n[NodeRunStartedEvent]", color='yellow')
self.print_text(f"Node ID: {event.node_id}", color='yellow')
self.print_text(f"Node Title: {event.node_data.title}", color='yellow')
self.print_text(f"Type: {event.node_type.value}", color='yellow')
self.print_text("\n[on_workflow_node_execute_started]", color='yellow')
self.print_text(f"Node ID: {node_id}", color='yellow')
self.print_text(f"Type: {node_type.value}", color='yellow')
self.print_text(f"Index: {node_run_index}", color='yellow')
if predecessor_node_id:
self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow')
def on_workflow_node_execute_succeeded(
self,
event: NodeRunSucceededEvent
) -> None:
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
route_node_state = event.route_node_state
self.print_text("\n[on_workflow_node_execute_succeeded]", color='green')
self.print_text(f"Node ID: {node_id}", color='green')
self.print_text(f"Type: {node_type.value}", color='green')
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green')
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green')
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green')
self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}",
color='green')
self.print_text("\n[NodeRunSucceededEvent]", color='green')
self.print_text(f"Node ID: {event.node_id}", color='green')
self.print_text(f"Node Title: {event.node_data.title}", color='green')
self.print_text(f"Type: {event.node_type.value}", color='green')
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color='green')
self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color='green')
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color='green')
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
color='green')
def on_workflow_node_execute_failed(
self,
event: NodeRunFailedEvent
) -> None:
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
"""
Workflow node execute failed
"""
route_node_state = event.route_node_state
self.print_text("\n[on_workflow_node_execute_failed]", color='red')
self.print_text(f"Node ID: {node_id}", color='red')
self.print_text(f"Type: {node_type.value}", color='red')
self.print_text(f"Error: {error}", color='red')
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red')
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red')
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red')
self.print_text("\n[NodeRunFailedEvent]", color='red')
self.print_text(f"Node ID: {event.node_id}", color='red')
self.print_text(f"Node Title: {event.node_data.title}", color='red')
self.print_text(f"Type: {event.node_type.value}", color='red')
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(f"Error: {node_run_result.error}", color='red')
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color='red')
self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color='red')
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color='red')
def on_node_text_chunk(
self,
event: NodeRunStreamChunkEvent
) -> None:
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
route_node_state = event.route_node_state
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
self.current_node_id = route_node_state.node_id
self.print_text('\n[NodeRunStreamChunkEvent]')
self.print_text(f"Node ID: {route_node_state.node_id}")
if not self.current_node_id or self.current_node_id != node_id:
self.current_node_id = node_id
self.print_text('\n[on_node_text_chunk]')
self.print_text(f"Node ID: {node_id}")
self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}")
node_run_result = route_node_state.node_run_result
if node_run_result:
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}")
self.print_text(text, color="pink", end="")
self.print_text(event.chunk_content, color="pink", end="")
def on_workflow_parallel_started(
self,
event: ParallelBranchRunStartedEvent
) -> None:
"""
Publish parallel started
"""
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
def on_workflow_parallel_completed(
self,
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
) -> None:
"""
Publish parallel completed
"""
if isinstance(event, ParallelBranchRunSucceededEvent):
color = 'blue'
elif isinstance(event, ParallelBranchRunFailedEvent):
color = 'red'
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
if isinstance(event, ParallelBranchRunFailedEvent):
self.print_text(f"Error: {event.error}", color=color)
def on_workflow_iteration_started(
self,
event: IterationRunStartedEvent
) -> None:
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self.print_text("\n[IterationRunStartedEvent]", color='blue')
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
self.print_text("\n[on_workflow_iteration_started]", color='blue')
self.print_text(f"Node ID: {node_id}", color='blue')
def on_workflow_iteration_next(
self,
event: IterationRunNextEvent
) -> None:
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[dict]) -> None:
"""
Publish iteration next
"""
self.print_text("\n[IterationRunNextEvent]", color='blue')
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
self.print_text(f"Iteration Index: {event.index}", color='blue')
self.print_text("\n[on_workflow_iteration_next]", color='blue')
def on_workflow_iteration_completed(
self,
event: IterationRunSucceededEvent | IterationRunFailedEvent
) -> None:
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
self.print_text("\n[on_workflow_iteration_completed]", color='blue')
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
self.print_text("\n[on_workflow_event]", color='blue')
self.print_text(f"Event: {jsonable_encoder(event)}", color='blue')
def print_text(
self, text: str, color: Optional[str] = None, end: str = "\n"

View File

@ -1,4 +1,3 @@
from datetime import datetime
from enum import Enum
from typing import Any, Optional
@ -6,8 +5,7 @@ from pydantic import BaseModel, field_validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.node_entities import NodeType
class QueueEvent(str, Enum):
@ -33,9 +31,6 @@ class QueueEvent(str, Enum):
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
MESSAGE_FILE = "message_file"
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
ERROR = "error"
PING = "ping"
STOP = "stop"
@ -43,7 +38,7 @@ class QueueEvent(str, Enum):
class AppQueueEvent(BaseModel):
"""
QueueEvent abstract entity
QueueEvent entity
"""
event: QueueEvent
@ -51,7 +46,6 @@ class AppQueueEvent(BaseModel):
class QueueLLMChunkEvent(AppQueueEvent):
"""
QueueLLMChunkEvent entity
Only for basic mode apps
"""
event: QueueEvent = QueueEvent.LLM_CHUNK
chunk: LLMResultChunk
@ -61,24 +55,14 @@ class QueueIterationStartEvent(AppQueueEvent):
QueueIterationStartEvent entity
"""
event: QueueEvent = QueueEvent.ITERATION_START
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
inputs: Optional[dict[str, Any]] = None
inputs: dict = None
predecessor_node_id: Optional[str] = None
metadata: Optional[dict[str, Any]] = None
metadata: Optional[dict] = None
class QueueIterationNextEvent(AppQueueEvent):
"""
@ -87,18 +71,8 @@ class QueueIterationNextEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.ITERATION_NEXT
index: int
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
node_run_index: int
output: Optional[Any] = None # output for the current iteration
@ -119,30 +93,13 @@ class QueueIterationCompletedEvent(AppQueueEvent):
"""
QueueIterationCompletedEvent entity
"""
event: QueueEvent = QueueEvent.ITERATION_COMPLETED
event:QueueEvent = QueueEvent.ITERATION_COMPLETED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
steps: int = 0
error: Optional[str] = None
outputs: dict
class QueueTextChunkEvent(AppQueueEvent):
"""
@ -150,8 +107,7 @@ class QueueTextChunkEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.TEXT_CHUNK
text: str
from_variable_selector: Optional[list[str]] = None
"""from variable selector"""
metadata: Optional[dict] = None
class QueueAgentMessageEvent(AppQueueEvent):
@ -206,7 +162,6 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
QueueWorkflowStartedEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
graph_runtime_state: GraphRuntimeState
class QueueWorkflowSucceededEvent(AppQueueEvent):
@ -214,7 +169,6 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
QueueWorkflowSucceededEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
outputs: Optional[dict[str, Any]] = None
class QueueWorkflowFailedEvent(AppQueueEvent):
@ -231,21 +185,11 @@ class QueueNodeStartedEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.NODE_STARTED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
node_run_index: int = 1
predecessor_node_id: Optional[str] = None
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
class QueueNodeSucceededEvent(AppQueueEvent):
@ -254,24 +198,14 @@ class QueueNodeSucceededEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.NODE_SUCCEEDED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
inputs: Optional[dict] = None
process_data: Optional[dict] = None
outputs: Optional[dict] = None
execution_metadata: Optional[dict] = None
error: Optional[str] = None
@ -282,23 +216,13 @@ class QueueNodeFailedEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
inputs: Optional[dict] = None
outputs: Optional[dict] = None
process_data: Optional[dict] = None
error: str
@ -350,23 +274,10 @@ class QueueStopEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.STOP
stopped_by: StopBy
def get_stop_reason(self) -> str:
"""
To stop reason
"""
reason_mapping = {
QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.',
QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.',
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
}
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
class QueueMessage(BaseModel):
"""
QueueMessage abstract entity
QueueMessage entity
"""
task_id: str
app_mode: str
@ -386,52 +297,3 @@ class WorkflowQueueMessage(QueueMessage):
WorkflowQueueMessage entity
"""
pass
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
"""
QueueParallelBranchRunStartedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
"""
QueueParallelBranchRunSucceededEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
"""
QueueParallelBranchRunFailedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
error: str

View File

@ -3,11 +3,40 @@ from typing import Any, Optional
from pydantic import BaseModel, ConfigDict
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import GenerateRouteChunk
from models.workflow import WorkflowNodeExecutionStatus
class WorkflowStreamGenerateNodes(BaseModel):
"""
WorkflowStreamGenerateNodes entity
"""
end_node_id: str
stream_node_ids: list[str]
class ChatflowStreamGenerateRoute(BaseModel):
"""
ChatflowStreamGenerateRoute entity
"""
answer_node_id: str
generate_route: list[GenerateRouteChunk]
current_route_position: int = 0
class NodeExecutionInfo(BaseModel):
"""
NodeExecutionInfo entity
"""
workflow_node_execution_id: str
node_type: NodeType
start_at: float
class TaskState(BaseModel):
"""
TaskState entity
@ -28,6 +57,27 @@ class WorkflowTaskState(TaskState):
"""
answer: str = ""
workflow_run_id: Optional[str] = None
start_at: Optional[float] = None
total_tokens: int = 0
total_steps: int = 0
ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
latest_node_execution_info: Optional[NodeExecutionInfo] = None
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
iteration_nested_node_ids: list[str] = None
class AdvancedChatTaskState(WorkflowTaskState):
"""
AdvancedChatTaskState entity
"""
usage: LLMUsage
current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None
class StreamEvent(Enum):
"""
@ -47,8 +97,6 @@ class StreamEvent(Enum):
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
@ -219,10 +267,6 @@ class NodeStartStreamResponse(StreamResponse):
inputs: Optional[dict] = None
created_at: int
extras: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str
@ -242,11 +286,7 @@ class NodeStartStreamResponse(StreamResponse):
"predecessor_node_id": self.data.predecessor_node_id,
"inputs": None,
"created_at": self.data.created_at,
"extras": {},
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"extras": {}
}
}
@ -276,10 +316,6 @@ class NodeFinishStreamResponse(StreamResponse):
created_at: int
finished_at: int
files: Optional[list[dict]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.NODE_FINISHED
workflow_run_id: str
@ -306,57 +342,9 @@ class NodeFinishStreamResponse(StreamResponse):
"execution_metadata": None,
"created_at": self.data.created_at,
"finished_at": self.data.finished_at,
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"files": []
}
}
class ParallelBranchStartStreamResponse(StreamResponse):
"""
ParallelBranchStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
workflow_run_id: str
data: Data
class ParallelBranchFinishedStreamResponse(StreamResponse):
"""
ParallelBranchFinishedStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
status: str
error: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
workflow_run_id: str
data: Data
class IterationNodeStartStreamResponse(StreamResponse):
@ -376,8 +364,6 @@ class IterationNodeStartStreamResponse(StreamResponse):
extras: dict = {}
metadata: dict = {}
inputs: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_STARTED
workflow_run_id: str
@ -401,8 +387,6 @@ class IterationNodeNextStreamResponse(StreamResponse):
created_at: int
pre_iteration_output: Optional[Any] = None
extras: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
@ -424,8 +408,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
title: str
outputs: Optional[dict] = None
created_at: int
extras: Optional[dict] = None
inputs: Optional[dict] = None
extras: dict = None
inputs: dict = None
status: WorkflowNodeExecutionStatus
error: Optional[str] = None
elapsed_time: float
@ -433,8 +417,6 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
execution_metadata: Optional[dict] = None
finished_at: int
steps: int
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
workflow_run_id: str
@ -506,7 +488,7 @@ class WorkflowAppStreamResponse(AppStreamResponse):
"""
WorkflowAppStreamResponse entity
"""
workflow_run_id: Optional[str] = None
workflow_run_id: str
class AppBlockingResponse(BaseModel):
@ -580,3 +562,25 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
workflow_run_id: str
data: Data
class WorkflowIterationState(BaseModel):
"""
WorkflowIterationState entity
"""
class Data(BaseModel):
"""
Data entity
"""
parent_iteration_id: Optional[str] = None
iteration_id: str
current_index: int
iteration_steps_boundary: list[int] = None
node_execution_id: str
started_at: float
inputs: dict = None
total_tokens: int = 0
node_data: BaseNodeData
current_iterations: dict[str, Data] = None

View File

@ -68,18 +68,16 @@ class BasedGenerateTaskPipeline:
err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
if message:
refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
message = db.session.query(Message).filter(Message.id == message.id).first()
err_desc = self._error_to_desc(err)
message.status = 'error'
message.error = err_desc
if refetch_message:
err_desc = self._error_to_desc(err)
refetch_message.status = 'error'
refetch_message.error = err_desc
db.session.commit()
db.session.commit()
return err
def _error_to_desc(self, e: Exception) -> str:
def _error_to_desc(cls, e: Exception) -> str:
"""
Error to desc.
:param e: exception

View File

@ -8,6 +8,7 @@ from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
@ -15,11 +16,11 @@ from core.app.entities.queue_entities import (
QueueRetrieverResourcesEvent,
)
from core.app.entities.task_entities import (
AdvancedChatTaskState,
EasyUITaskState,
MessageFileStreamResponse,
MessageReplaceStreamResponse,
MessageStreamResponse,
WorkflowTaskState,
)
from core.llm_generator.llm_generator import LLMGenerator
from core.tools.tool_file_manager import ToolFileManager
@ -35,7 +36,7 @@ class MessageCycleManage:
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
]
_task_state: Union[EasyUITaskState, WorkflowTaskState]
_task_state: Union[EasyUITaskState, AdvancedChatTaskState]
def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]:
"""
@ -44,9 +45,6 @@ class MessageCycleManage:
:param query: query
:return: thread
"""
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
return None
is_first_message = self._application_generate_entity.conversation_id is None
extras = self._application_generate_entity.extras
auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
@ -54,7 +52,7 @@ class MessageCycleManage:
if auto_generate_conversation_name and is_first_message:
# start generate thread
thread = Thread(target=self._generate_conversation_name_worker, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore
'flask_app': current_app._get_current_object(),
'conversation_id': conversation.id,
'query': query
})
@ -77,9 +75,6 @@ class MessageCycleManage:
.first()
)
if not conversation:
return
if conversation.mode != AppMode.COMPLETION.value:
app_model = conversation.app
if not app_model:
@ -126,13 +121,34 @@ class MessageCycleManage:
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata['retriever_resources'] = event.retriever_resources
def _get_response_metadata(self) -> dict:
"""
Get response metadata by invoke from.
:return:
"""
metadata = {}
# show_retrieve_source
if 'retriever_resources' in self._task_state.metadata:
metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
# show annotation reply
if 'annotation_reply' in self._task_state.metadata:
metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
# show usage
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
metadata['usage'] = self._task_state.metadata['usage']
return metadata
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
"""
Message file to stream response.
:param event: event
:return:
"""
message_file = (
message_file: MessageFile = (
db.session.query(MessageFile)
.filter(MessageFile.id == event.message_file_id)
.first()

View File

@ -1,41 +1,33 @@
import json
import time
from datetime import datetime, timezone
from typing import Any, Optional, Union, cast
from typing import Optional, Union, cast
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
NodeExecutionInfo,
NodeFinishStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
@ -49,56 +41,54 @@ from models.workflow import (
WorkflowRunStatus,
WorkflowRunTriggeredFrom,
)
from services.workflow_service import WorkflowService
class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any]
def _handle_workflow_run_start(self) -> WorkflowRun:
max_sequence = (
db.session.query(db.func.max(WorkflowRun.sequence_number))
.filter(WorkflowRun.tenant_id == self._workflow.tenant_id)
.filter(WorkflowRun.app_id == self._workflow.app_id)
.scalar()
or 0
)
class WorkflowCycleManage(WorkflowIterationCycleManage):
def _init_workflow_run(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
user_inputs: dict,
system_inputs: Optional[dict] = None) -> WorkflowRun:
"""
Init workflow run
:param workflow: Workflow instance
:param triggered_from: triggered from
:param user: account or end user
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:return:
"""
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \
.filter(WorkflowRun.app_id == workflow.app_id) \
.scalar() or 0
new_sequence_number = max_sequence + 1
inputs = {**self._application_generate_entity.inputs}
for key, value in (self._workflow_system_variables or {}).items():
inputs = {**user_inputs}
for key, value in (system_inputs or {}).items():
if key.value == 'conversation':
continue
inputs[f'sys.{key.value}'] = value
inputs = WorkflowEntry.handle_special_values(inputs)
triggered_from= (
WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN
)
inputs = WorkflowEngineManager.handle_special_values(inputs)
# init workflow run
workflow_run = WorkflowRun()
workflow_run.tenant_id = self._workflow.tenant_id
workflow_run.app_id = self._workflow.app_id
workflow_run.sequence_number = new_sequence_number
workflow_run.workflow_id = self._workflow.id
workflow_run.type = self._workflow.type
workflow_run.triggered_from = triggered_from.value
workflow_run.version = self._workflow.version
workflow_run.graph = self._workflow.graph
workflow_run.inputs = json.dumps(inputs)
workflow_run.status = WorkflowRunStatus.RUNNING.value
workflow_run.created_by_role = (
CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value
workflow_run = WorkflowRun(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
sequence_number=new_sequence_number,
workflow_id=workflow.id,
type=workflow.type,
triggered_from=triggered_from.value,
version=workflow.version,
graph=workflow.graph,
inputs=json.dumps(inputs),
status=WorkflowRunStatus.RUNNING.value,
created_by_role=(CreatedByRole.ACCOUNT.value
if isinstance(user, Account) else CreatedByRole.END_USER.value),
created_by=user.id
)
workflow_run.created_by = self._user.id
db.session.add(workflow_run)
db.session.commit()
@ -107,37 +97,33 @@ class WorkflowCycleManage:
return workflow_run
def _handle_workflow_run_success(
self,
workflow_run: WorkflowRun,
start_at: float,
def _workflow_run_success(
self, workflow_run: WorkflowRun,
total_tokens: int,
total_steps: int,
outputs: Optional[str] = None,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> WorkflowRun:
"""
Workflow run success
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:param conversation_id: conversation id
:return:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
workflow_run.outputs = outputs
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
db.session.refresh(workflow_run)
db.session.close()
if trace_manager:
trace_manager.add_trace_task(
@ -149,58 +135,34 @@ class WorkflowCycleManage:
)
)
db.session.close()
return workflow_run
def _handle_workflow_run_failed(
self,
workflow_run: WorkflowRun,
start_at: float,
def _workflow_run_failed(
self, workflow_run: WorkflowRun,
total_tokens: int,
total_steps: int,
status: WorkflowRunStatus,
error: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> WorkflowRun:
"""
Workflow run failed
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param status: status
:param error: error message
:return:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
workflow_run.status = status.value
workflow_run.error = error
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value
).all()
for workflow_node_execution in running_workflow_node_executions:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds()
db.session.commit()
db.session.refresh(workflow_run)
db.session.close()
@ -216,24 +178,39 @@ class WorkflowCycleManage:
return workflow_run
def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun,
node_id: str,
node_type: NodeType,
node_title: str,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution:
"""
Init workflow node execution from workflow run
:param workflow_run: workflow run
:param node_id: node id
:param node_type: node type
:param node_title: node title
:param node_run_index: run index
:param predecessor_node_id: predecessor node id if exists
:return:
"""
# init workflow node execution
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.index = event.node_run_index
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution = WorkflowNodeExecution(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
workflow_run_id=workflow_run.id,
predecessor_node_id=predecessor_node_id,
index=node_run_index,
node_id=node_id,
node_type=node_type.value,
title=node_title,
status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by,
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
)
db.session.add(workflow_node_execution)
db.session.commit()
@ -242,26 +219,33 @@ class WorkflowCycleManage:
return workflow_node_execution
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution:
"""
Workflow node execution success
:param event: queue node succeeded event
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param inputs: inputs
:param process_data: process data
:param outputs: outputs
:param execution_metadata: execution metadata
:return:
"""
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
inputs = WorkflowEngineManager.handle_special_values(inputs)
outputs = WorkflowEngineManager.handle_special_values(outputs)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
if execution_metadata else None
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
db.session.commit()
db.session.refresh(workflow_node_execution)
@ -269,24 +253,33 @@ class WorkflowCycleManage:
return workflow_node_execution
def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
error: str,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param error: error message
:return:
"""
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
inputs = WorkflowEngineManager.handle_special_values(inputs)
outputs = WorkflowEngineManager.handle_special_values(outputs)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = event.error
workflow_node_execution.error = error
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
if execution_metadata else None
db.session.commit()
db.session.refresh(workflow_node_execution)
@ -294,13 +287,8 @@ class WorkflowCycleManage:
return workflow_node_execution
#################################################
# to stream responses #
#################################################
def _workflow_start_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun
) -> WorkflowStartStreamResponse:
def _workflow_start_to_stream_response(self, task_id: str,
workflow_run: WorkflowRun) -> WorkflowStartStreamResponse:
"""
Workflow start to stream response.
:param task_id: task id
@ -314,14 +302,13 @@ class WorkflowCycleManage:
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
inputs=workflow_run.inputs_dict or {},
created_at=int(workflow_run.created_at.timestamp()),
),
inputs=workflow_run.inputs_dict,
created_at=int(workflow_run.created_at.timestamp())
)
)
def _workflow_finish_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun
) -> WorkflowFinishStreamResponse:
def _workflow_finish_to_stream_response(self, task_id: str,
workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse:
"""
Workflow finish to stream response.
:param task_id: task id
@ -333,16 +320,16 @@ class WorkflowCycleManage:
created_by_account = workflow_run.created_by_account
if created_by_account:
created_by = {
'id': created_by_account.id,
'name': created_by_account.name,
'email': created_by_account.email,
"id": created_by_account.id,
"name": created_by_account.name,
"email": created_by_account.email,
}
else:
created_by_end_user = workflow_run.created_by_end_user
if created_by_end_user:
created_by = {
'id': created_by_end_user.id,
'user': created_by_end_user.session_id,
"id": created_by_end_user.id,
"user": created_by_end_user.session_id,
}
return WorkflowFinishStreamResponse(
@ -361,13 +348,14 @@ class WorkflowCycleManage:
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}),
),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict)
)
)
def _workflow_node_start_to_stream_response(
self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution
) -> Optional[NodeStartStreamResponse]:
def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution) \
-> NodeStartStreamResponse:
"""
Workflow node start to stream response.
:param event: queue node started event
@ -375,9 +363,6 @@ class WorkflowCycleManage:
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
return None
response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
@ -389,12 +374,8 @@ class WorkflowCycleManage:
index=workflow_node_execution.index,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
),
created_at=int(workflow_node_execution.created_at.timestamp())
)
)
# extras logic
@ -403,27 +384,19 @@ class WorkflowCycleManage:
response.data.extras['icon'] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=node_data.provider_type,
provider_id=node_data.provider_id,
provider_id=node_data.provider_id
)
return response
def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution
) -> Optional[NodeFinishStreamResponse]:
def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \
-> NodeFinishStreamResponse:
"""
Workflow node finish to stream response.
:param event: queue node succeeded or failed event
:param task_id: task id
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
return None
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
@ -443,154 +416,181 @@ class WorkflowCycleManage:
execution_metadata=workflow_node_execution.execution_metadata_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
),
)
def _workflow_parallel_branch_start_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:
"""
Workflow parallel branch start to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: parallel branch run started event
:return:
"""
return ParallelBranchStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchStartStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
created_at=int(time.time()),
)
)
def _workflow_parallel_branch_finished_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
) -> ParallelBranchFinishedStreamResponse:
"""
Workflow parallel branch finished to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: parallel branch run succeeded or failed event
:return:
"""
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchFinishedStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict)
)
)
def _workflow_iteration_start_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse:
"""
Workflow iteration start to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: iteration start event
:return:
"""
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
def _handle_workflow_start(self) -> WorkflowRun:
self._task_state.start_at = time.perf_counter()
workflow_run = self._init_workflow_run(
workflow=self._workflow,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN,
user=self._user,
user_inputs=self._application_generate_entity.inputs,
system_inputs=self._workflow_system_variables
)
def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse:
"""
Workflow iteration next to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: iteration next event
:return:
"""
return IterationNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
self._task_state.workflow_run_id = workflow_run.id
db.session.close()
return workflow_run
def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_node_execution_from_workflow_run(
workflow_run=workflow_run,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_data.title,
node_run_index=event.node_run_index,
predecessor_node_id=event.predecessor_node_id
)
def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse:
"""
Workflow iteration completed to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: iteration completed event
:return:
"""
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=event.node_type,
start_at=time.perf_counter()
)
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._task_state.total_steps += 1
db.session.close()
return workflow_node_execution
def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None
if self._iteration_state and self._iteration_state.current_iterations:
if not execution_metadata:
execution_metadata = {}
current_iteration_data = None
for iteration_node_id in self._iteration_state.current_iterations:
data = self._iteration_state.current_iterations[iteration_node_id]
if data.parent_iteration_id == None:
current_iteration_data = data
break
if current_iteration_data:
execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id
execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index
if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
inputs=event.inputs,
process_data=event.process_data,
outputs=event.outputs,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED,
error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0,
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
execution_metadata=execution_metadata
)
)
if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
self._task_state.total_tokens += (
int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
if self._iteration_state:
for iteration_node_id in self._iteration_state.current_iterations:
data = self._iteration_state.current_iterations[iteration_node_id]
if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
if workflow_node_execution.node_type == NodeType.LLM.value:
outputs = workflow_node_execution.outputs_dict
usage_dict = outputs.get('usage', {})
self._task_state.metadata['usage'] = usage_dict
else:
workflow_node_execution = self._workflow_node_execution_failed(
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
error=event.error,
inputs=event.inputs,
process_data=event.process_data,
outputs=event.outputs,
execution_metadata=execution_metadata
)
db.session.close()
return workflow_node_execution
def _handle_workflow_finished(
self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> Optional[WorkflowRun]:
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == self._task_state.workflow_run_id).first()
if not workflow_run:
return None
if conversation_id is None:
conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id')
if isinstance(event, QueueStopEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
status=WorkflowRunStatus.STOPPED,
error='Workflow stopped.',
conversation_id=conversation_id,
trace_manager=trace_manager
)
latest_node_execution_info = self._task_state.latest_node_execution_info
if latest_node_execution_info:
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first()
if (workflow_node_execution
and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value):
self._workflow_node_execution_failed(
workflow_node_execution=workflow_node_execution,
start_at=latest_node_execution_info.start_at,
error='Workflow stopped.'
)
elif isinstance(event, QueueWorkflowFailedEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
status=WorkflowRunStatus.FAILED,
error=event.error,
conversation_id=conversation_id,
trace_manager=trace_manager
)
else:
if self._task_state.latest_node_execution_info:
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
outputs = workflow_node_execution.outputs
else:
outputs = None
workflow_run = self._workflow_run_success(
workflow_run=workflow_run,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
outputs=outputs,
conversation_id=conversation_id,
trace_manager=trace_manager
)
self._task_state.workflow_run_id = workflow_run.id
db.session.close()
return workflow_run
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
"""
@ -647,40 +647,3 @@ class WorkflowCycleManage:
return value.to_dict()
return None
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run
:param workflow_run_id: workflow run id
:return:
"""
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == workflow_run_id).first()
if not workflow_run:
raise Exception(f'Workflow run not found: {workflow_run_id}')
return workflow_run
def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
"""
Refetch workflow node execution
:param node_execution_id: workflow node execution id
:return:
"""
workflow_node_execution = (
db.session.query(WorkflowNodeExecution)
.filter(
WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id,
WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id,
WorkflowNodeExecution.workflow_id == self._workflow.id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.node_execution_id == node_execution_id,
)
.first()
)
if not workflow_node_execution:
raise Exception(f'Workflow node execution not found: {node_execution_id}')
return workflow_node_execution

View File

@ -0,0 +1,16 @@
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.enums import SystemVariableKey
from models.account import Account
from models.model import EndUser
from models.workflow import Workflow
class WorkflowCycleStateManager:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariableKey, Any]

View File

@ -0,0 +1,290 @@
import json
import time
from collections.abc import Generator
from datetime import datetime, timezone
from typing import Optional, Union
from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
)
from core.app.entities.task_entities import (
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
NodeExecutionInfo,
WorkflowIterationState,
)
from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.workflow import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
)
class WorkflowIterationCycleManage(WorkflowCycleStateManager):
_iteration_state: WorkflowIterationState = None
def _init_iteration_state(self) -> WorkflowIterationState:
if not self._iteration_state:
self._iteration_state = WorkflowIterationState(
current_iterations={}
)
def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \
-> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]:
"""
Handle iteration to stream response
:param task_id: task id
:param event: iteration event
:return:
"""
if isinstance(event, QueueIterationStartEvent):
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs,
metadata=event.metadata
)
)
elif isinstance(event, QueueIterationNextEvent):
current_iteration = self._iteration_state.current_iterations[event.node_id]
return IterationNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=current_iteration.node_data.title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={}
)
)
elif isinstance(event, QueueIterationCompletedEvent):
current_iteration = self._iteration_state.current_iterations[event.node_id]
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=current_iteration.node_data.title,
outputs=event.outputs,
created_at=int(time.time()),
extras={},
inputs=current_iteration.inputs,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
error=None,
elapsed_time=time.perf_counter() - current_iteration.started_at,
total_tokens=current_iteration.total_tokens,
execution_metadata={
'total_tokens': current_iteration.total_tokens,
},
finished_at=int(time.time()),
steps=current_iteration.current_index
)
)
def _init_iteration_execution_from_workflow_run(self,
workflow_run: WorkflowRun,
node_id: str,
node_type: NodeType,
node_title: str,
node_run_index: int = 1,
inputs: Optional[dict] = None,
predecessor_node_id: Optional[str] = None
) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
workflow_run_id=workflow_run.id,
predecessor_node_id=predecessor_node_id,
index=node_run_index,
node_id=node_id,
node_type=node_type.value,
inputs=json.dumps(inputs) if inputs else None,
title=node_title,
status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by,
execution_metadata=json.dumps({
'started_run_index': node_run_index + 1,
'current_index': 0,
'steps_boundary': [],
}),
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
)
db.session.add(workflow_node_execution)
db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close()
return workflow_node_execution
def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution:
if isinstance(event, QueueIterationStartEvent):
return self._handle_iteration_started(event)
elif isinstance(event, QueueIterationNextEvent):
return self._handle_iteration_next(event)
elif isinstance(event, QueueIterationCompletedEvent):
return self._handle_iteration_completed(event)
def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution:
self._init_iteration_state()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_iteration_execution_from_workflow_run(
workflow_run=workflow_run,
node_id=event.node_id,
node_type=NodeType.ITERATION,
node_title=event.node_data.title,
node_run_index=event.node_run_index,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id
)
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=NodeType.ITERATION,
start_at=time.perf_counter()
)
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data(
parent_iteration_id=None,
iteration_id=event.node_id,
current_index=0,
iteration_steps_boundary=[],
node_execution_id=workflow_node_execution.id,
started_at=time.perf_counter(),
inputs=event.inputs,
total_tokens=0,
node_data=event.node_data
)
db.session.close()
return workflow_node_execution
def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution:
if event.node_id not in self._iteration_state.current_iterations:
return
current_iteration = self._iteration_state.current_iterations[event.node_id]
current_iteration.current_index = event.index
current_iteration.iteration_steps_boundary.append(event.node_run_index)
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
if original_node_execution_metadata:
original_node_execution_metadata['current_index'] = event.index
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
db.session.commit()
db.session.close()
def _handle_iteration_completed(self, event: QueueIterationCompletedEvent):
if event.node_id not in self._iteration_state.current_iterations:
return
current_iteration = self._iteration_state.current_iterations[event.node_id]
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
if original_node_execution_metadata:
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
db.session.commit()
# remove current iteration
self._iteration_state.current_iterations.pop(event.node_id, None)
# set latest node execution info
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=NodeType.ITERATION,
start_at=time.perf_counter()
)
self._task_state.latest_node_execution_info = latest_node_execution_info
db.session.close()
def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]:
"""
Handle iteration exception
"""
if not self._iteration_state or not self._iteration_state.current_iterations:
return
for node_id, current_iteration in self._iteration_state.current_iterations.items():
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
db.session.commit()
db.session.close()
yield IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeCompletedStreamResponse.Data(
id=node_id,
node_id=node_id,
node_type=NodeType.ITERATION.value,
title=current_iteration.node_data.title,
outputs={},
created_at=int(time.time()),
extras={},
inputs=current_iteration.inputs,
status=WorkflowNodeExecutionStatus.FAILED,
error=error,
elapsed_time=time.perf_counter() - current_iteration.started_at,
total_tokens=current_iteration.total_tokens,
execution_metadata={
'total_tokens': current_iteration.total_tokens,
},
finished_at=int(time.time()),
steps=current_iteration.current_index
)
)

View File

@ -65,7 +65,7 @@ class Extensible:
if os.path.exists(builtin_file_path):
with open(builtin_file_path, encoding='utf-8') as f:
position = int(f.read().strip())
position_map[extension_name] = position
position_map[extension_name] = position
if (extension_name + '.py') not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")

View File

@ -16,7 +16,9 @@ from configs import dify_config
from core.errors.error import ProviderTokenNotInitError
from core.llm_generator.llm_generator import LLMGenerator
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
@ -253,8 +255,11 @@ class IndexingRunner:
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
tokens = 0
preview_texts = []
total_segments = 0
total_price = 0
currency = 'USD'
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
all_text_docs = []
@ -281,22 +286,54 @@ class IndexingRunner:
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' or embedding_model_instance:
tokens += embedding_model_instance.get_text_embedding_num_tokens(
texts=[self.filter_string(document.page_content)]
)
if doc_form and doc_form == 'qa_model':
model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM
)
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
doc_language)
document_qa_list = self.format_split_text(response)
price_info = model_type_instance.get_price(
model=model_instance.model,
credentials=model_instance.credentials,
price_type=PriceType.INPUT,
tokens=total_segments * 2000,
)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(price_info.total_amount),
"currency": price_info.currency,
"qa_preview": document_qa_list,
"preview": preview_texts
}
if embedding_model_instance:
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
embedding_price_info = embedding_model_type_instance.get_price(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
total_price = '{:f}'.format(embedding_price_info.total_amount)
currency = embedding_price_info.currency
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": total_price,
"currency": currency,
"preview": preview_texts
}

View File

@ -63,39 +63,6 @@ class LLMUsage(ModelUsage):
latency=0.0
)
def plus(self, other: 'LLMUsage') -> 'LLMUsage':
"""
Add two LLMUsage instances together.
:param other: Another LLMUsage instance to add
:return: A new LLMUsage instance with summed values
"""
if self.total_tokens == 0:
return other
else:
return LLMUsage(
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
prompt_unit_price=other.prompt_unit_price,
prompt_price_unit=other.prompt_price_unit,
prompt_price=self.prompt_price + other.prompt_price,
completion_tokens=self.completion_tokens + other.completion_tokens,
completion_unit_price=other.completion_unit_price,
completion_price_unit=other.completion_price_unit,
completion_price=self.completion_price + other.completion_price,
total_tokens=self.total_tokens + other.total_tokens,
total_price=self.total_price + other.total_price,
currency=other.currency,
latency=self.latency + other.latency
)
def __add__(self, other: 'LLMUsage') -> 'LLMUsage':
"""
Overload the + operator to add two LLMUsage instances.
:param other: Another LLMUsage instance to add
:return: A new LLMUsage instance with summed values
"""
return self.plus(other)
class LLMResult(BaseModel):
"""

View File

@ -150,9 +150,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error')
if (completion_type is LLMMode.CHAT and json_result.get('object','') == ''):
if (completion_type is LLMMode.CHAT and json_result['object'] == ''):
json_result['object'] = 'chat.completion'
elif (completion_type is LLMMode.COMPLETION and json_result.get('object','') == ''):
elif (completion_type is LLMMode.COMPLETION and json_result['object'] == ''):
json_result['object'] = 'text_completion'
if (completion_type is LLMMode.CHAT

View File

@ -71,24 +71,11 @@ class ArkClientV3:
args = {
"base_url": credentials['api_endpoint_host'],
"region": credentials['volc_region'],
"ak": credentials['volc_access_key_id'],
"sk": credentials['volc_secret_access_key'],
}
if credentials.get("auth_method") == "api_key":
args = {
**args,
"api_key": credentials['volc_api_key'],
}
else:
args = {
**args,
"ak": credentials['volc_access_key_id'],
"sk": credentials['volc_secret_access_key'],
}
if cls.is_compatible_with_legacy(credentials):
args = {
**args,
"base_url": DEFAULT_V3_ENDPOINT
}
args["base_url"] = DEFAULT_V3_ENDPOINT
client = ArkClientV3(
**args

View File

@ -30,28 +30,8 @@ model_credential_schema:
en_US: Enter your Model Name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: auth_method
required: true
label:
en_US: Authentication Method
zh_Hans: 鉴权方式
type: select
default: aksk
options:
- label:
en_US: API Key
value: api_key
- label:
en_US: Access Key / Secret Access Key
value: aksk
placeholder:
en_US: Enter your Authentication Method
zh_Hans: 选择鉴权方式
- variable: volc_access_key_id
required: true
show_on:
- variable: auth_method
value: aksk
label:
en_US: Access Key
zh_Hans: Access Key
@ -61,9 +41,6 @@ model_credential_schema:
zh_Hans: 输入您的 Access Key
- variable: volc_secret_access_key
required: true
show_on:
- variable: auth_method
value: aksk
label:
en_US: Secret Access Key
zh_Hans: Secret Access Key
@ -71,17 +48,6 @@ model_credential_schema:
placeholder:
en_US: Enter your Secret Access Key
zh_Hans: 输入您的 Secret Access Key
- variable: volc_api_key
required: true
show_on:
- variable: auth_method
value: api_key
label:
en_US: API Key
type: secret-input
placeholder:
en_US: Enter your API Key
zh_Hans: 输入您的 API Key
- variable: volc_region
required: true
label:

View File

@ -174,11 +174,6 @@ class XinferenceText2SpeechModel(TTSModel):
return voices[language]
elif 'all' in voices:
return voices['all']
else:
all_voices = []
for lang, lang_voices in voices.items():
all_voices.extend(lang_voices)
return all_voices
return self.model_voices['__default']['all']

View File

@ -38,7 +38,7 @@ parameter_rules:
min: 1
max: 8192
pricing:
input: '0'
output: '0'
input: '0.0001'
output: '0.0001'
unit: '0.001'
currency: RMB

View File

@ -37,8 +37,3 @@ parameter_rules:
default: 1024
min: 1
max: 8192
pricing:
input: '0.001'
output: '0.001'
unit: '0.001'
currency: RMB

View File

@ -37,8 +37,3 @@ parameter_rules:
default: 1024
min: 1
max: 8192
pricing:
input: '0.1'
output: '0.1'
unit: '0.001'
currency: RMB

View File

@ -30,9 +30,4 @@ parameter_rules:
use_template: max_tokens
default: 1024
min: 1
max: 8192
pricing:
input: '0.001'
output: '0.001'
unit: '0.001'
currency: RMB
max: 4096

View File

@ -1,44 +0,0 @@
model: glm-4-plus
label:
en_US: glm-4-plus
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
parameter_rules:
- name: temperature
use_template: temperature
default: 0.95
min: 0.0
max: 1.0
help:
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: top_p
use_template: top_p
default: 0.7
help:
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: incremental
label:
zh_Hans: 增量返回
en_US: Incremental
type: boolean
help:
zh_Hans: SSE接口调用时用于控制每次返回内容方式是增量还是全量不提供此参数时默认为增量返回true 为增量返回false 为全量返回。
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
required: false
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 8192
pricing:
input: '0.05'
output: '0.05'
unit: '0.001'
currency: RMB

View File

@ -34,9 +34,4 @@ parameter_rules:
use_template: max_tokens
default: 1024
min: 1
max: 1024
pricing:
input: '0.05'
output: '0.05'
unit: '0.001'
currency: RMB
max: 8192

View File

@ -1,42 +0,0 @@
model: glm-4v-plus
label:
en_US: glm-4v-plus
model_type: llm
model_properties:
mode: chat
features:
- vision
parameter_rules:
- name: temperature
use_template: temperature
default: 0.95
min: 0.0
max: 1.0
help:
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: top_p
use_template: top_p
default: 0.7
help:
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: incremental
label:
zh_Hans: 增量返回
en_US: Incremental
type: boolean
help:
zh_Hans: SSE接口调用时用于控制每次返回内容方式是增量还是全量不提供此参数时默认为增量返回true 为增量返回false 为全量返回。
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
required: false
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 1024
pricing:
input: '0.01'
output: '0.01'
unit: '0.001'
currency: RMB

View File

@ -153,8 +153,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
extra_model_kwargs = {}
# request to glm-4v-plus with stop words will always response "finish_reason":"network_error"
if stop and model!= 'glm-4v-plus':
if stop:
extra_model_kwargs['stop'] = stop
client = ZhipuAI(
@ -175,7 +174,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]:
if isinstance(copy_prompt_message.content, list):
# check if model is 'glm-4v'
if model not in ('glm-4v', 'glm-4v-plus'):
if model != 'glm-4v':
# not support list message
continue
# get image and
@ -208,7 +207,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
else:
new_prompt_messages.append(copy_prompt_message)
if model == 'glm-4v' or model == 'glm-4v-plus':
if model == 'glm-4v':
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
else:
params = {
@ -305,7 +304,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
return params
def _construct_glm_4v_messages(self, prompt_message: Union[str, list[PromptMessageContent]]) -> list[dict]:
def _construct_glm_4v_messages(self, prompt_message: Union[str | list[PromptMessageContent]]) -> list[dict]:
if isinstance(prompt_message, str):
return [{'type': 'text', 'text': prompt_message}]

View File

@ -34,13 +34,13 @@ class OutputModeration(BaseModel):
final_output: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
def should_direct_output(self) -> bool:
def should_direct_output(self):
return self.final_output is not None
def get_final_output(self) -> str:
return self.final_output or ""
def get_final_output(self):
return self.final_output
def append_new_token(self, token: str) -> None:
def append_new_token(self, token: str):
self.buffer += token
if not self.thread:

View File

@ -204,7 +204,6 @@ class LangFuseDataTrace(BaseTraceInstance):
node_generation_data = LangfuseGeneration(
name="llm",
trace_id=trace_id,
model=process_data.get("model_name"),
parent_observation_id=node_execution_id,
start_time=created_at,
end_time=finished_at,

View File

@ -139,7 +139,8 @@ class LangSmithDataTrace(BaseTraceInstance):
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
)
node_total_tokens = execution_metadata.get("total_tokens", 0)
metadata = execution_metadata.copy()
metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
metadata.update(
{
"workflow_run_id": trace_info.workflow_run_id,
@ -155,12 +156,6 @@ class LangSmithDataTrace(BaseTraceInstance):
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
if process_data and process_data.get("model_mode") == "chat":
run_type = LangSmithRunType.llm
metadata.update(
{
'ls_provider': process_data.get('model_provider', ''),
'ls_model_name': process_data.get('model_name', ''),
}
)
elif node_type == "knowledge-retrieval":
run_type = LangSmithRunType.retriever
else:

View File

@ -146,7 +146,7 @@ class RetrievalService:
)
if documents:
if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value:
if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
RerankMode.RERANKING_MODEL.value,
reranking_model, None, False)
@ -180,7 +180,7 @@ class RetrievalService:
top_k=top_k
)
if documents:
if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
RerankMode.RERANKING_MODEL.value,
reranking_model, None, False)

View File

@ -281,25 +281,20 @@ class NotionExtractor(BaseExtractor):
for table_header_cell_text in tabel_header_cell:
text = table_header_cell_text["text"]["content"]
table_header_cell_texts.append(text)
else:
table_header_cell_texts.append('')
# Initialize Markdown table with headers
markdown_table = "| " + " | ".join(table_header_cell_texts) + " |\n"
markdown_table += "| " + " | ".join(['---'] * len(table_header_cell_texts)) + " |\n"
# Process data to format each row in Markdown table format
# get table columns text and format
results = data["results"]
for i in range(len(results) - 1):
column_texts = []
table_column_cells = data["results"][i + 1]['table_row']['cells']
for j in range(len(table_column_cells)):
if table_column_cells[j]:
for table_column_cell_text in table_column_cells[j]:
tabel_column_cells = data["results"][i + 1]['table_row']['cells']
for j in range(len(tabel_column_cells)):
if tabel_column_cells[j]:
for table_column_cell_text in tabel_column_cells[j]:
column_text = table_column_cell_text["text"]["content"]
column_texts.append(column_text)
# Add row to Markdown table
markdown_table += "| " + " | ".join(column_texts) + " |\n"
result_lines_arr.append(markdown_table)
column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
cur_result_text = "\n".join(column_texts)
result_lines_arr.append(cur_result_text)
if data["next_cursor"] is None:
done = True
break

View File

@ -170,8 +170,6 @@ class WordExtractor(BaseExtractor):
if run.element.xpath('.//a:blip'):
for blip in run.element.xpath('.//a:blip'):
image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
if not image_id:
continue
image_part = paragraph.part.rels[image_id].target_part
if image_part in image_map:
@ -258,6 +256,6 @@ class WordExtractor(BaseExtractor):
content.append(parsed_paragraph)
elif isinstance(element.tag, str) and element.tag.endswith('tbl'): # table
table = tables.pop(0)
content.append(self._table_to_markdown(table, image_map))
content.append(self._table_to_markdown(table,image_map))
return '\n'.join(content)

View File

@ -30,14 +30,15 @@ def _split_text_with_regex(
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({re.escape(separator)})", text)
splits = [_splits[i - 1] + _splits[i] for i in range(1, len(_splits), 2)]
if len(_splits) % 2 != 0:
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = [_splits[0]] + splits
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if (s != "" and s != '\n')]
return [s for s in splits if s != ""]
class TextSplitter(BaseDocumentTransformer, ABC):
@ -108,7 +109,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
else:
return text
def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function(separator)
@ -116,9 +117,8 @@ class TextSplitter(BaseDocumentTransformer, ABC):
docs = []
current_doc: list[str] = []
total = 0
index = 0
for d in splits:
_len = lengths[index]
_len = self._length_function(d)
if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
@ -146,7 +146,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
index += 1
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
@ -495,10 +494,11 @@ class RecursiveCharacterTextSplitter(TextSplitter):
self._separators = separators or ["\n\n", "\n", " ", ""]
def _split_text(self, text: str, separators: list[str]) -> list[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
if _s == "":
separator = _s
@ -509,31 +509,25 @@ class RecursiveCharacterTextSplitter(TextSplitter):
break
splits = _split_text_with_regex(text, separator, self._keep_separator)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits
_separator = "" if self._keep_separator else separator
for s in splits:
s_len = self._length_function(s)
if s_len < self._chunk_size:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
_good_splits_lengths.append(s_len)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
_good_splits = []
_good_splits_lengths = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
return final_chunks
def split_text(self, text: str) -> list[str]:

View File

@ -1,6 +1,5 @@
- google
- bing
- perplexity
- duckduckgo
- searchapi
- serper
@ -11,7 +10,6 @@
- wikipedia
- nominatim
- yahoo
- alphavantage
- arxiv
- pubmed
- stablediffusion

View File

@ -1,7 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="56px" height="56px" viewBox="0 0 56 56" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>形状结合</title>
<g id="设计规范" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<path d="M56,0 L56,56 L0,56 L0,0 L56,0 Z M31.6063018,12 L24.3936982,12 L24.1061064,12.7425499 L12.6071308,42.4324141 L12,44 L19.7849972,44 L20.0648488,43.2391815 L22.5196173,36.5567427 L33.4780427,36.5567427 L35.9351512,43.2391815 L36.2150028,44 L44,44 L43.3928692,42.4324141 L31.8938936,12.7425499 L31.6063018,12 Z M28.0163803,21.5755126 L31.1613993,30.2523823 L24.8432808,30.2523823 L28.0163803,21.5755126 Z" id="形状结合" fill="#2F4F4F"></path>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 780 B

View File

@ -1,22 +0,0 @@
from typing import Any
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.alphavantage.tools.query_stock import QueryStockTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class AlphaVantageProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
QueryStockTool().fork_tool_runtime(
runtime={
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
"code": "AAPL", # Apple Inc.
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -1,31 +0,0 @@
identity:
author: zhuhao
name: alphavantage
label:
en_US: AlphaVantage
zh_Hans: AlphaVantage
pt_BR: AlphaVantage
description:
en_US: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis.
zh_Hans: AlphaVantage是一个在线平台它提供金融市场数据和API便于个人投资者和开发者获取股票报价、技术指标和股票分析。
pt_BR: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis.
icon: icon.svg
tags:
- finance
credentials_for_provider:
api_key:
type: secret-input
required: true
label:
en_US: AlphaVantage API key
zh_Hans: AlphaVantage API key
pt_BR: AlphaVantage API key
placeholder:
en_US: Please input your AlphaVantage API key
zh_Hans: 请输入你的 AlphaVantage API key
pt_BR: Please input your AlphaVantage API key
help:
en_US: Get your AlphaVantage API key from AlphaVantage
zh_Hans: 从 AlphaVantage 获取您的 AlphaVantage API key
pt_BR: Get your AlphaVantage API key from AlphaVantage
url: https://www.alphavantage.co/support/#api-key

View File

@ -1,49 +0,0 @@
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
ALPHAVANTAGE_API_URL = "https://www.alphavantage.co/query"
class QueryStockTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
stock_code = tool_parameters.get('code', '')
if not stock_code:
return self.create_text_message('Please tell me your stock code')
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
return self.create_text_message("Alpha Vantage API key is required.")
params = {
"function": "TIME_SERIES_DAILY",
"symbol": stock_code,
"outputsize": "compact",
"datatype": "json",
"apikey": self.runtime.credentials['api_key']
}
response = requests.get(url=ALPHAVANTAGE_API_URL, params=params)
response.raise_for_status()
result = self._handle_response(response.json())
return self.create_json_message(result)
def _handle_response(self, response: dict[str, Any]) -> dict[str, Any]:
result = response.get('Time Series (Daily)', {})
if not result:
return {}
stock_result = {}
for k, v in result.items():
stock_result[k] = {}
stock_result[k]['open'] = v.get('1. open')
stock_result[k]['high'] = v.get('2. high')
stock_result[k]['low'] = v.get('3. low')
stock_result[k]['close'] = v.get('4. close')
stock_result[k]['volume'] = v.get('5. volume')
return stock_result

View File

@ -1,27 +0,0 @@
identity:
name: query_stock
author: zhuhao
label:
en_US: query_stock
zh_Hans: query_stock
pt_BR: query_stock
description:
human:
en_US: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol.
zh_Hans: 获取指定股票代码的每日开盘价、每日最高价、每日最低价、每日收盘价和每日交易量等信息。
pt_BR: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol
llm: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol
parameters:
- name: code
type: string
required: true
label:
en_US: stock code
zh_Hans: 股票代码
pt_BR: stock code
human_description:
en_US: stock code
zh_Hans: 股票代码
pt_BR: stock code
llm_description: stock code for query from alphavantage
form: llm

View File

@ -1,3 +0,0 @@
<svg width="400" height="400" viewBox="0 0 400 400" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M101.008 42L190.99 124.905L190.99 124.886L190.99 42.1913H208.506L208.506 125.276L298.891 42V136.524L336 136.524V272.866H299.005V357.035L208.506 277.525L208.506 357.948H190.99L190.99 278.836L101.11 358V272.866H64V136.524H101.008V42ZM177.785 153.826H81.5159V255.564H101.088V223.472L177.785 153.826ZM118.625 231.149V319.392L190.99 255.655L190.99 165.421L118.625 231.149ZM209.01 254.812V165.336L281.396 231.068V272.866H281.489V318.491L209.01 254.812ZM299.005 255.564H318.484V153.826L222.932 153.826L299.005 222.751V255.564ZM281.375 136.524V81.7983L221.977 136.524L281.375 136.524ZM177.921 136.524H118.524V81.7983L177.921 136.524Z" fill="black"/>
</svg>

Before

Width:  |  Height:  |  Size: 798 B

View File

@ -1,46 +0,0 @@
from typing import Any
import requests
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.perplexity.tools.perplexity_search import PERPLEXITY_API_URL
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class PerplexityProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
headers = {
"Authorization": f"Bearer {credentials.get('perplexity_api_key')}",
"Content-Type": "application/json"
}
payload = {
"model": "llama-3.1-sonar-small-128k-online",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello"
}
],
"max_tokens": 5,
"temperature": 0.1,
"top_p": 0.9,
"stream": False
}
try:
response = requests.post(PERPLEXITY_API_URL, json=payload, headers=headers)
response.raise_for_status()
except requests.RequestException as e:
raise ToolProviderCredentialValidationError(
f"Failed to validate Perplexity API key: {str(e)}"
)
if response.status_code != 200:
raise ToolProviderCredentialValidationError(
f"Perplexity API key is invalid. Status code: {response.status_code}"
)

View File

@ -1,26 +0,0 @@
identity:
author: Dify
name: perplexity
label:
en_US: Perplexity
zh_Hans: Perplexity
description:
en_US: Perplexity.AI
zh_Hans: Perplexity.AI
icon: icon.svg
tags:
- search
credentials_for_provider:
perplexity_api_key:
type: secret-input
required: true
label:
en_US: Perplexity API key
zh_Hans: Perplexity API key
placeholder:
en_US: Please input your Perplexity API key
zh_Hans: 请输入你的 Perplexity API key
help:
en_US: Get your Perplexity API key from Perplexity
zh_Hans: 从 Perplexity 获取您的 Perplexity API key
url: https://www.perplexity.ai/settings/api

View File

@ -1,72 +0,0 @@
import json
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
PERPLEXITY_API_URL = "https://api.perplexity.ai/chat/completions"
class PerplexityAITool(BuiltinTool):
def _parse_response(self, response: dict) -> dict:
"""Parse the response from Perplexity AI API"""
if 'choices' in response and len(response['choices']) > 0:
message = response['choices'][0]['message']
return {
'content': message.get('content', ''),
'role': message.get('role', ''),
'citations': response.get('citations', [])
}
else:
return {'content': 'Unable to get a valid response', 'role': 'assistant', 'citations': []}
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
headers = {
"Authorization": f"Bearer {self.runtime.credentials['perplexity_api_key']}",
"Content-Type": "application/json"
}
payload = {
"model": tool_parameters.get('model', 'llama-3.1-sonar-small-128k-online'),
"messages": [
{
"role": "system",
"content": "Be precise and concise."
},
{
"role": "user",
"content": tool_parameters['query']
}
],
"max_tokens": tool_parameters.get('max_tokens', 4096),
"temperature": tool_parameters.get('temperature', 0.7),
"top_p": tool_parameters.get('top_p', 1),
"top_k": tool_parameters.get('top_k', 5),
"presence_penalty": tool_parameters.get('presence_penalty', 0),
"frequency_penalty": tool_parameters.get('frequency_penalty', 1),
"stream": False
}
if 'search_recency_filter' in tool_parameters:
payload['search_recency_filter'] = tool_parameters['search_recency_filter']
if 'return_citations' in tool_parameters:
payload['return_citations'] = tool_parameters['return_citations']
if 'search_domain_filter' in tool_parameters:
if isinstance(tool_parameters['search_domain_filter'], str):
payload['search_domain_filter'] = [tool_parameters['search_domain_filter']]
elif isinstance(tool_parameters['search_domain_filter'], list):
payload['search_domain_filter'] = tool_parameters['search_domain_filter']
response = requests.post(url=PERPLEXITY_API_URL, json=payload, headers=headers)
response.raise_for_status()
valuable_res = self._parse_response(response.json())
return [
self.create_json_message(valuable_res),
self.create_text_message(json.dumps(valuable_res, ensure_ascii=False, indent=2))
]

View File

@ -1,178 +0,0 @@
identity:
name: perplexity
author: Dify
label:
en_US: Perplexity Search
description:
human:
en_US: Search information using Perplexity AI's language models.
llm: This tool is used to search information using Perplexity AI's language models.
parameters:
- name: query
type: string
required: true
label:
en_US: Query
zh_Hans: 查询
human_description:
en_US: The text query to be processed by the AI model.
zh_Hans: 要由 AI 模型处理的文本查询。
form: llm
- name: model
type: select
required: false
label:
en_US: Model Name
zh_Hans: 模型名称
human_description:
en_US: The Perplexity AI model to use for generating the response.
zh_Hans: 用于生成响应的 Perplexity AI 模型。
form: form
default: "llama-3.1-sonar-small-128k-online"
options:
- value: llama-3.1-sonar-small-128k-online
label:
en_US: llama-3.1-sonar-small-128k-online
zh_Hans: llama-3.1-sonar-small-128k-online
- value: llama-3.1-sonar-large-128k-online
label:
en_US: llama-3.1-sonar-large-128k-online
zh_Hans: llama-3.1-sonar-large-128k-online
- value: llama-3.1-sonar-huge-128k-online
label:
en_US: llama-3.1-sonar-huge-128k-online
zh_Hans: llama-3.1-sonar-huge-128k-online
- name: max_tokens
type: number
required: false
label:
en_US: Max Tokens
zh_Hans: 最大令牌数
pt_BR: Máximo de Tokens
human_description:
en_US: The maximum number of tokens to generate in the response.
zh_Hans: 在响应中生成的最大令牌数。
pt_BR: O número máximo de tokens a serem gerados na resposta.
form: form
default: 4096
min: 1
max: 4096
- name: temperature
type: number
required: false
label:
en_US: Temperature
zh_Hans: 温度
pt_BR: Temperatura
human_description:
en_US: Controls randomness in the output. Lower values make the output more focused and deterministic.
zh_Hans: 控制输出的随机性。较低的值使输出更加集中和确定。
form: form
default: 0.7
min: 0
max: 1
- name: top_k
type: number
required: false
label:
en_US: Top K
zh_Hans: 取样数量
human_description:
en_US: The number of top results to consider for response generation.
zh_Hans: 用于生成响应的顶部结果数量。
form: form
default: 5
min: 1
max: 100
- name: top_p
type: number
required: false
label:
en_US: Top P
zh_Hans: Top P
human_description:
en_US: Controls diversity via nucleus sampling.
zh_Hans: 通过核心采样控制多样性。
form: form
default: 1
min: 0.1
max: 1
step: 0.1
- name: presence_penalty
type: number
required: false
label:
en_US: Presence Penalty
zh_Hans: 存在惩罚
human_description:
en_US: Positive values penalize new tokens based on whether they appear in the text so far.
zh_Hans: 正值会根据新词元是否已经出现在文本中来对其进行惩罚。
form: form
default: 0
min: -1.0
max: 1.0
step: 0.1
- name: frequency_penalty
type: number
required: false
label:
en_US: Frequency Penalty
zh_Hans: 频率惩罚
human_description:
en_US: Positive values penalize new tokens based on their existing frequency in the text so far.
zh_Hans: 正值会根据新词元在文本中已经出现的频率来对其进行惩罚。
form: form
default: 1
min: 0.1
max: 1.0
step: 0.1
- name: return_citations
type: boolean
required: false
label:
en_US: Return Citations
zh_Hans: 返回引用
human_description:
en_US: Whether to return citations in the response.
zh_Hans: 是否在响应中返回引用。
form: form
default: true
- name: search_domain_filter
type: string
required: false
label:
en_US: Search Domain Filter
zh_Hans: 搜索域过滤器
human_description:
en_US: Domain to filter the search results.
zh_Hans: 用于过滤搜索结果的域名。
form: form
default: ""
- name: search_recency_filter
type: select
required: false
label:
en_US: Search Recency Filter
zh_Hans: 搜索时间过滤器
human_description:
en_US: Filter for search results based on recency.
zh_Hans: 基于时间筛选搜索结果。
form: form
default: "month"
options:
- value: day
label:
en_US: Day
zh_Hans:
- value: week
label:
en_US: Week
zh_Hans:
- value: month
label:
en_US: Month
zh_Hans:
- value: year
label:
en_US: Year
zh_Hans:

View File

@ -17,8 +17,11 @@ class StepfunTool(BuiltinTool):
"""
invoke tools
"""
base_url = self.runtime.credentials.get('stepfun_base_url', 'https://api.stepfun.com')
base_url = str(URL(base_url) / 'v1')
base_url = self.runtime.credentials.get('stepfun_base_url', None)
if not base_url:
base_url = None
else:
base_url = str(URL(base_url) / 'v1')
client = OpenAI(
api_key=self.runtime.credentials['stepfun_api_key'],

View File

@ -1,7 +1,7 @@
import json
import logging
from copy import deepcopy
from typing import Any, Optional, Union
from typing import Any, Union
from core.file.file_obj import FileTransferMethod, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
@ -18,7 +18,6 @@ class WorkflowTool(Tool):
version: str
workflow_entities: dict[str, Any]
workflow_call_depth: int
thread_pool_id: Optional[str] = None
label: str
@ -58,7 +57,6 @@ class WorkflowTool(Tool):
invoke_from=self.runtime.invoke_from,
stream=False,
call_depth=self.workflow_call_depth + 1,
workflow_thread_pool_id=self.thread_pool_id
)
data = result.get('data', {})

View File

@ -128,7 +128,6 @@ class ToolEngine:
user_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler,
workflow_call_depth: int,
thread_pool_id: Optional[str] = None
) -> list[ToolInvokeMessage]:
"""
Workflow invokes the tool with the given arguments.
@ -142,7 +141,6 @@ class ToolEngine:
if isinstance(tool, WorkflowTool):
tool.workflow_call_depth = workflow_call_depth + 1
tool.thread_pool_id = thread_pool_id
if tool.runtime and tool.runtime.runtime_parameters:
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}

View File

@ -25,6 +25,7 @@ from core.tools.tool.tool import Tool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from core.workflow.nodes.tool.entities import ToolEntity
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
@ -248,7 +249,7 @@ class ToolManager:
return tool_entity
@classmethod
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
"""
get the workflow tool runtime
"""

View File

@ -7,7 +7,6 @@ from core.tools.tool_file_manager import ToolFileManager
logger = logging.getLogger(__name__)
class ToolFileMessageTransformer:
@classmethod
def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],

View File

@ -1,15 +1,116 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
from core.workflow.graph_engine.entities.event import GraphEngineEvent
from core.app.entities.queue_entities import AppQueueEvent
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
class WorkflowCallback(ABC):
@abstractmethod
def on_event(
self,
event: GraphEngineEvent
) -> None:
def on_workflow_run_started(self) -> None:
"""
Published event
Workflow run started
"""
raise NotImplementedError
@abstractmethod
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
raise NotImplementedError
@abstractmethod
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
raise NotImplementedError
@abstractmethod
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
raise NotImplementedError
@abstractmethod
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
raise NotImplementedError
@abstractmethod
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
"""
Workflow node execute failed
"""
raise NotImplementedError
@abstractmethod
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
raise NotImplementedError
@abstractmethod
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: Optional[dict] = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
raise NotImplementedError
@abstractmethod
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any],
) -> None:
"""
Publish iteration next
"""
raise NotImplementedError
@abstractmethod
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
raise NotImplementedError
@abstractmethod
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
raise NotImplementedError

View File

@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel):
desc: Optional[str] = None
class BaseIterationNodeData(BaseNodeData):
start_node_id: Optional[str] = None
start_node_id: str
class BaseIterationState(BaseModel):
iteration_node_id: str

View File

@ -1,9 +1,9 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMUsage
from models import WorkflowNodeExecutionStatus
@ -28,7 +28,6 @@ class NodeType(Enum):
VARIABLE_ASSIGNER = 'variable-assigner'
LOOP = 'loop'
ITERATION = 'iteration'
ITERATION_START = 'iteration-start' # fake start node for iteration
PARAMETER_EXTRACTOR = 'parameter-extractor'
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
@ -57,10 +56,6 @@ class NodeRunMetadataKey(Enum):
TOOL_INFO = 'tool_info'
ITERATION_ID = 'iteration_id'
ITERATION_INDEX = 'iteration_index'
PARALLEL_ID = 'parallel_id'
PARALLEL_START_NODE_ID = 'parallel_start_node_id'
PARENT_PARALLEL_ID = 'parent_parallel_id'
PARENT_PARALLEL_START_NODE_ID = 'parent_parallel_start_node_id'
class NodeRunResult(BaseModel):
@ -70,32 +65,11 @@ class NodeRunResult(BaseModel):
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[dict[str, Any]] = None # node inputs
process_data: Optional[dict[str, Any]] = None # process data
outputs: Optional[dict[str, Any]] = None # node outputs
inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[dict] = None # process data
outputs: Optional[Mapping[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
llm_usage: Optional[LLMUsage] = None # llm usage
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
class UserFrom(Enum):
"""
User from
"""
ACCOUNT = "account"
END_USER = "end-user"
@classmethod
def value_of(cls, value: str) -> "UserFrom":
"""
Value of
:param value: value
:return:
"""
for item in cls:
if item.value == value:
return item
raise ValueError(f"Invalid value: {value}")

View File

@ -2,7 +2,6 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union
from pydantic import BaseModel, Field, model_validator
from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory
@ -17,52 +16,43 @@ ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
class VariablePool(BaseModel):
# Variable dictionary is a dictionary for looking up variables by their selector.
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: dict[str, dict[int, Segment]] = Field(
description='Variables mapping',
default=defaultdict(dict)
)
class VariablePool:
def __init__(
self,
system_variables: Mapping[SystemVariableKey, Any],
user_inputs: Mapping[str, Any],
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable] | None = None,
) -> None:
# system variables
# for example:
# {
# 'query': 'abc',
# 'files': []
# }
# TODO: This user inputs is not used for pool.
user_inputs: Mapping[str, Any] = Field(
description='User inputs',
)
# Varaible dictionary is a dictionary for looking up variables by their selector.
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
system_variables: Mapping[SystemVariableKey, Any] = Field(
description='System variables',
)
# TODO: This user inputs is not used for pool.
self.user_inputs = user_inputs
environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
default_factory=list
)
conversation_variables: Sequence[Variable] | None = None
@model_validator(mode="after")
def val_model_after(self):
"""
Append system variables
:return:
"""
# Add system variables to the variable pool
for key, value in self.system_variables.items():
self.system_variables = system_variables
for key, value in system_variables.items():
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Add environment variables to the variable pool
for var in self.environment_variables or []:
for var in environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
# Add conversation variables to the variable pool
for var in self.conversation_variables or []:
for var in conversation_variables or []:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
return self
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Adds a variable to the variable pool.
@ -89,7 +79,7 @@ class VariablePool(BaseModel):
v = factory.build_segment(value)
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]][hash_key] = v
self._variable_dictionary[selector[0]][hash_key] = v
def get(self, selector: Sequence[str], /) -> Segment | None:
"""
@ -107,7 +97,7 @@ class VariablePool(BaseModel):
if len(selector) < 2:
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self.variable_dictionary[selector[0]].get(hash_key)
value = self._variable_dictionary[selector[0]].get(hash_key)
return value
@ -128,7 +118,7 @@ class VariablePool(BaseModel):
if len(selector) < 2:
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self.variable_dictionary[selector[0]].get(hash_key)
value = self._variable_dictionary[selector[0]].get(hash_key)
return value.to_object() if value else None
def remove(self, selector: Sequence[str], /):
@ -144,19 +134,7 @@ class VariablePool(BaseModel):
if not selector:
return
if len(selector) == 1:
self.variable_dictionary[selector[0]] = {}
self._variable_dictionary[selector[0]] = {}
return
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]].pop(hash_key, None)
def remove_node(self, node_id: str, /):
"""
Remove all variables associated with a given node id.
Args:
node_id (str): The node id to remove.
Returns:
None
"""
self.variable_dictionary.pop(node_id, None)
self._variable_dictionary[selector[0]].pop(hash_key, None)

View File

@ -66,7 +66,8 @@ class WorkflowRunState:
self.variable_pool = variable_pool
self.total_tokens = 0
self.workflow_nodes_and_results = []
self.workflow_node_steps = 1
self.workflow_node_runs = []
self.current_iteration_state = None
self.workflow_node_steps = 1
self.workflow_node_runs = []

View File

@ -1,8 +1,10 @@
from core.workflow.nodes.base_node import BaseNode
from core.workflow.entities.node_entities import NodeType
class WorkflowNodeRunFailedError(Exception):
def __init__(self, node_instance: BaseNode, error: str):
self.node_instance = node_instance
def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str):
self.node_id = node_id
self.node_type = node_type
self.node_title = node_title
self.error = error
super().__init__(f"Node {node_instance.node_data.title} run failed: {error}")
super().__init__(f"Node {node_title} run failed: {error}")

View File

@ -1,31 +0,0 @@
from abc import ABC, abstractmethod
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class RunConditionHandler(ABC):
def __init__(self,
init_params: GraphInitParams,
graph: Graph,
condition: RunCondition):
self.init_params = init_params
self.graph = graph
self.condition = condition
@abstractmethod
def check(self,
graph_runtime_state: GraphRuntimeState,
previous_route_node_state: RouteNodeState
) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
raise NotImplementedError

View File

@ -1,28 +0,0 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class BranchIdentifyRunConditionHandler(RunConditionHandler):
def check(self,
graph_runtime_state: GraphRuntimeState,
previous_route_node_state: RouteNodeState) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
if not self.condition.branch_identify:
raise Exception("Branch identify is required")
run_result = previous_route_node_state.node_run_result
if not run_result:
return False
if not run_result.edge_source_handle:
return False
return self.condition.branch_identify == run_result.edge_source_handle

View File

@ -1,32 +0,0 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self,
graph_runtime_state: GraphRuntimeState,
previous_route_node_state: RouteNodeState
) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
if not self.condition.conditions:
return True
# process condition
condition_processor = ConditionProcessor()
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=graph_runtime_state.variable_pool,
conditions=self.condition.conditions
)
# Apply the logical operator for the current case
compare_result = all(group_result)
return compare_result

View File

@ -1,35 +0,0 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.run_condition import RunCondition
class ConditionManager:
@staticmethod
def get_condition_handler(
init_params: GraphInitParams,
graph: Graph,
run_condition: RunCondition
) -> RunConditionHandler:
"""
Get condition handler
:param init_params: init params
:param graph: graph
:param run_condition: run condition
:return: condition handler
"""
if run_condition.type == "branch_identify":
return BranchIdentifyRunConditionHandler(
init_params=init_params,
graph=graph,
condition=run_condition
)
else:
return ConditionRunConditionHandlerHandler(
init_params=init_params,
graph=graph,
condition=run_condition
)

View File

@ -1,163 +0,0 @@
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class GraphEngineEvent(BaseModel):
pass
###########################################
# Graph Events
###########################################
class BaseGraphEvent(GraphEngineEvent):
pass
class GraphRunStartedEvent(BaseGraphEvent):
pass
class GraphRunSucceededEvent(BaseGraphEvent):
outputs: Optional[dict[str, Any]] = None
"""outputs"""
class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
###########################################
# Node Events
###########################################
class BaseNodeEvent(GraphEngineEvent):
id: str = Field(..., description="node execution id")
node_id: str = Field(..., description="node id")
node_type: NodeType = Field(..., description="node type")
node_data: BaseNodeData = Field(..., description="node data")
route_node_state: RouteNodeState = Field(..., description="route node state")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class NodeRunStartedEvent(BaseNodeEvent):
predecessor_node_id: Optional[str] = None
"""predecessor node id"""
class NodeRunStreamChunkEvent(BaseNodeEvent):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: Optional[list[str]] = None
"""from variable selector"""
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class NodeRunSucceededEvent(BaseNodeEvent):
pass
class NodeRunFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
###########################################
# Parallel Branch Events
###########################################
class BaseParallelBranchEvent(GraphEngineEvent):
parallel_id: str = Field(..., description="parallel id")
"""parallel id"""
parallel_start_node_id: str = Field(..., description="parallel start node id")
"""parallel start node id"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
pass
class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
pass
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
error: str = Field(..., description="failed reason")
###########################################
# Iteration Events
###########################################
class BaseIterationEvent(GraphEngineEvent):
iteration_id: str = Field(..., description="iteration node execution id")
iteration_node_id: str = Field(..., description="iteration node id")
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
iteration_node_data: BaseNodeData = Field(..., description="node data")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
class IterationRunStartedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
predecessor_node_id: Optional[str] = None
class IterationRunNextEvent(BaseIterationEvent):
index: int = Field(..., description="index")
pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output")
class IterationRunSucceededEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
steps: int = 0
class IterationRunFailedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
steps: int = 0
error: str = Field(..., description="failed reason")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent

View File

@ -1,672 +0,0 @@
import uuid
from collections.abc import Mapping
from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
from core.workflow.nodes.end.entities import EndStreamParam
class GraphEdge(BaseModel):
source_node_id: str = Field(..., description="source node id")
target_node_id: str = Field(..., description="target node id")
run_condition: Optional[RunCondition] = None
"""run condition"""
class GraphParallel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
start_from_node_id: str = Field(..., description="start from node id")
parent_parallel_id: Optional[str] = None
"""parent parallel id"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id"""
end_to_node_id: Optional[str] = None
"""end to node id"""
class Graph(BaseModel):
root_node_id: str = Field(..., description="root node id of the graph")
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
node_id_config_mapping: dict[str, dict] = Field(
default_factory=list,
description="node configs mapping (node id: node config)"
)
edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict,
description="graph edge mapping (source node id: edges)"
)
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict,
description="reverse graph edge mapping (target node id: edges)"
)
parallel_mapping: dict[str, GraphParallel] = Field(
default_factory=dict,
description="graph parallel mapping (parallel id: parallel)"
)
node_parallel_mapping: dict[str, str] = Field(
default_factory=dict,
description="graph node parallel mapping (node id: parallel id)"
)
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(
...,
description="answer stream generate routes"
)
end_stream_param: EndStreamParam = Field(
...,
description="end stream param"
)
@classmethod
def init(cls,
graph_config: Mapping[str, Any],
root_node_id: Optional[str] = None) -> "Graph":
"""
Init graph
:param graph_config: graph config
:param root_node_id: root node id
:return: graph
"""
# edge configs
edge_configs = graph_config.get('edges')
if edge_configs is None:
edge_configs = []
edge_configs = cast(list, edge_configs)
# reorganize edges mapping
edge_mapping: dict[str, list[GraphEdge]] = {}
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
target_edge_ids = set()
for edge_config in edge_configs:
source_node_id = edge_config.get('source')
if not source_node_id:
continue
if source_node_id not in edge_mapping:
edge_mapping[source_node_id] = []
target_node_id = edge_config.get('target')
if not target_node_id:
continue
if target_node_id not in reverse_edge_mapping:
reverse_edge_mapping[target_node_id] = []
# is target node id in source node id edge mapping
if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]):
continue
target_edge_ids.add(target_node_id)
# parse run condition
run_condition = None
if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source':
run_condition = RunCondition(
type='branch_identify',
branch_identify=edge_config.get('sourceHandle')
)
graph_edge = GraphEdge(
source_node_id=source_node_id,
target_node_id=target_node_id,
run_condition=run_condition
)
edge_mapping[source_node_id].append(graph_edge)
reverse_edge_mapping[target_node_id].append(graph_edge)
# node configs
node_configs = graph_config.get('nodes')
if not node_configs:
raise ValueError("Graph must have at least one node")
node_configs = cast(list, node_configs)
# fetch nodes that have no predecessor node
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}
for node_config in node_configs:
node_id = node_config.get('id')
if not node_id:
continue
if node_id not in target_edge_ids:
root_node_configs.append(node_config)
all_node_id_config_mapping[node_id] = node_config
root_node_ids = [node_config.get('id') for node_config in root_node_configs]
# fetch root node
if not root_node_id:
# if no root node id, use the START type node as root node
root_node_id = next((node_config.get("id") for node_config in root_node_configs
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
if not root_node_id or root_node_id not in root_node_ids:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
# Check whether it is connected to the previous node
cls._check_connected_to_previous_node(
route=[root_node_id],
edge_mapping=edge_mapping
)
# fetch all node ids from root node
node_ids = [root_node_id]
cls._recursively_add_node_ids(
node_ids=node_ids,
edge_mapping=edge_mapping,
node_id=root_node_id
)
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
# init parallel mapping
parallel_mapping: dict[str, GraphParallel] = {}
node_parallel_mapping: dict[str, str] = {}
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=root_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping
)
# Check if it exceeds N layers of parallel
for parallel in parallel_mapping.values():
if parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=3,
parent_parallel_id=parallel.parent_parallel_id
)
# init answer stream generate routes
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping
)
# init end stream param
end_stream_param = EndStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
node_parallel_mapping=node_parallel_mapping
)
# init graph
graph = cls(
root_node_id=root_node_id,
node_ids=node_ids,
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
answer_stream_generate_routes=answer_stream_generate_routes,
end_stream_param=end_stream_param
)
return graph
def add_extra_edge(self, source_node_id: str,
target_node_id: str,
run_condition: Optional[RunCondition] = None) -> None:
"""
Add extra edge to the graph
:param source_node_id: source node id
:param target_node_id: target node id
:param run_condition: run condition
"""
if source_node_id not in self.node_ids or target_node_id not in self.node_ids:
return
if source_node_id not in self.edge_mapping:
self.edge_mapping[source_node_id] = []
if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]:
return
graph_edge = GraphEdge(
source_node_id=source_node_id,
target_node_id=target_node_id,
run_condition=run_condition
)
self.edge_mapping[source_node_id].append(graph_edge)
def get_leaf_node_ids(self) -> list[str]:
"""
Get leaf node ids of the graph
:return: leaf node ids
"""
leaf_node_ids = []
for node_id in self.node_ids:
if node_id not in self.edge_mapping:
leaf_node_ids.append(node_id)
elif (len(self.edge_mapping[node_id]) == 1
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id):
leaf_node_ids.append(node_id)
return leaf_node_ids
@classmethod
def _recursively_add_node_ids(cls,
node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
node_id: str) -> None:
"""
Recursively add node ids
:param node_ids: node ids
:param edge_mapping: edge mapping
:param node_id: node id
"""
for graph_edge in edge_mapping.get(node_id, []):
if graph_edge.target_node_id in node_ids:
continue
node_ids.append(graph_edge.target_node_id)
cls._recursively_add_node_ids(
node_ids=node_ids,
edge_mapping=edge_mapping,
node_id=graph_edge.target_node_id
)
@classmethod
def _check_connected_to_previous_node(
cls,
route: list[str],
edge_mapping: dict[str, list[GraphEdge]]
) -> None:
"""
Check whether it is connected to the previous node
"""
last_node_id = route[-1]
for graph_edge in edge_mapping.get(last_node_id, []):
if not graph_edge.target_node_id:
continue
if graph_edge.target_node_id in route:
raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.")
new_route = route[:]
new_route.append(graph_edge.target_node_id)
cls._check_connected_to_previous_node(
route=new_route,
edge_mapping=edge_mapping,
)
@classmethod
def _recursively_add_parallels(
cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
parallel_mapping: dict[str, GraphParallel],
node_parallel_mapping: dict[str, str],
parent_parallel: Optional[GraphParallel] = None
) -> None:
"""
Recursively add parallel ids
:param edge_mapping: edge mapping
:param start_node_id: start from node id
:param parallel_mapping: parallel mapping
:param node_parallel_mapping: node parallel mapping
:param parent_parallel: parent parallel
"""
target_node_edges = edge_mapping.get(start_node_id, [])
parallel = None
if len(target_node_edges) > 1:
# fetch all node ids in current parallels
parallel_branch_node_ids = []
condition_edge_mappings = {}
for graph_edge in target_node_edges:
if graph_edge.run_condition is None:
parallel_branch_node_ids.append(graph_edge.target_node_id)
else:
condition_hash = graph_edge.run_condition.hash
if not condition_hash in condition_edge_mappings:
condition_edge_mappings[condition_hash] = []
condition_edge_mappings[condition_hash].append(graph_edge)
for _, graph_edges in condition_edge_mappings.items():
if len(graph_edges) > 1:
for graph_edge in graph_edges:
parallel_branch_node_ids.append(graph_edge.target_node_id)
# any target node id in node_parallel_mapping
if parallel_branch_node_ids:
parent_parallel_id = parent_parallel.id if parent_parallel else None
parallel = GraphParallel(
start_from_node_id=start_node_id,
parent_parallel_id=parent_parallel.id if parent_parallel else None,
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None
)
parallel_mapping[parallel.id] = parallel
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_branch_node_ids=parallel_branch_node_ids
)
# collect all branches node ids
parallel_node_ids = []
for _, node_ids in in_branch_node_ids.items():
for node_id in node_ids:
in_parent_parallel = True
if parent_parallel_id:
in_parent_parallel = False
for parallel_node_id, parallel_id in node_parallel_mapping.items():
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
in_parent_parallel = True
break
if in_parent_parallel:
parallel_node_ids.append(node_id)
node_parallel_mapping[node_id] = parallel.id
outside_parallel_target_node_ids = set()
for node_id in parallel_node_ids:
if node_id == parallel.start_from_node_id:
continue
node_edges = edge_mapping.get(node_id)
if not node_edges:
continue
if len(node_edges) > 1:
continue
target_node_id = node_edges[0].target_node_id
if target_node_id in parallel_node_ids:
continue
if parent_parallel_id:
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
continue
if (
(node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id)
or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id)
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
):
outside_parallel_target_node_ids.add(target_node_id)
if len(outside_parallel_target_node_ids) == 1:
if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id:
parallel.end_to_node_id = None
else:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
for graph_edge in target_node_edges:
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=parallel if parallel else parent_parallel
)
@classmethod
def _check_exceed_parallel_limit(
cls,
parallel_mapping: dict[str, GraphParallel],
level_limit: int,
parent_parallel_id: str,
current_level: int = 1
) -> None:
"""
Check if it exceeds N layers of parallel
"""
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
return
current_level += 1
if current_level > level_limit:
raise ValueError(f"Exceeds {level_limit} layers of parallel")
if parent_parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=level_limit,
parent_parallel_id=parent_parallel.parent_parallel_id,
current_level=current_level
)
@classmethod
def _recursively_add_parallel_node_ids(cls,
branch_node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
merge_node_id: str,
start_node_id: str) -> None:
"""
Recursively add node ids
:param branch_node_ids: in branch node ids
:param edge_mapping: edge mapping
:param merge_node_id: merge node id
:param start_node_id: start node id
"""
for graph_edge in edge_mapping.get(start_node_id, []):
if (graph_edge.target_node_id != merge_node_id
and graph_edge.target_node_id not in branch_node_ids):
branch_node_ids.append(graph_edge.target_node_id)
cls._recursively_add_parallel_node_ids(
branch_node_ids=branch_node_ids,
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=graph_edge.target_node_id
)
@classmethod
def _fetch_all_node_ids_in_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
"""
Fetch all node ids in parallels
"""
routes_node_ids: dict[str, list[str]] = {}
for parallel_branch_node_id in parallel_branch_node_ids:
routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]
# fetch routes node ids
cls._recursively_fetch_routes(
edge_mapping=edge_mapping,
start_node_id=parallel_branch_node_id,
routes_node_ids=routes_node_ids[parallel_branch_node_id]
)
# fetch leaf node ids from routes node ids
leaf_node_ids: dict[str, list[str]] = {}
merge_branch_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
for node_id in node_ids:
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
if branch_node_id not in leaf_node_ids:
leaf_node_ids[branch_node_id] = []
leaf_node_ids[branch_node_id].append(node_id)
for branch_node_id2, inner_route2 in routes_node_ids.items():
if (
branch_node_id != branch_node_id2
and node_id in inner_route2
and len(reverse_edge_mapping.get(node_id, [])) > 1
and cls._is_node_in_routes(
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=node_id,
routes_node_ids=routes_node_ids
)
):
if node_id not in merge_branch_node_ids:
merge_branch_node_ids[node_id] = []
if branch_node_id2 not in merge_branch_node_ids[node_id]:
merge_branch_node_ids[node_id].append(branch_node_id2)
# sorted merge_branch_node_ids by branch_node_ids length desc
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
duplicate_end_node_ids = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids:
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
# check which node is after
if cls._is_node2_after_node1(
node1_id=node_id,
node2_id=node_id2,
edge_mapping=edge_mapping
):
if node_id in merge_branch_node_ids:
del merge_branch_node_ids[node_id2]
elif cls._is_node2_after_node1(
node1_id=node_id2,
node2_id=node_id,
edge_mapping=edge_mapping
):
if node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id]
branches_merge_node_ids: dict[str, str] = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
if len(branch_node_ids) <= 1:
continue
for branch_node_id in branch_node_ids:
if branch_node_id in branches_merge_node_ids:
continue
branches_merge_node_ids[branch_node_id] = node_id
in_branch_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
in_branch_node_ids[branch_node_id] = []
if branch_node_id not in branches_merge_node_ids:
# all node ids in current branch is in this thread
in_branch_node_ids[branch_node_id].append(branch_node_id)
in_branch_node_ids[branch_node_id].extend(node_ids)
else:
merge_node_id = branches_merge_node_ids[branch_node_id]
if merge_node_id != branch_node_id:
in_branch_node_ids[branch_node_id].append(branch_node_id)
# fetch all node ids from branch_node_id and merge_node_id
cls._recursively_add_parallel_node_ids(
branch_node_ids=in_branch_node_ids[branch_node_id],
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=branch_node_id
)
return in_branch_node_ids
@classmethod
def _recursively_fetch_routes(cls,
edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
routes_node_ids: list[str]) -> None:
"""
Recursively fetch route
"""
if start_node_id not in edge_mapping:
return
for graph_edge in edge_mapping[start_node_id]:
# find next node ids
if graph_edge.target_node_id not in routes_node_ids:
routes_node_ids.append(graph_edge.target_node_id)
cls._recursively_fetch_routes(
edge_mapping=edge_mapping,
start_node_id=graph_edge.target_node_id,
routes_node_ids=routes_node_ids
)
@classmethod
def _is_node_in_routes(cls,
reverse_edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
routes_node_ids: dict[str, list[str]]) -> bool:
"""
Recursively check if the node is in the routes
"""
if start_node_id not in reverse_edge_mapping:
return False
all_routes_node_ids = set()
parallel_start_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
for node_id in node_ids:
all_routes_node_ids.add(node_id)
if branch_node_id in reverse_edge_mapping:
for graph_edge in reverse_edge_mapping[branch_node_id]:
if graph_edge.source_node_id not in parallel_start_node_ids:
parallel_start_node_ids[graph_edge.source_node_id] = []
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
parallel_start_node_id = None
for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
if set(branch_node_ids) == set(routes_node_ids.keys()):
parallel_start_node_id = p_start_node_id
return True
if not parallel_start_node_id:
raise Exception("Parallel start node id not found")
for graph_edge in reverse_edge_mapping[start_node_id]:
if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id:
return False
return True
@classmethod
def _is_node2_after_node1(
cls,
node1_id: str,
node2_id: str,
edge_mapping: dict[str, list[GraphEdge]]
) -> bool:
"""
is node2 after node1
"""
if node1_id not in edge_mapping:
return False
for graph_edge in edge_mapping[node1_id]:
if graph_edge.target_node_id == node2_id:
return True
if cls._is_node2_after_node1(
node1_id=graph_edge.target_node_id,
node2_id=node2_id,
edge_mapping=edge_mapping
):
return True
return False

View File

@ -1,21 +0,0 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import UserFrom
from models.workflow import WorkflowType
class GraphInitParams(BaseModel):
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")
workflow_type: WorkflowType = Field(..., description="workflow type")
workflow_id: str = Field(..., description="workflow id")
graph_config: Mapping[str, Any] = Field(..., description="graph config")
user_id: str = Field(..., description="user id")
user_from: UserFrom = Field(..., description="user from, account or end-user")
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
call_depth: int = Field(..., description="call depth")

View File

@ -1,27 +0,0 @@
from typing import Any
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
class GraphRuntimeState(BaseModel):
variable_pool: VariablePool = Field(..., description="variable pool")
"""variable pool"""
start_at: float = Field(..., description="start time")
"""start time"""
total_tokens: int = 0
"""total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage()
"""llm usage info"""
outputs: dict[str, Any] = {}
"""outputs"""
node_run_steps: int = 0
"""node run steps"""
node_run_state: RuntimeRouteState = RuntimeRouteState()
"""node run state"""

View File

@ -1,13 +0,0 @@
from typing import Optional
from pydantic import BaseModel
from core.workflow.graph_engine.entities.graph import GraphParallel
class NextGraphNode(BaseModel):
node_id: str
"""next node id"""
parallel: Optional[GraphParallel] = None
"""parallel"""

Some files were not shown because too many files have changed in this diff Show More