mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 20:19:28 +08:00
Compare commits
39 Commits
fix/reset-
...
feat/llm-f
| Author | SHA1 | Date | |
|---|---|---|---|
| bfc1583626 | |||
| 5db06175de | |||
| bdd8d5b470 | |||
| 4955de5905 | |||
| 3bee2ee067 | |||
| 328897f81c | |||
| ab078380a3 | |||
| a33ac77a22 | |||
| d3923e7b56 | |||
| 2f633de45e | |||
| 98c88cec34 | |||
| c6999fb5be | |||
| f7f9a08fa5 | |||
| 5008f5e89b | |||
| 1dd89a02ea | |||
| 5bf4114d6f | |||
| a56e94ba8e | |||
| 11f1782df0 | |||
| 8cf5d9a6a1 | |||
| 0ec2b12e65 | |||
| f33b1a3332 | |||
| 08026f7399 | |||
| 18e051bd66 | |||
| 42f991dbef | |||
| b1b2c9636f | |||
| 01f17b7ddc | |||
| 14b2e5bd0d | |||
| d095bd413b | |||
| 3473ff7ad1 | |||
| 138c56bd6e | |||
| c327d0bb44 | |||
| e4b97fba29 | |||
| 7f9884e7a1 | |||
| e389cd1665 | |||
| 87f348a0de | |||
| 206706987d | |||
| 91da784f84 | |||
| a129e684cc | |||
| fe07c810ba |
1
.agent/skills
Symbolic link
1
.agent/skills
Symbolic link
@ -0,0 +1 @@
|
||||
../.claude/skills
|
||||
46
.claude/skills/orpc-contract-first/SKILL.md
Normal file
46
.claude/skills/orpc-contract-first/SKILL.md
Normal file
@ -0,0 +1,46 @@
|
||||
---
|
||||
name: orpc-contract-first
|
||||
description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Triggers when creating new API contracts, adding service endpoints, integrating TanStack Query with typed contracts, or migrating legacy service calls to oRPC. Use for all API layer work in web/contract and web/service directories.
|
||||
---
|
||||
|
||||
# oRPC Contract-First Development
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
web/contract/
|
||||
├── base.ts # Base contract (inputStructure: 'detailed')
|
||||
├── router.ts # Router composition & type exports
|
||||
├── marketplace.ts # Marketplace contracts
|
||||
└── console/ # Console contracts by domain
|
||||
├── system.ts
|
||||
└── billing.ts
|
||||
```
|
||||
|
||||
## Workflow
|
||||
|
||||
1. **Create contract** in `web/contract/console/{domain}.ts`
|
||||
- Import `base` from `../base` and `type` from `@orpc/contract`
|
||||
- Define route with `path`, `method`, `input`, `output`
|
||||
|
||||
2. **Register in router** at `web/contract/router.ts`
|
||||
- Import directly from domain file (no barrel files)
|
||||
- Nest by API prefix: `billing: { invoices, bindPartnerStack }`
|
||||
|
||||
3. **Create hooks** in `web/service/use-{domain}.ts`
|
||||
- Use `consoleQuery.{group}.{contract}.queryKey()` for query keys
|
||||
- Use `consoleClient.{group}.{contract}()` for API calls
|
||||
|
||||
## Key Rules
|
||||
|
||||
- **Input structure**: Always use `{ params, query?, body? }` format
|
||||
- **Path params**: Use `{paramName}` in path, match in `params` object
|
||||
- **Router nesting**: Group by API prefix (e.g., `/billing/*` → `billing: {}`)
|
||||
- **No barrel files**: Import directly from specific files
|
||||
- **Types**: Import from `@/types/`, use `type<T>()` helper
|
||||
|
||||
## Type Export
|
||||
|
||||
```typescript
|
||||
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
|
||||
```
|
||||
2
.github/workflows/style.yml
vendored
2
.github/workflows/style.yml
vendored
@ -90,7 +90,7 @@ jobs:
|
||||
uses: actions/setup-node@v6
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
|
||||
8
.github/workflows/tool-test-sdks.yaml
vendored
8
.github/workflows/tool-test-sdks.yaml
vendored
@ -16,10 +16,6 @@ jobs:
|
||||
name: unit test for Node.js SDK
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
node-version: [16, 18, 20, 22]
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: sdks/nodejs-client
|
||||
@ -29,10 +25,10 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Use Node.js ${{ matrix.node-version }}
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
node-version: 24
|
||||
cache: ''
|
||||
cache-dependency-path: 'pnpm-lock.yaml'
|
||||
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@ -57,7 +57,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
|
||||
2
.github/workflows/web-tests.yml
vendored
2
.github/workflows/web-tests.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
|
||||
@ -417,6 +417,8 @@ SMTP_USERNAME=123
|
||||
SMTP_PASSWORD=abc
|
||||
SMTP_USE_TLS=true
|
||||
SMTP_OPPORTUNISTIC_TLS=false
|
||||
# Optional: override the local hostname used for SMTP HELO/EHLO
|
||||
SMTP_LOCAL_HOSTNAME=
|
||||
# Sendgid configuration
|
||||
SENDGRID_API_KEY=
|
||||
# Sentry configuration
|
||||
@ -713,3 +715,4 @@ ANNOTATION_IMPORT_MAX_CONCURRENT=5
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||
|
||||
|
||||
@ -949,6 +949,12 @@ class MailConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
SMTP_LOCAL_HOSTNAME: str | None = Field(
|
||||
description="Override the local hostname used in SMTP HELO/EHLO. "
|
||||
"Useful behind NAT or when the default hostname causes rejections.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
|
||||
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
|
||||
default=50,
|
||||
@ -959,16 +965,6 @@ class MailConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
ENABLE_TRIAL_APP: bool = Field(
|
||||
description="Enable trial app",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ENABLE_EXPLORE_BANNER: bool = Field(
|
||||
description="Enable explore banner",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class RagEtlConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class VolcengineTOSStorageConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Volcengine Tinder Object Storage (TOS)
|
||||
Configuration settings for Volcengine Torch Object Storage (TOS)
|
||||
"""
|
||||
|
||||
VOLCENGINE_TOS_BUCKET_NAME: str | None = Field(
|
||||
|
||||
@ -107,12 +107,10 @@ from .datasets.rag_pipeline import (
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
banner,
|
||||
installed_app,
|
||||
parameter,
|
||||
recommended_app,
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
|
||||
# Import tag controllers
|
||||
@ -147,7 +145,6 @@ __all__ = [
|
||||
"apikey",
|
||||
"app",
|
||||
"audio",
|
||||
"banner",
|
||||
"billing",
|
||||
"bp",
|
||||
"completion",
|
||||
@ -201,7 +198,6 @@ __all__ = [
|
||||
"statistic",
|
||||
"tags",
|
||||
"tool_providers",
|
||||
"trial",
|
||||
"trigger_providers",
|
||||
"version",
|
||||
"website",
|
||||
|
||||
@ -15,7 +15,7 @@ from controllers.console.wraps import only_edition_cloud
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@ -32,8 +32,6 @@ class InsertExploreAppPayload(BaseModel):
|
||||
language: str = Field(...)
|
||||
category: str = Field(...)
|
||||
position: int = Field(...)
|
||||
can_trial: bool = Field(default=False)
|
||||
trial_limit: int = Field(default=0)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
@ -41,33 +39,11 @@ class InsertExploreAppPayload(BaseModel):
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
class InsertExploreBannerPayload(BaseModel):
|
||||
category: str = Field(...)
|
||||
title: str = Field(...)
|
||||
description: str = Field(...)
|
||||
img_src: str = Field(..., alias="img-src")
|
||||
language: str = Field(default="en-US")
|
||||
link: str = Field(...)
|
||||
sort: int = Field(...)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
InsertExploreAppPayload.__name__,
|
||||
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
InsertExploreBannerPayload.__name__,
|
||||
InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def admin_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
@ -133,20 +109,6 @@ class InsertExploreAppListApi(Resource):
|
||||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
if payload.can_trial:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=payload.app_id,
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=payload.trial_limit,
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = payload.trial_limit
|
||||
|
||||
app.is_public = True
|
||||
db.session.commit()
|
||||
@ -161,20 +123,6 @@ class InsertExploreAppListApi(Resource):
|
||||
recommended_app.category = payload.category
|
||||
recommended_app.position = payload.position
|
||||
|
||||
if payload.can_trial:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=payload.app_id,
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=payload.trial_limit,
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = payload.trial_limit
|
||||
app.is_public = True
|
||||
|
||||
db.session.commit()
|
||||
@ -220,62 +168,7 @@ class InsertExploreAppApi(Resource):
|
||||
for installed_app in installed_apps:
|
||||
session.delete(installed_app)
|
||||
|
||||
trial_app = session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
|
||||
).scalar_one_or_none()
|
||||
if trial_app:
|
||||
session.delete(trial_app)
|
||||
|
||||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-banner")
|
||||
class InsertExploreBannerApi(Resource):
|
||||
@console_ns.doc("insert_explore_banner")
|
||||
@console_ns.doc(description="Insert an explore banner")
|
||||
@console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__])
|
||||
@console_ns.response(201, "Banner inserted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
|
||||
|
||||
content = {
|
||||
"category": payload.category,
|
||||
"title": payload.title,
|
||||
"description": payload.description,
|
||||
"img-src": payload.img_src,
|
||||
}
|
||||
|
||||
banner = ExporleBanner(
|
||||
content=content,
|
||||
link=payload.link,
|
||||
sort=payload.sort,
|
||||
language=payload.language,
|
||||
)
|
||||
db.session.add(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-banner/<uuid:banner_id>")
|
||||
class DeleteExploreBannerApi(Resource):
|
||||
@console_ns.doc("delete_explore_banner")
|
||||
@console_ns.doc(description="Delete an explore banner")
|
||||
@console_ns.doc(params={"banner_id": "Banner ID to delete"})
|
||||
@console_ns.response(204, "Banner deleted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, banner_id):
|
||||
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
|
||||
if not banner:
|
||||
raise NotFound(f"Banner '{banner_id}' is not found")
|
||||
|
||||
db.session.delete(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@ -272,7 +272,6 @@ class AnnotationExportApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response_data = {"data": marshal(annotation_list, annotation_fields)}
|
||||
@ -360,6 +359,7 @@ class AnnotationBatchImportApi(Resource):
|
||||
file.seek(0, 2) # Seek to end of file
|
||||
file_size = file.tell()
|
||||
file.seek(0) # Reset to beginning
|
||||
|
||||
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
|
||||
if file_size > max_size_bytes:
|
||||
abort(
|
||||
|
||||
@ -592,9 +592,12 @@ def _get_conversation(app_model, conversation_id):
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if not conversation.read_at:
|
||||
conversation.read_at = naive_utc_now()
|
||||
conversation.read_account_id = current_user.id
|
||||
db.session.commit()
|
||||
db.session.execute(
|
||||
sa.update(Conversation)
|
||||
.where(Conversation.id == conversation_id, Conversation.read_at.is_(None))
|
||||
.values(read_at=naive_utc_now(), read_account_id=current_user.id)
|
||||
)
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
|
||||
return conversation
|
||||
|
||||
@ -115,9 +115,3 @@ class InvokeRateLimitError(BaseHTTPException):
|
||||
error_code = "rate_limit_error"
|
||||
description = "Rate Limit Error"
|
||||
code = 429
|
||||
|
||||
|
||||
class NeedAddIdsError(BaseHTTPException):
|
||||
error_code = "need_add_ids"
|
||||
description = "Need to add ids."
|
||||
code = 400
|
||||
|
||||
@ -202,7 +202,6 @@ message_detail_model = console_ns.model(
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"generation_detail": fields.Raw,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -23,11 +23,6 @@ def _load_app_model(app_id: str) -> App | None:
|
||||
return app_model
|
||||
|
||||
|
||||
def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
|
||||
return app_model
|
||||
|
||||
|
||||
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P1, R1]):
|
||||
@wraps(view_func)
|
||||
@ -67,44 +62,3 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
app_id = kwargs.get("app_id")
|
||||
app_id = str(app_id)
|
||||
|
||||
del kwargs["app_id"]
|
||||
|
||||
app_model = _load_app_model_with_trial(app_id)
|
||||
|
||||
if not app_model:
|
||||
raise AppNotFoundError()
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
if mode is not None:
|
||||
if isinstance(mode, list):
|
||||
modes = mode
|
||||
else:
|
||||
modes = [mode]
|
||||
|
||||
if app_mode not in modes:
|
||||
mode_values = {m.value for m in modes}
|
||||
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
|
||||
|
||||
kwargs["app_model"] = app_model
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
@ -161,7 +161,10 @@ class OAuthCallback(Resource):
|
||||
ip_address=extract_remote_ip(request),
|
||||
)
|
||||
|
||||
response = redirect(f"{dify_config.CONSOLE_WEB_URL}?oauth_new_user={str(oauth_new_user).lower()}")
|
||||
base_url = dify_config.CONSOLE_WEB_URL
|
||||
query_char = "&" if "?" in base_url else "?"
|
||||
target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
|
||||
response = redirect(target_url)
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
|
||||
@ -146,7 +146,6 @@ class DatasetUpdatePayload(BaseModel):
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
summary_index_setting: dict[str, Any] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import Literal, cast
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import asc, desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@ -39,10 +39,9 @@ from fields.document_fields import (
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog, DocumentSegmentSummary
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
from ..app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
@ -105,8 +104,13 @@ class DocumentRenamePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class GenerateSummaryPayload(BaseModel):
|
||||
document_list: list[str]
|
||||
class DocumentDatasetListParam(BaseModel):
|
||||
page: int = Field(1, title="Page", description="Page number.")
|
||||
limit: int = Field(20, title="Limit", description="Page size.")
|
||||
search: str | None = Field(None, alias="keyword", title="Search", description="Search keyword.")
|
||||
sort_by: str = Field("-created_at", alias="sort", title="SortBy", description="Sort by field.")
|
||||
status: str | None = Field(None, title="Status", description="Document status.")
|
||||
fetch_val: str = Field("false", alias="fetch")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
@ -116,7 +120,6 @@ register_schema_models(
|
||||
RetrievalModel,
|
||||
DocumentRetryPayload,
|
||||
DocumentRenamePayload,
|
||||
GenerateSummaryPayload,
|
||||
)
|
||||
|
||||
|
||||
@ -231,14 +234,16 @@ class DatasetDocumentListApi(Resource):
|
||||
def get(self, dataset_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
sort = request.args.get("sort", default="-created_at", type=str)
|
||||
status = request.args.get("status", default=None, type=str)
|
||||
raw_args = request.args.to_dict()
|
||||
param = DocumentDatasetListParam.model_validate(raw_args)
|
||||
page = param.page
|
||||
limit = param.limit
|
||||
search = param.search
|
||||
sort = param.sort_by
|
||||
status = param.status
|
||||
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
|
||||
try:
|
||||
fetch_val = request.args.get("fetch", default="false")
|
||||
fetch_val = param.fetch_val
|
||||
if isinstance(fetch_val, bool):
|
||||
fetch = fetch_val
|
||||
else:
|
||||
@ -301,97 +306,6 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
# Check if dataset has summary index enabled
|
||||
has_summary_index = (
|
||||
dataset.summary_index_setting
|
||||
and dataset.summary_index_setting.get("enable") is True
|
||||
)
|
||||
|
||||
# Filter documents that need summary calculation
|
||||
documents_need_summary = [doc for doc in documents if doc.need_summary is True]
|
||||
document_ids_need_summary = [str(doc.id) for doc in documents_need_summary]
|
||||
|
||||
# Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled)
|
||||
summary_status_map = {}
|
||||
if has_summary_index and document_ids_need_summary:
|
||||
# Get all segments for these documents (excluding qa_model and re_segment)
|
||||
segments = (
|
||||
db.session.query(DocumentSegment.id, DocumentSegment.document_id)
|
||||
.where(
|
||||
DocumentSegment.document_id.in_(document_ids_need_summary),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.tenant_id == current_tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Group segments by document_id
|
||||
document_segments_map = {}
|
||||
for segment in segments:
|
||||
doc_id = str(segment.document_id)
|
||||
if doc_id not in document_segments_map:
|
||||
document_segments_map[doc_id] = []
|
||||
document_segments_map[doc_id].append(segment.id)
|
||||
|
||||
# Get all summary records for these segments
|
||||
all_segment_ids = [seg.id for seg in segments]
|
||||
summaries = {}
|
||||
if all_segment_ids:
|
||||
summary_records = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id.in_(all_segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only count enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
summaries = {summary.chunk_id: summary.status for summary in summary_records}
|
||||
|
||||
# Calculate summary_index_status for each document
|
||||
for doc_id in document_ids_need_summary:
|
||||
segment_ids = document_segments_map.get(doc_id, [])
|
||||
if not segment_ids:
|
||||
# No segments, status is "GENERATING" (waiting to generate)
|
||||
summary_status_map[doc_id] = "GENERATING"
|
||||
continue
|
||||
|
||||
# Count summary statuses for this document's segments
|
||||
status_counts = {"completed": 0, "generating": 0, "error": 0, "not_started": 0}
|
||||
for segment_id in segment_ids:
|
||||
status = summaries.get(segment_id, "not_started")
|
||||
if status in status_counts:
|
||||
status_counts[status] += 1
|
||||
else:
|
||||
status_counts["not_started"] += 1
|
||||
|
||||
total_segments = len(segment_ids)
|
||||
completed_count = status_counts["completed"]
|
||||
generating_count = status_counts["generating"]
|
||||
error_count = status_counts["error"]
|
||||
|
||||
# Determine overall status (only three states: GENERATING, COMPLETED, ERROR)
|
||||
if completed_count == total_segments:
|
||||
summary_status_map[doc_id] = "COMPLETED"
|
||||
elif error_count > 0:
|
||||
# Has errors (even if some are completed or generating)
|
||||
summary_status_map[doc_id] = "ERROR"
|
||||
elif generating_count > 0 or status_counts["not_started"] > 0:
|
||||
# Still generating or not started
|
||||
summary_status_map[doc_id] = "GENERATING"
|
||||
else:
|
||||
# Default to generating
|
||||
summary_status_map[doc_id] = "GENERATING"
|
||||
|
||||
# Add summary_index_status to each document
|
||||
for document in documents:
|
||||
if has_summary_index and document.need_summary is True:
|
||||
document.summary_index_status = summary_status_map.get(str(document.id), "GENERATING")
|
||||
else:
|
||||
# Return null if summary index is not enabled or document doesn't need summary
|
||||
document.summary_index_status = None
|
||||
|
||||
if fetch:
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
@ -490,7 +404,6 @@ class DatasetDocumentListApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
|
||||
@console_ns.route("/datasets/init")
|
||||
class DatasetInitApi(Resource):
|
||||
@console_ns.doc("init_dataset")
|
||||
@ -878,7 +791,6 @@ class DocumentApi(DocumentResource):
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
@ -914,7 +826,6 @@ class DocumentApi(DocumentResource):
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
|
||||
return response, 200
|
||||
@ -1282,211 +1193,3 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
"input_data": log.input_data,
|
||||
"datasource_node_id": log.datasource_node_id,
|
||||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/generate-summary")
|
||||
class DocumentGenerateSummaryApi(Resource):
|
||||
@console_ns.doc("generate_summary_for_documents")
|
||||
@console_ns.doc(description="Generate summary index for documents")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
|
||||
@console_ns.response(200, "Summary generation started successfully")
|
||||
@console_ns.response(400, "Invalid request or dataset configuration")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
"""
|
||||
Generate summary index for specified documents.
|
||||
|
||||
This endpoint checks if the dataset configuration supports summary generation
|
||||
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
|
||||
then asynchronously generates summary indexes for the provided documents.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check permissions
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Validate request payload
|
||||
payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
|
||||
document_list = payload.document_list
|
||||
|
||||
if not document_list:
|
||||
raise ValueError("document_list cannot be empty.")
|
||||
|
||||
# Check if dataset configuration supports summary generation
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
raise ValueError(
|
||||
f"Summary generation is only available for 'high_quality' indexing technique. "
|
||||
f"Current indexing technique: {dataset.indexing_technique}"
|
||||
)
|
||||
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
raise ValueError(
|
||||
"Summary index is not enabled for this dataset. "
|
||||
"Please enable it in the dataset settings."
|
||||
)
|
||||
|
||||
# Verify all documents exist and belong to the dataset
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.id.in_(document_list),
|
||||
Document.dataset_id == dataset_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(documents) != len(document_list):
|
||||
found_ids = {doc.id for doc in documents}
|
||||
missing_ids = set(document_list) - found_ids
|
||||
raise NotFound(f"Some documents not found: {list(missing_ids)}")
|
||||
|
||||
# Dispatch async tasks for each document
|
||||
for document in documents:
|
||||
# Skip qa_model documents as they don't generate summaries
|
||||
if document.doc_form == "qa_model":
|
||||
logger.info(
|
||||
f"Skipping summary generation for qa_model document {document.id}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Dispatch async task
|
||||
generate_summary_index_task(dataset_id, document.id)
|
||||
logger.info(
|
||||
f"Dispatched summary generation task for document {document.id} in dataset {dataset_id}"
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/summary-status")
|
||||
class DocumentSummaryStatusApi(DocumentResource):
|
||||
@console_ns.doc("get_document_summary_status")
|
||||
@console_ns.doc(description="Get summary index generation status for a document")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Summary status retrieved successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""
|
||||
Get summary index generation status for a document.
|
||||
|
||||
Returns:
|
||||
- total_segments: Total number of segments in the document
|
||||
- summary_status: Dictionary with status counts
|
||||
- completed: Number of summaries completed
|
||||
- generating: Number of summaries being generated
|
||||
- error: Number of summaries with errors
|
||||
- not_started: Number of segments without summary records
|
||||
- summaries: List of summary records with status and content preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
# Get document
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check permissions
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Get all segments for this document
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
total_segments = len(segments)
|
||||
|
||||
# Get all summary records for these segments
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries = []
|
||||
if segment_ids:
|
||||
summaries = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
DocumentSegmentSummary.document_id == document_id,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create a mapping of chunk_id to summary
|
||||
summary_map = {summary.chunk_id: summary for summary in summaries}
|
||||
|
||||
# Count statuses
|
||||
status_counts = {
|
||||
"completed": 0,
|
||||
"generating": 0,
|
||||
"error": 0,
|
||||
"not_started": 0,
|
||||
}
|
||||
|
||||
summary_list = []
|
||||
for segment in segments:
|
||||
summary = summary_map.get(segment.id)
|
||||
if summary:
|
||||
status = summary.status
|
||||
status_counts[status] = status_counts.get(status, 0) + 1
|
||||
summary_list.append({
|
||||
"segment_id": segment.id,
|
||||
"segment_position": segment.position,
|
||||
"status": summary.status,
|
||||
"summary_preview": summary.summary_content[:100] + "..." if summary.summary_content and len(summary.summary_content) > 100 else summary.summary_content,
|
||||
"error": summary.error,
|
||||
"created_at": int(summary.created_at.timestamp()) if summary.created_at else None,
|
||||
"updated_at": int(summary.updated_at.timestamp()) if summary.updated_at else None,
|
||||
})
|
||||
else:
|
||||
status_counts["not_started"] += 1
|
||||
summary_list.append({
|
||||
"segment_id": segment.id,
|
||||
"segment_position": segment.position,
|
||||
"status": "not_started",
|
||||
"summary_preview": None,
|
||||
"error": None,
|
||||
"created_at": None,
|
||||
"updated_at": None,
|
||||
})
|
||||
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"summary_status": status_counts,
|
||||
"summaries": summary_list,
|
||||
}, 200
|
||||
|
||||
@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.helper import escape_like_pattern
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import ChildChunk, DocumentSegment, DocumentSegmentSummary
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
|
||||
@ -41,23 +41,6 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
||||
def _get_segment_with_summary(segment, dataset_id):
|
||||
"""Helper function to marshal segment and add summary information."""
|
||||
segment_dict = marshal(segment, segment_fields)
|
||||
# Query summary for this segment (only enabled summaries)
|
||||
summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
)
|
||||
.first()
|
||||
)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
class SegmentListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
status: list[str] = Field(default_factory=list)
|
||||
@ -80,7 +63,6 @@ class SegmentUpdatePayload(BaseModel):
|
||||
keywords: list[str] | None = None
|
||||
regenerate_child_chunks: bool = False
|
||||
attachment_ids: list[str] | None = None
|
||||
summary: str | None = None # Summary content for summary index
|
||||
|
||||
|
||||
class BatchImportPayload(BaseModel):
|
||||
@ -198,34 +180,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
|
||||
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
# Query summaries for all segments in this page (batch query for efficiency)
|
||||
segment_ids = [segment.id for segment in segments.items]
|
||||
summaries = {}
|
||||
if segment_ids:
|
||||
summary_records = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
# Only include enabled summaries
|
||||
summaries = {
|
||||
summary.chunk_id: summary.summary_content
|
||||
for summary in summary_records
|
||||
if summary.enabled is True
|
||||
}
|
||||
|
||||
# Add summary to each segment
|
||||
segments_with_summary = []
|
||||
for segment in segments.items:
|
||||
segment_dict = marshal(segment, segment_fields)
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
segments_with_summary.append(segment_dict)
|
||||
|
||||
response = {
|
||||
"data": segments_with_summary,
|
||||
"data": marshal(segments.items, segment_fields),
|
||||
"limit": limit,
|
||||
"total": segments.total,
|
||||
"total_pages": segments.pages,
|
||||
@ -371,7 +327,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.create_segment(payload_dict, document, dataset)
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
@ -433,12 +389,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
|
||||
# Update segment (summary update with change detection is handled in SegmentService.update_segment)
|
||||
segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
|
||||
)
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -81,7 +81,7 @@ class ExternalKnowledgeApiPayload(BaseModel):
|
||||
class ExternalDatasetCreatePayload(BaseModel):
|
||||
external_knowledge_api_id: str
|
||||
external_knowledge_id: str
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
description: str | None = Field(None, max_length=400)
|
||||
external_retrieval_model: dict[str, object] | None = None
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from libs.login import login_required
|
||||
@ -10,56 +10,17 @@ from ..wraps import (
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
)
|
||||
from fields.hit_testing_fields import (
|
||||
child_chunk_fields,
|
||||
document_fields,
|
||||
files_fields,
|
||||
hit_testing_record_fields,
|
||||
segment_fields,
|
||||
)
|
||||
|
||||
register_schema_model(console_ns, HitTestingPayload)
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
document_model = _get_or_create_model("HitTestingDocument", document_fields)
|
||||
|
||||
segment_fields_copy = segment_fields.copy()
|
||||
segment_fields_copy["document"] = fields.Nested(document_model)
|
||||
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
|
||||
|
||||
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
|
||||
files_model = _get_or_create_model("HitTestingFile", files_fields)
|
||||
|
||||
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
|
||||
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
|
||||
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
|
||||
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
|
||||
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
|
||||
|
||||
# Response model for hit testing API
|
||||
hit_testing_response_fields = {
|
||||
"query": fields.String,
|
||||
"records": fields.List(fields.Nested(hit_testing_record_model)),
|
||||
}
|
||||
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@console_ns.doc("test_dataset_retrieval")
|
||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
|
||||
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
|
||||
@console_ns.response(200, "Hit testing completed successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
|
||||
@ -1,43 +0,0 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import explore_banner_enabled
|
||||
from extensions.ext_database import db
|
||||
from models.model import ExporleBanner
|
||||
|
||||
|
||||
class BannerApi(Resource):
|
||||
"""Resource for banner list."""
|
||||
|
||||
@explore_banner_enabled
|
||||
def get(self):
|
||||
"""Get banner list."""
|
||||
language = request.args.get("language", "en-US")
|
||||
|
||||
# Build base query for enabled banners
|
||||
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
|
||||
|
||||
# Try to get banners in the requested language
|
||||
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
|
||||
|
||||
# Fallback to en-US if no banners found and language is not en-US
|
||||
if not banners and language != "en-US":
|
||||
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
|
||||
# Convert banners to serializable format
|
||||
result = []
|
||||
for banner in banners:
|
||||
banner_data = {
|
||||
"id": banner.id,
|
||||
"content": banner.content, # Already parsed as JSON by SQLAlchemy
|
||||
"link": banner.link,
|
||||
"sort": banner.sort,
|
||||
"status": banner.status,
|
||||
"created_at": banner.created_at.isoformat() if banner.created_at else None,
|
||||
}
|
||||
result.append(banner_data)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
api.add_resource(BannerApi, "/explore/banners")
|
||||
@ -29,25 +29,3 @@ class AppAccessDeniedError(BaseHTTPException):
|
||||
error_code = "access_denied"
|
||||
description = "App access denied."
|
||||
code = 403
|
||||
|
||||
|
||||
class TrialAppNotAllowed(BaseHTTPException):
|
||||
"""*403* `Trial App Not Allowed`
|
||||
|
||||
Raise if the user has reached the trial app limit.
|
||||
"""
|
||||
|
||||
error_code = "trial_app_not_allowed"
|
||||
code = 403
|
||||
description = "the app is not allowed to be trial."
|
||||
|
||||
|
||||
class TrialAppLimitExceeded(BaseHTTPException):
|
||||
"""*403* `Trial App Limit Exceeded`
|
||||
|
||||
Raise if the user has exceeded the trial app limit.
|
||||
"""
|
||||
|
||||
error_code = "trial_app_limit_exceeded"
|
||||
code = 403
|
||||
description = "The user has exceeded the trial app limit."
|
||||
|
||||
@ -29,7 +29,6 @@ recommended_app_fields = {
|
||||
"category": fields.String,
|
||||
"position": fields.Integer,
|
||||
"is_listed": fields.Boolean,
|
||||
"can_trial": fields.Boolean,
|
||||
}
|
||||
|
||||
recommended_app_list_fields = {
|
||||
|
||||
@ -1,512 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import Parameters as ParametersResponse
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
NeedAddIdsError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model_with_trial
|
||||
from controllers.console.explore.error import (
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
NotWorkflowAppError,
|
||||
)
|
||||
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields_with_site
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.account import TenantStatus
|
||||
from models.model import AppMode, Site
|
||||
from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_service import AppService
|
||||
from services.audio_service import AudioService
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.errors.message import (
|
||||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
def post(self, trial_app):
|
||||
"""
|
||||
Run workflow
|
||||
"""
|
||||
app_model = trial_app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
assert current_user is not None
|
||||
try:
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
def post(self, trial_app, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
"""
|
||||
app_model = trial_app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialMessageSuggestedQuestionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def get(self, trial_app, message_id):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
|
||||
|
||||
class TrialChatAudioApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||
)
|
||||
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialSitApi(Resource):
|
||||
"""Resource for trial app sites."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
def get(self, app_model):
|
||||
"""Retrieve app site info.
|
||||
|
||||
Returns the site configuration for the application including theme, icons, and text.
|
||||
"""
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
assert app_model.tenant
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return SiteResponse.model_validate(site).model_dump(mode="json")
|
||||
|
||||
|
||||
class TrialAppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
def get(self, app_model):
|
||||
"""Retrieve app parameters."""
|
||||
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return ParametersResponse.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.get_app(app_model)
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
class AppWorkflowApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(workflow_fields)
|
||||
def get(self, app_model):
|
||||
"""Get workflow detail"""
|
||||
if not app_model.workflow_id:
|
||||
raise AppUnavailableError()
|
||||
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.id == app_model.workflow_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return workflow
|
||||
|
||||
|
||||
class DatasetListApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
def get(self, app_model):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
ids = request.args.getlist("ids")
|
||||
|
||||
tenant_id = app_model.tenant_id
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id)
|
||||
else:
|
||||
raise NeedAddIdsError()
|
||||
|
||||
data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields))
|
||||
|
||||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
|
||||
|
||||
api.add_resource(
|
||||
TrialMessageSuggestedQuestionApi,
|
||||
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||
endpoint="trial_app_suggested_question",
|
||||
)
|
||||
|
||||
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
|
||||
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
|
||||
|
||||
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
|
||||
|
||||
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
|
||||
|
||||
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
|
||||
|
||||
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
|
||||
|
||||
api.add_resource(TrialAppWorkflowRunApi, "/trial-apps/<uuid:app_id>/workflows/run", endpoint="trial_app_workflow_run")
|
||||
api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
|
||||
|
||||
api.add_resource(AppWorkflowApi, "/trial-apps/<uuid:app_id>/workflows", endpoint="trial_app_workflow")
|
||||
api.add_resource(DatasetListApi, "/trial-apps/<uuid:app_id>/datasets", endpoint="trial_app_datasets")
|
||||
@ -2,15 +2,14 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import abort
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
|
||||
from controllers.console.explore.error import AppAccessDeniedError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
|
||||
from models import InstalledApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -72,61 +71,6 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[App, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
|
||||
|
||||
if trial_app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
app = trial_app.app
|
||||
|
||||
if app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
|
||||
account_trial_app_record = (
|
||||
db.session.query(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
|
||||
.first()
|
||||
)
|
||||
if account_trial_app_record:
|
||||
if account_trial_app_record.count >= trial_app.trial_limit:
|
||||
raise TrialAppLimitExceeded()
|
||||
|
||||
return view(app, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_trial_app:
|
||||
abort(403, "Trial app feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_explore_banner:
|
||||
abort(403, "Explore banner feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
class InstalledAppResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
|
||||
@ -136,13 +80,3 @@ class InstalledAppResource(Resource):
|
||||
account_initialization_required,
|
||||
login_required,
|
||||
]
|
||||
|
||||
|
||||
class TrialAppResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
|
||||
method_decorators = [
|
||||
trial_app_required,
|
||||
account_initialization_required,
|
||||
login_required,
|
||||
]
|
||||
|
||||
@ -358,12 +358,14 @@ def annotation_import_rate_limit(view: Callable[P, R]):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
current_time = int(time.time() * 1000)
|
||||
|
||||
# Check per-minute rate limit
|
||||
minute_key = f"annotation_import_rate_limit:{current_tenant_id}:1min"
|
||||
redis_client.zadd(minute_key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(minute_key, 0, current_time - 60000)
|
||||
minute_count = redis_client.zcard(minute_key)
|
||||
redis_client.expire(minute_key, 120) # 2 minutes TTL
|
||||
|
||||
if minute_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE:
|
||||
abort(
|
||||
429,
|
||||
@ -377,6 +379,7 @@ def annotation_import_rate_limit(view: Callable[P, R]):
|
||||
redis_client.zremrangebyscore(hour_key, 0, current_time - 3600000)
|
||||
hour_count = redis_client.zcard(hour_key)
|
||||
redis_client.expire(hour_key, 7200) # 2 hours TTL
|
||||
|
||||
if hour_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:
|
||||
abort(
|
||||
429,
|
||||
|
||||
@ -1,380 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentAppRunner(BaseAgentRunner):
|
||||
def _create_tool_invoke_hook(self, message: Message):
|
||||
"""
|
||||
Create a tool invoke hook that uses ToolEngine.agent_invoke.
|
||||
This hook handles file creation and returns proper meta information.
|
||||
"""
|
||||
# Get trace manager from app generate entity
|
||||
trace_manager = self.application_generate_entity.trace_manager
|
||||
|
||||
def tool_invoke_hook(
|
||||
tool: Tool, tool_args: dict[str, Any], tool_name: str
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""Hook that uses agent_invoke for proper file and meta handling."""
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
|
||||
# Publish files and track IDs
|
||||
for message_file_id in message_files:
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
self._current_message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, message_files, tool_invoke_meta
|
||||
|
||||
return tool_invoke_hook
|
||||
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run Agent application
|
||||
"""
|
||||
self.query = query
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, _ = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
# Create tool invoke hook for agent_invoke
|
||||
tool_invoke_hook = self._create_tool_invoke_hook(message)
|
||||
|
||||
# Get instruction for ReAct strategy
|
||||
instruction = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
# Use factory to create appropriate strategy
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=self.model_features,
|
||||
model_instance=self.model_instance,
|
||||
tools=list(tool_instances.values()),
|
||||
files=list(self.files),
|
||||
max_iterations=app_config.agent.max_iteration,
|
||||
context=self.build_execution_context(),
|
||||
agent_strategy=self.config.strategy,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Initialize state variables
|
||||
current_agent_thought_id = None
|
||||
has_published_thought = False
|
||||
current_tool_name: str | None = None
|
||||
self._current_message_file_ids: list[str] = []
|
||||
|
||||
# organize prompt messages
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
|
||||
# Run strategy
|
||||
generator = strategy.run(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Consume generator and collect result
|
||||
result: AgentResult | None = None
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
output = next(generator)
|
||||
except StopIteration as e:
|
||||
# Generator finished, get the return value
|
||||
result = e.value
|
||||
break
|
||||
|
||||
if isinstance(output, LLMResultChunk):
|
||||
# Handle LLM chunk
|
||||
if current_agent_thought_id and not has_published_thought:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
has_published_thought = True
|
||||
|
||||
yield output
|
||||
|
||||
elif isinstance(output, AgentLog):
|
||||
# Handle Agent Log using log_type for type-safe dispatch
|
||||
if output.status == AgentLog.LogStatus.START:
|
||||
if output.log_type == AgentLog.LogType.ROUND:
|
||||
# Start of a new round
|
||||
message_file_ids: list[str] = []
|
||||
current_agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message="",
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
has_published_thought = False
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call start - extract data from structured fields
|
||||
current_tool_name = output.data.get("tool_name", "")
|
||||
tool_input = output.data.get("tool_args", {})
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_input=tool_input,
|
||||
thought=None,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.status == AgentLog.LogStatus.SUCCESS:
|
||||
if output.log_type == AgentLog.LogType.THOUGHT:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
thought_text = output.data.get("thought")
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=thought_text,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call finished
|
||||
tool_output = output.data.get("output")
|
||||
# Get meta from strategy output (now properly populated)
|
||||
tool_meta = output.data.get("meta")
|
||||
|
||||
# Wrap tool_meta with tool_name as key (required by agent_service)
|
||||
if tool_meta and current_tool_name:
|
||||
tool_meta = {current_tool_name: tool_meta}
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
observation=tool_output,
|
||||
tool_invoke_meta=tool_meta,
|
||||
answer=None,
|
||||
messages_ids=self._current_message_file_ids,
|
||||
)
|
||||
# Clear message file ids after saving
|
||||
self._current_message_file_ids = []
|
||||
current_tool_name = None
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.ROUND:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Round finished - save LLM usage and answer
|
||||
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
|
||||
llm_result = output.data.get("llm_result")
|
||||
final_answer = output.data.get("final_answer")
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=llm_result,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Re-raise any other exceptions
|
||||
raise
|
||||
|
||||
# Process final result
|
||||
if isinstance(result, AgentResult):
|
||||
final_answer = result.text
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
# Publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=self.model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=usage,
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
if not prompt_template:
|
||||
return prompt_messages or []
|
||||
|
||||
prompt_messages = prompt_messages or []
|
||||
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
|
||||
return prompt_messages
|
||||
|
||||
if not prompt_messages:
|
||||
return [SystemPromptMessage(content=prompt_template)]
|
||||
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
We need to remove the image messages from the prompt messages at the first iteration.
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
# For ReAct strategy, use the agent prompt template
|
||||
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
|
||||
prompt_template = self.config.prompt.first_prompt
|
||||
else:
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
return prompt_messages
|
||||
@ -6,7 +6,7 @@ from typing import Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -116,20 +116,9 @@ class BaseAgentRunner(AppRunner):
|
||||
features = model_schema.features if model_schema and model_schema.features else []
|
||||
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
||||
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
||||
self.model_features = features
|
||||
self.query: str | None = ""
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
def build_execution_context(self) -> ExecutionContext:
|
||||
"""Build execution context."""
|
||||
return ExecutionContext(
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_config.app_id,
|
||||
conversation_id=self.conversation.id,
|
||||
message_id=self.message.id,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
def _repack_app_generate_entity(
|
||||
self, app_generate_entity: AgentChatAppGenerateEntity
|
||||
) -> AgentChatAppGenerateEntity:
|
||||
|
||||
437
api/core/agent/cot_agent_runner.py
Normal file
437
api/core/agent/cot_agent_runner.py
Normal file
@ -0,0 +1,437 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.workflow.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage]
|
||||
_agent_scratchpad: list[AgentScratchpadUnit]
|
||||
_instruction: str
|
||||
_query: str
|
||||
_prompt_messages_tools: Sequence[PromptMessageTool]
|
||||
|
||||
def run(
|
||||
self,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: Mapping[str, str],
|
||||
) -> Generator:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
|
||||
app_generate_entity = self.application_generate_entity
|
||||
self._repack_app_generate_entity(app_generate_entity)
|
||||
self._init_react_state(query)
|
||||
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
# check model mode
|
||||
if "Observation" not in app_generate_entity.model_conf.stop:
|
||||
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config.agent
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template or ""
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
self._prompt_messages_tools = prompt_messages_tools
|
||||
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
final_answer = ""
|
||||
prompt_messages: list = [] # Initialize prompt_messages
|
||||
agent_thought_id = "" # Initialize agent_thought_id
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.total_tokens += usage.total_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
llm_usage.total_price += usage.total_price
|
||||
|
||||
model_instance = self.model_instance
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
self._prompt_messages_tools = []
|
||||
|
||||
message_file_ids: list[str] = []
|
||||
|
||||
agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
if iteration_step > 1:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
tools=[],
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
usage_dict: dict[str, LLMUsage | None] = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
action = chunk
|
||||
# detect action
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
scratchpad.action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.action = action
|
||||
else:
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += chunk
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Check if max iteration is reached and model still wants to call tools
|
||||
if iteration_step == max_iteration_steps and scratchpad.action:
|
||||
if scratchpad.action.action_name.lower() != "final answer":
|
||||
raise AgentMaxIterationError(app_config.agent.max_iteration)
|
||||
|
||||
# get llm usage
|
||||
if "usage" in usage_dict:
|
||||
if usage_dict["usage"] is not None:
|
||||
increase_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
usage_dict["usage"] = LLMUsage.empty_usage()
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought or "",
|
||||
observation="",
|
||||
answer=scratchpad.agent_response or "",
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
if not scratchpad.is_final():
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = ""
|
||||
else:
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
|
||||
elif isinstance(scratchpad.action.action_input, str):
|
||||
final_answer = scratchpad.action.action_input
|
||||
else:
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
except TypeError:
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
else:
|
||||
function_call_state = True
|
||||
# action is tool call, invoke tool
|
||||
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
|
||||
action=scratchpad.action,
|
||||
tool_instances=tool_instances,
|
||||
message_file_ids=message_file_ids,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
scratchpad.observation = tool_invoke_response
|
||||
scratchpad.agent_response = tool_invoke_response
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name=scratchpad.action.action_name,
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
thought=scratchpad.thought or "",
|
||||
observation={scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in self._prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
|
||||
),
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name="",
|
||||
tool_input={},
|
||||
tool_invoke_meta={},
|
||||
thought=final_answer,
|
||||
observation={},
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: Mapping[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
"""
|
||||
handle invoke action
|
||||
:param action: action
|
||||
:param tool_instances: tool instances
|
||||
:param message_file_ids: message file ids
|
||||
:param trace_manager: trace manager
|
||||
:return: observation, meta
|
||||
"""
|
||||
# action is tool call, invoke tool
|
||||
tool_call_name = action.action_name
|
||||
tool_call_args = action.action_input
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
|
||||
if not tool_instance:
|
||||
answer = f"there is not a tool named {tool_call_name}"
|
||||
return answer, ToolInvokeMeta.error_instance(answer)
|
||||
|
||||
if isinstance(tool_call_args, str):
|
||||
try:
|
||||
tool_call_args = json.loads(tool_call_args)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# invoke tool
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# publish files
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, tool_invoke_meta
|
||||
|
||||
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
|
||||
"""
|
||||
convert dict to action
|
||||
"""
|
||||
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
fill in inputs from external data tools
|
||||
"""
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return instruction
|
||||
|
||||
def _init_react_state(self, query):
|
||||
"""
|
||||
init agent scratchpad
|
||||
"""
|
||||
self._query = query
|
||||
self._agent_scratchpad = []
|
||||
self._historic_prompt_messages = self._organize_historic_prompt_messages()
|
||||
|
||||
@abstractmethod
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
organize prompt messages
|
||||
"""
|
||||
|
||||
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
format assistant message
|
||||
"""
|
||||
message = ""
|
||||
for scratchpad in agent_scratchpad:
|
||||
if scratchpad.is_final():
|
||||
message += f"Final Answer: {scratchpad.agent_response}"
|
||||
else:
|
||||
message += f"Thought: {scratchpad.thought}\n\n"
|
||||
if scratchpad.action_str:
|
||||
message += f"Action: {scratchpad.action_str}\n\n"
|
||||
if scratchpad.observation:
|
||||
message += f"Observation: {scratchpad.observation}\n\n"
|
||||
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: list[PromptMessage] | None = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
scratchpads: list[AgentScratchpadUnit] = []
|
||||
current_scratchpad: AgentScratchpadUnit | None = None
|
||||
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
if not current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
scratchpads.append(current_scratchpad)
|
||||
if message.tool_calls:
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except Exception:
|
||||
logger.exception("Failed to parse tool call from assistant message")
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad.observation = message.content
|
||||
else:
|
||||
raise NotImplementedError("expected str type")
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
scratchpads = []
|
||||
current_scratchpad = None
|
||||
|
||||
result.append(message)
|
||||
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
|
||||
historic_prompts = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=current_session_messages or [],
|
||||
history_messages=result,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
return historic_prompts
|
||||
118
api/core/agent/cot_chat_agent_runner.py
Normal file
118
api/core/agent/cot_chat_agent_runner.py
Normal file
@ -0,0 +1,118 @@
|
||||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class CotChatAgentRunner(CotAgentRunner):
|
||||
def _organize_system_prompt(self) -> SystemPromptMessage:
|
||||
"""
|
||||
Organize system prompt
|
||||
"""
|
||||
assert self.app_config.agent
|
||||
assert self.app_config.agent.prompt
|
||||
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if not prompt_entity:
|
||||
raise ValueError("Agent prompt configuration is not set")
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return SystemPromptMessage(content=system_prompt)
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize
|
||||
"""
|
||||
# organize system prompt
|
||||
system_message = self._organize_system_prompt()
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
if not agent_scratchpad:
|
||||
assistant_messages = []
|
||||
else:
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
||||
if unit.observation:
|
||||
assistant_message.content += f"Observation: {unit.observation}\n\n"
|
||||
|
||||
assistant_messages = [assistant_message]
|
||||
|
||||
# query messages
|
||||
query_messages = self._organize_user_query(self._query, [])
|
||||
|
||||
if assistant_messages:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages(
|
||||
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
|
||||
)
|
||||
messages = [
|
||||
system_message,
|
||||
*historic_messages,
|
||||
*query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content="continue"),
|
||||
]
|
||||
else:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
|
||||
messages = [system_message, *historic_messages, *query_messages]
|
||||
|
||||
# join all messages
|
||||
return messages
|
||||
87
api/core/agent/cot_completion_agent_runner.py
Normal file
87
api/core/agent/cot_completion_agent_runner.py
Normal file
@ -0,0 +1,87 @@
|
||||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class CotCompletionAgentRunner(CotAgentRunner):
|
||||
def _organize_instruction_prompt(self) -> str:
|
||||
"""
|
||||
Organize instruction prompt
|
||||
"""
|
||||
if self.app_config.agent is None:
|
||||
raise ValueError("Agent configuration is not set")
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if prompt_entity is None:
|
||||
raise ValueError("prompt entity is not set")
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str:
|
||||
"""
|
||||
Organize historic prompt
|
||||
"""
|
||||
historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
|
||||
historic_prompt = ""
|
||||
|
||||
for message in historic_prompt_messages:
|
||||
if isinstance(message, UserPromptMessage):
|
||||
historic_prompt += f"Question: {message.content}\n\n"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
if isinstance(message.content, str):
|
||||
historic_prompt += message.content + "\n\n"
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
historic_prompt += content.data
|
||||
|
||||
return historic_prompt
|
||||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
"""
|
||||
# organize system prompt
|
||||
system_prompt = self._organize_instruction_prompt()
|
||||
|
||||
# organize historic prompt messages
|
||||
historic_prompt = self._organize_historic_prompt()
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
assistant_prompt = ""
|
||||
for unit in agent_scratchpad or []:
|
||||
if unit.is_final():
|
||||
assistant_prompt += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assistant_prompt += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_prompt += f"Action: {unit.action_str}\n\n"
|
||||
if unit.observation:
|
||||
assistant_prompt += f"Observation: {unit.observation}\n\n"
|
||||
|
||||
# query messages
|
||||
query_prompt = f"Question: {self._query}"
|
||||
|
||||
# join all messages
|
||||
prompt = (
|
||||
system_prompt.replace("{{historic_messages}}", historic_prompt)
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt)
|
||||
.replace("{{query}}", query_prompt)
|
||||
)
|
||||
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
@ -1,5 +1,3 @@
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
@ -94,96 +92,3 @@ class AgentInvokeMessage(ToolInvokeMessage):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
"""Execution context containing trace and audit information.
|
||||
|
||||
This context carries all the IDs and metadata that are not part of
|
||||
the core business logic but needed for tracing, auditing, and
|
||||
correlation purposes.
|
||||
"""
|
||||
|
||||
user_id: str | None = None
|
||||
app_id: str | None = None
|
||||
conversation_id: str | None = None
|
||||
message_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext":
|
||||
"""Create a minimal context with only essential fields."""
|
||||
return cls(user_id=user_id)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for passing to legacy code."""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"message_id": self.message_id,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
def with_updates(self, **kwargs) -> "ExecutionContext":
|
||||
"""Create a new context with updated fields."""
|
||||
data = self.to_dict()
|
||||
data.update(kwargs)
|
||||
|
||||
return ExecutionContext(
|
||||
user_id=data.get("user_id"),
|
||||
app_id=data.get("app_id"),
|
||||
conversation_id=data.get("conversation_id"),
|
||||
message_id=data.get("message_id"),
|
||||
tenant_id=data.get("tenant_id"),
|
||||
)
|
||||
|
||||
|
||||
class AgentLog(BaseModel):
|
||||
"""
|
||||
Agent Log.
|
||||
"""
|
||||
|
||||
class LogType(StrEnum):
|
||||
"""Type of agent log entry."""
|
||||
|
||||
ROUND = "round" # A complete iteration round
|
||||
THOUGHT = "thought" # LLM thinking/reasoning
|
||||
TOOL_CALL = "tool_call" # Tool invocation
|
||||
|
||||
class LogMetadata(StrEnum):
|
||||
STARTED_AT = "started_at"
|
||||
FINISHED_AT = "finished_at"
|
||||
ELAPSED_TIME = "elapsed_time"
|
||||
TOTAL_PRICE = "total_price"
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
PROVIDER = "provider"
|
||||
CURRENCY = "currency"
|
||||
LLM_USAGE = "llm_usage"
|
||||
ICON = "icon"
|
||||
ICON_DARK = "icon_dark"
|
||||
|
||||
class LogStatus(StrEnum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="The id of the log")
|
||||
label: str = Field(..., description="The label of the log")
|
||||
log_type: LogType = Field(..., description="The type of the log")
|
||||
parent_id: str | None = Field(default=None, description="Leave empty for root log")
|
||||
error: str | None = Field(default=None, description="The error message")
|
||||
status: LogStatus = Field(..., description="The status of the log")
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
|
||||
|
||||
|
||||
class AgentResult(BaseModel):
|
||||
"""
|
||||
Agent execution result.
|
||||
"""
|
||||
|
||||
text: str = Field(default="", description="The generated text")
|
||||
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
|
||||
usage: Any | None = Field(default=None, description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(default=None, description="Reason for completion")
|
||||
|
||||
@ -1,55 +0,0 @@
|
||||
# Agent Patterns
|
||||
|
||||
A unified agent pattern module that powers both Agent V2 workflow nodes and agent applications. Strategies share a common execution contract while adapting to model capabilities and tool availability.
|
||||
|
||||
## Overview
|
||||
|
||||
The module applies a strategy pattern around LLM/tool orchestration. `StrategyFactory` auto-selects the best implementation based on model features or an explicit agent strategy, and each strategy streams logs and usage consistently.
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Dual strategies**
|
||||
- `FunctionCallStrategy`: uses native LLM function/tool calling when the model exposes `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL`.
|
||||
- `ReActStrategy`: ReAct (reasoning + acting) flow driven by `CotAgentOutputParser`, used when function calling is unavailable or explicitly requested.
|
||||
- **Explicit or auto selection**
|
||||
- `StrategyFactory.create_strategy` prefers an explicit `AgentEntity.Strategy` (FUNCTION_CALLING or CHAIN_OF_THOUGHT).
|
||||
- Otherwise it falls back to function calling when tool-call features exist, or ReAct when they do not.
|
||||
- **Unified execution contract**
|
||||
- `AgentPattern.run` yields streaming `AgentLog` entries and `LLMResultChunk` data, returning an `AgentResult` with text, files, usage, and `finish_reason`.
|
||||
- Iterations are configurable and hard-capped at 99 rounds; the last round forces a final answer by withholding tools.
|
||||
- **Tool handling and hooks**
|
||||
- Tools convert to `PromptMessageTool` objects before invocation.
|
||||
- Optional `tool_invoke_hook` lets callers override tool execution (e.g., agent apps) while workflow runs use `ToolEngine.generic_invoke`.
|
||||
- Tool outputs support text, links, JSON, variables, blobs, retriever resources, and file attachments; `target=="self"` files are reloaded into model context, others are returned as outputs.
|
||||
- **File-aware arguments**
|
||||
- Tool args accept `[File: <id>]` or `[Files: <id1, id2>]` placeholders that resolve to `File` objects before invocation, enabling models to reference uploaded files safely.
|
||||
- **ReAct prompt shaping**
|
||||
- System prompts replace `{{instruction}}`, `{{tools}}`, and `{{tool_names}}` placeholders.
|
||||
- Adds `Observation` to stop sequences and appends scratchpad text so the model sees prior Thought/Action/Observation history.
|
||||
- **Observability and accounting**
|
||||
- Standardized `AgentLog` entries for rounds, model thoughts, and tool calls, including usage aggregation (`LLMUsage`) across streaming and non-streaming paths.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
agent/patterns/
|
||||
├── base.py # Shared utilities: logging, usage, tool invocation, file handling
|
||||
├── function_call.py # Native function-calling loop with tool execution
|
||||
├── react.py # ReAct loop with CoT parsing and scratchpad wiring
|
||||
└── strategy_factory.py # Strategy selection by model features or explicit override
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
- For auto-selection:
|
||||
- Call `StrategyFactory.create_strategy(model_features, model_instance, context, tools, files, ...)` and run the returned strategy with prompt messages and model params.
|
||||
- For explicit behavior:
|
||||
- Pass `agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING` to force native calls (falls back to ReAct if unsupported), or `CHAIN_OF_THOUGHT` to force ReAct.
|
||||
- Both strategies stream chunks and logs; collect the generator output until it returns an `AgentResult`.
|
||||
|
||||
## Integration Points
|
||||
|
||||
- **Model runtime**: delegates to `ModelInstance.invoke_llm` for both streaming and non-streaming calls.
|
||||
- **Tool system**: defaults to `ToolEngine.generic_invoke`, with `tool_invoke_hook` for custom callers.
|
||||
- **Files**: flows through `File` objects for tool inputs/outputs and model-context attachments.
|
||||
- **Execution context**: `ExecutionContext` fields (user/app/conversation/message) propagate to tool invocations and logging.
|
||||
@ -1,19 +0,0 @@
|
||||
"""Agent patterns module.
|
||||
|
||||
This module provides different strategies for agent execution:
|
||||
- FunctionCallStrategy: Uses native function/tool calling
|
||||
- ReActStrategy: Uses ReAct (Reasoning + Acting) approach
|
||||
- StrategyFactory: Factory for creating strategies based on model features
|
||||
"""
|
||||
|
||||
from .base import AgentPattern
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
from .strategy_factory import StrategyFactory
|
||||
|
||||
__all__ = [
|
||||
"AgentPattern",
|
||||
"FunctionCallStrategy",
|
||||
"ReActStrategy",
|
||||
"StrategyFactory",
|
||||
]
|
||||
@ -1,474 +0,0 @@
|
||||
"""Base class for agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
# Type alias for tool invoke hook
|
||||
# Returns: (response_content, message_file_ids, tool_invoke_meta)
|
||||
ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]]
|
||||
|
||||
|
||||
class AgentPattern(ABC):
|
||||
"""Base class for agent execution strategies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
):
|
||||
"""Initialize the agent strategy."""
|
||||
self.model_instance = model_instance
|
||||
self.tools = tools
|
||||
self.context = context
|
||||
self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.files: list[File] = files
|
||||
self.tool_invoke_hook = tool_invoke_hook
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the agent strategy."""
|
||||
pass
|
||||
|
||||
def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None:
|
||||
"""Accumulate LLM usage statistics."""
|
||||
if not total_usage.get("usage"):
|
||||
# Create a copy to avoid modifying the original
|
||||
total_usage["usage"] = LLMUsage(
|
||||
prompt_tokens=delta_usage.prompt_tokens,
|
||||
prompt_unit_price=delta_usage.prompt_unit_price,
|
||||
prompt_price_unit=delta_usage.prompt_price_unit,
|
||||
prompt_price=delta_usage.prompt_price,
|
||||
completion_tokens=delta_usage.completion_tokens,
|
||||
completion_unit_price=delta_usage.completion_unit_price,
|
||||
completion_price_unit=delta_usage.completion_price_unit,
|
||||
completion_price=delta_usage.completion_price,
|
||||
total_tokens=delta_usage.total_tokens,
|
||||
total_price=delta_usage.total_price,
|
||||
currency=delta_usage.currency,
|
||||
latency=delta_usage.latency,
|
||||
)
|
||||
else:
|
||||
current: LLMUsage = total_usage["usage"]
|
||||
current.prompt_tokens += delta_usage.prompt_tokens
|
||||
current.completion_tokens += delta_usage.completion_tokens
|
||||
current.total_tokens += delta_usage.total_tokens
|
||||
current.prompt_price += delta_usage.prompt_price
|
||||
current.completion_price += delta_usage.completion_price
|
||||
current.total_price += delta_usage.total_price
|
||||
|
||||
def _extract_content(self, content: Any) -> str:
|
||||
"""Extract text content from message content."""
|
||||
if isinstance(content, list):
|
||||
# Content items are PromptMessageContentUnionTypes
|
||||
text_parts = []
|
||||
for c in content:
|
||||
# Check if it's a TextPromptMessageContent (which has data attribute)
|
||||
if isinstance(c, TextPromptMessageContent):
|
||||
text_parts.append(c.data)
|
||||
return "".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
def _has_tool_calls(self, chunk: LLMResultChunk) -> bool:
|
||||
"""Check if chunk contains tool calls."""
|
||||
# LLMResultChunk always has delta attribute
|
||||
return bool(chunk.delta.message and chunk.delta.message.tool_calls)
|
||||
|
||||
def _has_tool_calls_result(self, result: LLMResult) -> bool:
|
||||
"""Check if result contains tool calls (non-streaming)."""
|
||||
# LLMResult always has message attribute
|
||||
return bool(result.message and result.message.tool_calls)
|
||||
|
||||
def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from streaming chunk."""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
if chunk.delta.message and chunk.delta.message.tool_calls:
|
||||
for tool_call in chunk.delta.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from non-streaming result."""
|
||||
tool_calls = []
|
||||
if result.message and result.message.tool_calls:
|
||||
for tool_call in result.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_text_from_message(self, message: PromptMessage) -> str:
|
||||
"""Extract text content from a prompt message."""
|
||||
# PromptMessage always has content attribute
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from content list
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
text_parts.append(item.data)
|
||||
return " ".join(text_parts)
|
||||
return ""
|
||||
|
||||
def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]:
|
||||
"""Get metadata for a tool including provider and icon info."""
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {}
|
||||
if tool_instance.entity and tool_instance.entity.identity:
|
||||
identity = tool_instance.entity.identity
|
||||
if identity.provider:
|
||||
metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider
|
||||
|
||||
# Get icon using ToolManager for proper URL generation
|
||||
tenant_id = self.context.tenant_id
|
||||
if tenant_id and identity.provider:
|
||||
try:
|
||||
provider_type = tool_instance.tool_provider_type()
|
||||
icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider)
|
||||
if isinstance(icon, str):
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
elif isinstance(icon, dict):
|
||||
# Handle icon dict with background/content or light/dark variants
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
except Exception:
|
||||
# Fallback to identity.icon if ToolManager fails
|
||||
if identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
elif identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
return metadata
|
||||
|
||||
def _create_log(
|
||||
self,
|
||||
label: str,
|
||||
log_type: AgentLog.LogType,
|
||||
status: AgentLog.LogStatus,
|
||||
data: dict[str, Any] | None = None,
|
||||
parent_id: str | None = None,
|
||||
extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None,
|
||||
) -> AgentLog:
|
||||
"""Create a new AgentLog with standard metadata."""
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {
|
||||
AgentLog.LogMetadata.STARTED_AT: time.perf_counter(),
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
|
||||
return AgentLog(
|
||||
label=label,
|
||||
log_type=log_type,
|
||||
status=status,
|
||||
data=data or {},
|
||||
parent_id=parent_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _finish_log(
|
||||
self,
|
||||
log: AgentLog,
|
||||
data: dict[str, Any] | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
) -> AgentLog:
|
||||
"""Finish an AgentLog by updating its status and metadata."""
|
||||
log.status = AgentLog.LogStatus.SUCCESS
|
||||
|
||||
if data is not None:
|
||||
log.data = data
|
||||
|
||||
# Calculate elapsed time
|
||||
started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter())
|
||||
finished_at = time.perf_counter()
|
||||
|
||||
# Update metadata
|
||||
log.metadata = {
|
||||
**log.metadata,
|
||||
AgentLog.LogMetadata.FINISHED_AT: finished_at,
|
||||
# Calculate elapsed time in seconds
|
||||
AgentLog.LogMetadata.ELAPSED_TIME: round(finished_at - started_at, 4),
|
||||
}
|
||||
|
||||
# Add usage information if provided
|
||||
if usage:
|
||||
log.metadata.update(
|
||||
{
|
||||
AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price,
|
||||
AgentLog.LogMetadata.CURRENCY: usage.currency,
|
||||
AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens,
|
||||
AgentLog.LogMetadata.LLM_USAGE: usage,
|
||||
}
|
||||
)
|
||||
|
||||
return log
|
||||
|
||||
def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Replace file references in tool arguments with actual File objects.
|
||||
|
||||
Args:
|
||||
tool_args: Dictionary of tool arguments
|
||||
|
||||
Returns:
|
||||
Updated tool arguments with file references replaced
|
||||
"""
|
||||
# Process each argument in the dictionary
|
||||
processed_args: dict[str, Any] = {}
|
||||
for key, value in tool_args.items():
|
||||
processed_args[key] = self._process_file_reference(value)
|
||||
return processed_args
|
||||
|
||||
def _process_file_reference(self, data: Any) -> Any:
|
||||
"""
|
||||
Recursively process data to replace file references.
|
||||
Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...].
|
||||
|
||||
Args:
|
||||
data: The data to process (can be dict, list, str, or other types)
|
||||
|
||||
Returns:
|
||||
Processed data with file references replaced
|
||||
"""
|
||||
single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$")
|
||||
multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$")
|
||||
|
||||
if isinstance(data, dict):
|
||||
# Process dictionary recursively
|
||||
return {key: self._process_file_reference(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
# Process list recursively
|
||||
return [self._process_file_reference(item) for item in data]
|
||||
elif isinstance(data, str):
|
||||
# Check for single file pattern [File: file_id]
|
||||
single_match = single_file_pattern.match(data.strip())
|
||||
if single_match:
|
||||
file_id = single_match.group(1).strip()
|
||||
# Find the file in self.files
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
return file
|
||||
# If file not found, return original value
|
||||
return data
|
||||
|
||||
# Check for multiple files pattern [Files: file_id1, file_id2, ...]
|
||||
multiple_match = multiple_files_pattern.match(data.strip())
|
||||
if multiple_match:
|
||||
file_ids_str = multiple_match.group(1).strip()
|
||||
# Split by comma and strip whitespace
|
||||
file_ids = [fid.strip() for fid in file_ids_str.split(",")]
|
||||
|
||||
# Find all matching files
|
||||
matched_files: list[File] = []
|
||||
for file_id in file_ids:
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
matched_files.append(file)
|
||||
break
|
||||
|
||||
# Return list of files if any were found, otherwise return original
|
||||
return matched_files or data
|
||||
|
||||
return data
|
||||
else:
|
||||
# Return other types as-is
|
||||
return data
|
||||
|
||||
def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk:
|
||||
"""Create a text chunk for streaming."""
|
||||
return LLMResultChunk(
|
||||
model=self.model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=None,
|
||||
),
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
def _invoke_tool(
|
||||
self,
|
||||
tool_instance: Tool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> tuple[str, list[File], ToolInvokeMeta | None]:
|
||||
"""
|
||||
Invoke a tool and collect its response.
|
||||
|
||||
Args:
|
||||
tool_instance: The tool instance to invoke
|
||||
tool_args: Tool arguments
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Tuple of (response_content, tool_files, tool_invoke_meta)
|
||||
"""
|
||||
# Process tool_args to replace file references with actual File objects
|
||||
tool_args = self._replace_file_references(tool_args)
|
||||
|
||||
# If a tool invoke hook is set, use it instead of generic_invoke
|
||||
if self.tool_invoke_hook:
|
||||
response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name)
|
||||
# Note: message_file_ids are stored in DB, we don't convert them to File objects here
|
||||
# The caller (AgentAppRunner) handles file publishing
|
||||
return response_content, [], tool_invoke_meta
|
||||
|
||||
# Default: use generic_invoke for workflow scenarios
|
||||
# Import here to avoid circular import
|
||||
from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine
|
||||
|
||||
tool_response = ToolEngine().generic_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.context.user_id or "",
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=self.context.app_id,
|
||||
conversation_id=self.context.conversation_id,
|
||||
message_id=self.context.message_id,
|
||||
)
|
||||
|
||||
# Collect response and files
|
||||
response_content = ""
|
||||
tool_files: list[File] = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
response_content += response.message.text
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# Handle link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Link: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# Handle image URL messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK:
|
||||
# Handle image link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK:
|
||||
# Handle binary file link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
filename = response.meta.get("filename", "file") if response.meta else "file"
|
||||
response_content += f"[File: {filename} - {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
# Handle JSON messages
|
||||
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
|
||||
response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2)
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# Handle blob messages - convert to text representation
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobMessage):
|
||||
mime_type = (
|
||||
response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream"
|
||||
)
|
||||
size = len(response.message.blob)
|
||||
response_content += f"[Binary data: {mime_type}, size: {size} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
# Handle variable messages
|
||||
if isinstance(response.message, ToolInvokeMessage.VariableMessage):
|
||||
var_name = response.message.variable_name
|
||||
var_value = response.message.variable_value
|
||||
if isinstance(var_value, str):
|
||||
response_content += var_value
|
||||
else:
|
||||
response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
|
||||
# Handle blob chunk messages - these are parts of a larger blob
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage):
|
||||
response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
|
||||
# Handle retriever resources messages
|
||||
if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage):
|
||||
response_content += response.message.context
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.FILE:
|
||||
# Extract file from meta
|
||||
if response.meta and "file" in response.meta:
|
||||
file = response.meta["file"]
|
||||
if isinstance(file, File):
|
||||
# Check if file is for model or tool output
|
||||
if response.meta.get("target") == "self":
|
||||
# File is for model - add to files for next prompt
|
||||
self.files.append(file)
|
||||
response_content += f"File '{file.filename}' has been loaded into your context."
|
||||
else:
|
||||
# File is tool output
|
||||
tool_files.append(file)
|
||||
|
||||
return response_content, tool_files, None
|
||||
|
||||
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
|
||||
"""Find a tool instance by its name."""
|
||||
for tool in self.tools:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]:
|
||||
"""Convert tools to prompt message format."""
|
||||
prompt_tools: list[PromptMessageTool] = []
|
||||
for tool in self.tools:
|
||||
prompt_tools.append(tool.to_prompt_message_tool())
|
||||
return prompt_tools
|
||||
|
||||
def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None:
|
||||
"""Initialize usage tracking with empty usage if not set."""
|
||||
if "usage" not in llm_usage or llm_usage["usage"] is None:
|
||||
llm_usage["usage"] = LLMUsage.empty_usage()
|
||||
@ -1,299 +0,0 @@
|
||||
"""Function Call strategy implementation."""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
from .base import AgentPattern
|
||||
|
||||
|
||||
class FunctionCallStrategy(AgentPattern):
|
||||
"""Function Call strategy using model's native tool calling capability."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the function call agent strategy."""
|
||||
# Convert tools to prompt format
|
||||
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
|
||||
|
||||
# Initialize tracking
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
function_call_state: bool = True
|
||||
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
|
||||
while function_call_state and iteration_step <= max_iterations:
|
||||
function_call_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, remove tools to force final answer
|
||||
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=current_tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=self.context.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log
|
||||
)
|
||||
messages.append(self._create_assistant_message(response_content, tool_calls))
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update final text if no tool calls (this is likely the final answer)
|
||||
if not tool_calls:
|
||||
final_text = response_content
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Process tool calls
|
||||
tool_outputs: dict[str, str] = {}
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
# Execute tools
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"llm_result": response_content,
|
||||
"tool_calls": [
|
||||
{"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
"final_answer": final_text if not function_call_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, LLMUsage | None],
|
||||
start_log: AgentLog,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[list[tuple[str, str, dict[str, Any]]], str, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract tool calls and content.
|
||||
|
||||
Returns a tuple of (tool_calls, response_content, finish_reason).
|
||||
"""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
response_content: str = ""
|
||||
finish_reason: str | None = None
|
||||
if isinstance(chunks, Generator):
|
||||
# Streaming response
|
||||
for chunk in chunks:
|
||||
# Extract tool calls
|
||||
if self._has_tool_calls(chunk):
|
||||
tool_calls.extend(self._extract_tool_calls(chunk))
|
||||
|
||||
# Extract content
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
response_content += self._extract_content(chunk.delta.message.content)
|
||||
|
||||
# Track usage
|
||||
if chunk.delta.usage:
|
||||
self._accumulate_usage(llm_usage, chunk.delta.usage)
|
||||
|
||||
# Capture finish reason
|
||||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
# Non-streaming response
|
||||
result: LLMResult = chunks
|
||||
|
||||
if self._has_tool_calls_result(result):
|
||||
tool_calls.extend(self._extract_tool_calls_result(result))
|
||||
|
||||
if result.message and result.message.content:
|
||||
response_content += self._extract_content(result.message.content)
|
||||
|
||||
if result.usage:
|
||||
self._accumulate_usage(llm_usage, result.usage)
|
||||
|
||||
# Convert to streaming format
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
yield self._finish_log(
|
||||
start_log,
|
||||
data={
|
||||
"result": response_content,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
return tool_calls, response_content, finish_reason
|
||||
|
||||
def _create_assistant_message(
|
||||
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
||||
) -> AssistantPromptMessage:
|
||||
"""Create assistant message with tool calls."""
|
||||
if tool_calls is None:
|
||||
return AssistantPromptMessage(content=content)
|
||||
return AssistantPromptMessage(
|
||||
content=content or "",
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tc[0],
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])),
|
||||
)
|
||||
for tc in tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
tool_call_id: str,
|
||||
messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None]]:
|
||||
"""Handle a single tool call and return response with files and meta."""
|
||||
# Find tool
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
if not tool_instance:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
# Get tool metadata (provider, icon, etc.)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance)
|
||||
|
||||
# Create tool call log
|
||||
tool_call_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_call_log
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
|
||||
|
||||
yield self._finish_log(
|
||||
tool_call_log,
|
||||
data={
|
||||
**tool_call_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
final_content = response_content or "Tool executed successfully"
|
||||
# Add tool response to messages
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=final_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return response_content, tool_files, tool_invoke_meta
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = error_message
|
||||
tool_call_log.data = {
|
||||
**tool_call_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_call_log
|
||||
|
||||
# Add error message to conversation
|
||||
error_content = f"Tool execution failed: {error_message}"
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=error_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return error_content, [], None
|
||||
@ -1,418 +0,0 @@
|
||||
"""ReAct strategy implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
)
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class ReActStrategy(AgentPattern):
|
||||
"""ReAct strategy using reasoning and acting approach."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
):
|
||||
"""Initialize the ReAct strategy with instruction support."""
|
||||
super().__init__(
|
||||
model_instance=model_instance,
|
||||
tools=tools,
|
||||
context=context,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
files=files,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
self.instruction = instruction
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the ReAct agent strategy."""
|
||||
# Initialize tracking
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
react_state: bool = True
|
||||
total_usage: dict[str, Any] = {"usage": None}
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Add "Observation" to stop sequences
|
||||
if "Observation" not in stop:
|
||||
stop = stop.copy()
|
||||
stop.append("Observation")
|
||||
|
||||
while react_state and iteration_step <= max_iterations:
|
||||
react_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
|
||||
# Build prompt with/without tools based on iteration
|
||||
include_tools = iteration_step < max_iterations
|
||||
current_messages = self._build_prompt_with_react_format(
|
||||
prompt_messages, agent_scratchpad, include_tools, self.instruction
|
||||
)
|
||||
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, Any] = {"usage": None}
|
||||
|
||||
# Use current messages directly (files are handled by base class if needed)
|
||||
messages_to_use = current_messages
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=self.context.user_id or "",
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log, current_messages
|
||||
)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Check if we have an action to execute
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
react_state = True
|
||||
# Execute tool
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
# Add observation to scratchpad for display
|
||||
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
|
||||
else:
|
||||
# Extract final answer
|
||||
if scratchpad.action and scratchpad.action.action_input:
|
||||
final_answer = scratchpad.action.action_input
|
||||
if isinstance(final_answer, dict):
|
||||
final_answer = json.dumps(final_answer, ensure_ascii=False)
|
||||
final_text = str(final_answer)
|
||||
elif scratchpad.thought:
|
||||
# If no action but we have thought, use thought as final answer
|
||||
final_text = scratchpad.thought
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
"observation": scratchpad.observation or None,
|
||||
"final_answer": final_text if not react_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
|
||||
)
|
||||
|
||||
def _build_prompt_with_react_format(
|
||||
self,
|
||||
original_messages: list[PromptMessage],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
include_tools: bool = True,
|
||||
instruction: str = "",
|
||||
) -> list[PromptMessage]:
|
||||
"""Build prompt messages with ReAct format."""
|
||||
# Copy messages to avoid modifying original
|
||||
messages = list(original_messages)
|
||||
|
||||
# Find and update the system prompt that should already exist
|
||||
system_prompt_found = False
|
||||
for i, msg in enumerate(messages):
|
||||
if isinstance(msg, SystemPromptMessage):
|
||||
system_prompt_found = True
|
||||
# The system prompt from frontend already has the template, just replace placeholders
|
||||
|
||||
# Format tools
|
||||
tools_str = ""
|
||||
tool_names = []
|
||||
if include_tools and self.tools:
|
||||
# Convert tools to prompt message tools format
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
|
||||
tool_names = [tool.name for tool in prompt_tools]
|
||||
|
||||
# Format tools as JSON for comprehensive information
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2)
|
||||
tool_names_str = ", ".join(f'"{name}"' for name in tool_names)
|
||||
else:
|
||||
tools_str = "No tools available"
|
||||
tool_names_str = ""
|
||||
|
||||
# Replace placeholders in the existing system prompt
|
||||
updated_content = msg.content
|
||||
assert isinstance(updated_content, str)
|
||||
updated_content = updated_content.replace("{{instruction}}", instruction)
|
||||
updated_content = updated_content.replace("{{tools}}", tools_str)
|
||||
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
|
||||
|
||||
# Create new SystemPromptMessage with updated content
|
||||
messages[i] = SystemPromptMessage(content=updated_content)
|
||||
break
|
||||
|
||||
# If no system prompt found, that's unexpected but add scratchpad anyway
|
||||
if not system_prompt_found:
|
||||
# This shouldn't happen if frontend is working correctly
|
||||
pass
|
||||
|
||||
# Format agent scratchpad
|
||||
scratchpad_str = ""
|
||||
if agent_scratchpad:
|
||||
scratchpad_parts: list[str] = []
|
||||
for unit in agent_scratchpad:
|
||||
if unit.thought:
|
||||
scratchpad_parts.append(f"Thought: {unit.thought}")
|
||||
if unit.action_str:
|
||||
scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```")
|
||||
if unit.observation:
|
||||
scratchpad_parts.append(f"Observation: {unit.observation}")
|
||||
scratchpad_str = "\n".join(scratchpad_parts)
|
||||
|
||||
# If there's a scratchpad, append it to the last message
|
||||
if scratchpad_str:
|
||||
messages.append(AssistantPromptMessage(content=scratchpad_str))
|
||||
|
||||
return messages
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, Any],
|
||||
model_log: AgentLog,
|
||||
current_messages: list[PromptMessage],
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[AgentScratchpadUnit, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract action/thought.
|
||||
|
||||
Returns a tuple of (scratchpad_unit, finish_reason).
|
||||
"""
|
||||
usage_dict: dict[str, Any] = {}
|
||||
|
||||
# Convert non-streaming to streaming format if needed
|
||||
if isinstance(chunks, LLMResult):
|
||||
# Create a generator from the LLMResult
|
||||
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=chunks.model,
|
||||
prompt_messages=chunks.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=chunks.message,
|
||||
usage=chunks.usage,
|
||||
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
|
||||
),
|
||||
system_fingerprint=chunks.system_fingerprint or "",
|
||||
)
|
||||
|
||||
streaming_chunks = result_to_chunks()
|
||||
else:
|
||||
streaming_chunks = chunks
|
||||
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
|
||||
|
||||
# Initialize scratchpad unit
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Process chunks
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
# Action detected
|
||||
action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + action_str
|
||||
scratchpad.action_str = action_str
|
||||
scratchpad.action = chunk
|
||||
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
else:
|
||||
# Text chunk
|
||||
chunk_text = str(chunk)
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
||||
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
|
||||
# Update usage
|
||||
if usage_dict.get("usage"):
|
||||
if llm_usage.get("usage"):
|
||||
self._accumulate_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
llm_usage["usage"] = usage_dict["usage"]
|
||||
|
||||
# Clean up thought
|
||||
scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you"
|
||||
|
||||
# Finish model log
|
||||
yield self._finish_log(
|
||||
model_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
|
||||
return scratchpad, finish_reason
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
prompt_messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File]]]:
|
||||
"""Handle tool call and return observation with files."""
|
||||
tool_name = action.action_name
|
||||
tool_args: dict[str, Any] | str = action.action_input
|
||||
|
||||
# Find tool instance first to get metadata
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {}
|
||||
|
||||
# Start tool log with tool metadata
|
||||
tool_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_log
|
||||
|
||||
if not tool_instance:
|
||||
# Finish tool log with error
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"error": f"Tool {tool_name} not found",
|
||||
},
|
||||
)
|
||||
return f"Tool {tool_name} not found", []
|
||||
|
||||
# Ensure tool_args is a dict
|
||||
tool_args_dict: dict[str, Any]
|
||||
if isinstance(tool_args, str):
|
||||
try:
|
||||
tool_args_dict = json.loads(tool_args)
|
||||
except json.JSONDecodeError:
|
||||
tool_args_dict = {"input": tool_args}
|
||||
elif not isinstance(tool_args, dict):
|
||||
tool_args_dict = {"input": str(tool_args)}
|
||||
else:
|
||||
tool_args_dict = tool_args
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
|
||||
|
||||
# Finish tool log
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
|
||||
return response_content or "Tool executed successfully", tool_files
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_log.error = error_message
|
||||
tool_log.data = {
|
||||
**tool_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_log
|
||||
|
||||
return f"Tool execution failed: {error_message}", []
|
||||
@ -1,107 +0,0 @@
|
||||
"""Strategy factory for creating agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.file.models import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
"""Factory for creating agent strategies based on model features."""
|
||||
|
||||
# Tool calling related features
|
||||
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
|
||||
|
||||
@staticmethod
|
||||
def create_strategy(
|
||||
model_features: list[ModelFeature],
|
||||
model_instance: ModelInstance,
|
||||
context: ExecutionContext,
|
||||
tools: list[Tool],
|
||||
files: list[File],
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
agent_strategy: AgentEntity.Strategy | None = None,
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
) -> AgentPattern:
|
||||
"""
|
||||
Create an appropriate strategy based on model features.
|
||||
|
||||
Args:
|
||||
model_features: List of model features/capabilities
|
||||
model_instance: Model instance to use
|
||||
context: Execution context containing trace/audit information
|
||||
tools: Available tools
|
||||
files: Available files
|
||||
max_iterations: Maximum iterations for the strategy
|
||||
workflow_call_depth: Depth of workflow calls
|
||||
agent_strategy: Optional explicit strategy override
|
||||
tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke)
|
||||
instruction: Optional instruction for ReAct strategy
|
||||
|
||||
Returns:
|
||||
AgentStrategy instance
|
||||
"""
|
||||
# If explicit strategy is provided and it's Function Calling, try to use it if supported
|
||||
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
# Fallback to ReAct if FC is requested but not supported
|
||||
|
||||
# If explicit strategy is Chain of Thought (ReAct)
|
||||
if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Default auto-selection logic
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
# Model supports native function calling
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
else:
|
||||
# Use ReAct strategy for models without function calling
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari
|
||||
from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
@ -149,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=self._workflow.environment_variables,
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
# Based on the definition of `Variable`,
|
||||
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
@ -318,7 +318,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
trace_manager=app_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
def _initialize_conversation_variables(self) -> list[VariableUnion]:
|
||||
def _initialize_conversation_variables(self) -> list[Variable]:
|
||||
"""
|
||||
Initialize conversation variables for the current conversation.
|
||||
|
||||
@ -343,7 +343,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
conversation_variables = [var.to_variable() for var in existing_variables]
|
||||
|
||||
session.commit()
|
||||
return cast(list[VariableUnion], conversation_variables)
|
||||
return cast(list[Variable], conversation_variables)
|
||||
|
||||
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
|
||||
"""
|
||||
|
||||
@ -4,7 +4,6 @@ import re
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Any, Union
|
||||
|
||||
@ -20,7 +19,6 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
ChunkType,
|
||||
MessageQueueMessage,
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAgentLogEvent,
|
||||
@ -72,122 +70,13 @@ from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, LLMGenerationDetail, Message, MessageFile
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamEventBuffer:
|
||||
"""
|
||||
Buffer for recording stream events in order to reconstruct the generation sequence.
|
||||
Records the exact order of text chunks, thoughts, and tool calls as they stream.
|
||||
"""
|
||||
|
||||
# Accumulated reasoning content (each thought block is a separate element)
|
||||
reasoning_content: list[str] = field(default_factory=list)
|
||||
# Current reasoning buffer (accumulates until we see a different event type)
|
||||
_current_reasoning: str = ""
|
||||
# Tool calls with their details
|
||||
tool_calls: list[dict] = field(default_factory=list)
|
||||
# Tool call ID to index mapping for updating results
|
||||
_tool_call_id_map: dict[str, int] = field(default_factory=dict)
|
||||
# Sequence of events in stream order
|
||||
sequence: list[dict] = field(default_factory=list)
|
||||
# Current position in answer text
|
||||
_content_position: int = 0
|
||||
# Track last event type to detect transitions
|
||||
_last_event_type: str | None = None
|
||||
|
||||
def _flush_current_reasoning(self) -> None:
|
||||
"""Flush accumulated reasoning to the list and add to sequence."""
|
||||
if self._current_reasoning.strip():
|
||||
self.reasoning_content.append(self._current_reasoning.strip())
|
||||
self.sequence.append({"type": "reasoning", "index": len(self.reasoning_content) - 1})
|
||||
self._current_reasoning = ""
|
||||
|
||||
def record_text_chunk(self, text: str) -> None:
|
||||
"""Record a text chunk event."""
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Flush any pending reasoning first
|
||||
if self._last_event_type == "thought":
|
||||
self._flush_current_reasoning()
|
||||
|
||||
text_len = len(text)
|
||||
start_pos = self._content_position
|
||||
|
||||
# If last event was also content, extend it; otherwise create new
|
||||
if self.sequence and self.sequence[-1].get("type") == "content":
|
||||
self.sequence[-1]["end"] = start_pos + text_len
|
||||
else:
|
||||
self.sequence.append({"type": "content", "start": start_pos, "end": start_pos + text_len})
|
||||
|
||||
self._content_position += text_len
|
||||
self._last_event_type = "content"
|
||||
|
||||
def record_thought_chunk(self, text: str) -> None:
|
||||
"""Record a thought/reasoning chunk event."""
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Accumulate thought content
|
||||
self._current_reasoning += text
|
||||
self._last_event_type = "thought"
|
||||
|
||||
def record_tool_call(self, tool_call_id: str, tool_name: str, tool_arguments: str) -> None:
|
||||
"""Record a tool call event."""
|
||||
if not tool_call_id:
|
||||
return
|
||||
|
||||
# Flush any pending reasoning first
|
||||
if self._last_event_type == "thought":
|
||||
self._flush_current_reasoning()
|
||||
|
||||
# Check if this tool call already exists (we might get multiple chunks)
|
||||
if tool_call_id in self._tool_call_id_map:
|
||||
idx = self._tool_call_id_map[tool_call_id]
|
||||
# Update arguments if provided
|
||||
if tool_arguments:
|
||||
self.tool_calls[idx]["arguments"] = tool_arguments
|
||||
else:
|
||||
# New tool call
|
||||
tool_call = {
|
||||
"id": tool_call_id or "",
|
||||
"name": tool_name or "",
|
||||
"arguments": tool_arguments or "",
|
||||
"result": "",
|
||||
"elapsed_time": None,
|
||||
}
|
||||
self.tool_calls.append(tool_call)
|
||||
idx = len(self.tool_calls) - 1
|
||||
self._tool_call_id_map[tool_call_id] = idx
|
||||
self.sequence.append({"type": "tool_call", "index": idx})
|
||||
|
||||
self._last_event_type = "tool_call"
|
||||
|
||||
def record_tool_result(self, tool_call_id: str, result: str, tool_elapsed_time: float | None = None) -> None:
|
||||
"""Record a tool result event (update existing tool call)."""
|
||||
if not tool_call_id:
|
||||
return
|
||||
if tool_call_id in self._tool_call_id_map:
|
||||
idx = self._tool_call_id_map[tool_call_id]
|
||||
self.tool_calls[idx]["result"] = result
|
||||
self.tool_calls[idx]["elapsed_time"] = tool_elapsed_time
|
||||
|
||||
def finalize(self) -> None:
|
||||
"""Finalize the buffer, flushing any pending data."""
|
||||
if self._last_event_type == "thought":
|
||||
self._flush_current_reasoning()
|
||||
|
||||
def has_data(self) -> bool:
|
||||
"""Check if there's any meaningful data recorded."""
|
||||
return bool(self.reasoning_content or self.tool_calls or self.sequence)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
@ -255,8 +144,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._workflow_run_id: str = ""
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
self._graph_runtime_state: GraphRuntimeState | None = None
|
||||
# Stream event buffer for recording generation sequence
|
||||
self._stream_buffer = StreamEventBuffer()
|
||||
self._seed_graph_runtime_state_from_queue_manager()
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
@ -496,7 +383,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle text chunk events and record to stream buffer for sequence reconstruction."""
|
||||
"""Handle text chunk events."""
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
return
|
||||
@ -518,52 +405,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
tool_call = event.tool_call
|
||||
tool_result = event.tool_result
|
||||
tool_payload = tool_call or tool_result
|
||||
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else ""
|
||||
tool_name = tool_payload.name if tool_payload and tool_payload.name else ""
|
||||
tool_arguments = tool_call.arguments if tool_call and tool_call.arguments else ""
|
||||
tool_files = tool_result.files if tool_result else []
|
||||
tool_elapsed_time = tool_result.elapsed_time if tool_result else None
|
||||
tool_icon = tool_payload.icon if tool_payload else None
|
||||
tool_icon_dark = tool_payload.icon_dark if tool_payload else None
|
||||
# Record stream event based on chunk type
|
||||
chunk_type = event.chunk_type or ChunkType.TEXT
|
||||
match chunk_type:
|
||||
case ChunkType.TEXT:
|
||||
self._stream_buffer.record_text_chunk(delta_text)
|
||||
self._task_state.answer += delta_text
|
||||
case ChunkType.THOUGHT:
|
||||
# Reasoning should not be part of final answer text
|
||||
self._stream_buffer.record_thought_chunk(delta_text)
|
||||
case ChunkType.TOOL_CALL:
|
||||
self._stream_buffer.record_tool_call(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
)
|
||||
case ChunkType.TOOL_RESULT:
|
||||
self._stream_buffer.record_tool_result(
|
||||
tool_call_id=tool_call_id,
|
||||
result=delta_text,
|
||||
tool_elapsed_time=tool_elapsed_time,
|
||||
)
|
||||
self._task_state.answer += delta_text
|
||||
case _:
|
||||
pass
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=delta_text,
|
||||
message_id=self._message_id,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
chunk_type=event.chunk_type.value if event.chunk_type else None,
|
||||
tool_call_id=tool_call_id or None,
|
||||
tool_name=tool_name or None,
|
||||
tool_arguments=tool_arguments or None,
|
||||
tool_files=tool_files,
|
||||
tool_elapsed_time=tool_elapsed_time,
|
||||
tool_icon=tool_icon,
|
||||
tool_icon_dark=tool_icon_dark,
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
@ -931,7 +775,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
answer_text = self._task_state.answer
|
||||
answer_text = self._strip_think_blocks(answer_text)
|
||||
if self._recorded_files:
|
||||
# Remove markdown image links since we're storing files separately
|
||||
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
|
||||
@ -983,54 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
# Save generation detail (reasoning/tool calls/sequence) from stream buffer
|
||||
self._save_generation_detail(session=session, message=message)
|
||||
|
||||
@staticmethod
|
||||
def _strip_think_blocks(text: str) -> str:
|
||||
"""Remove <think>...</think> blocks (including their content) from text."""
|
||||
if not text or "<think" not in text.lower():
|
||||
return text
|
||||
|
||||
clean_text = re.sub(r"<think[^>]*>.*?</think>", "", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||
return clean_text
|
||||
|
||||
def _save_generation_detail(self, *, session: Session, message: Message) -> None:
|
||||
"""
|
||||
Save LLM generation detail for Chatflow using stream event buffer.
|
||||
The buffer records the exact order of events as they streamed,
|
||||
allowing accurate reconstruction of the generation sequence.
|
||||
"""
|
||||
# Finalize the stream buffer to flush any pending data
|
||||
self._stream_buffer.finalize()
|
||||
|
||||
# Only save if there's meaningful data
|
||||
if not self._stream_buffer.has_data():
|
||||
return
|
||||
|
||||
reasoning_content = self._stream_buffer.reasoning_content
|
||||
tool_calls = self._stream_buffer.tool_calls
|
||||
sequence = self._stream_buffer.sequence
|
||||
|
||||
# Check if generation detail already exists for this message
|
||||
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
|
||||
|
||||
if existing:
|
||||
existing.reasoning_content = json.dumps(reasoning_content) if reasoning_content else None
|
||||
existing.tool_calls = json.dumps(tool_calls) if tool_calls else None
|
||||
existing.sequence = json.dumps(sequence) if sequence else None
|
||||
else:
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
reasoning_content=json.dumps(reasoning_content) if reasoning_content else None,
|
||||
tool_calls=json.dumps(tool_calls) if tool_calls else None,
|
||||
sequence=json.dumps(sequence) if sequence else None,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
@ -3,8 +3,10 @@ from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
@ -12,7 +14,8 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
@ -191,7 +194,22 @@ class AgentChatAppRunner(AppRunner):
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
runner = AgentAppRunner(
|
||||
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
# check LLM mode
|
||||
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
|
||||
runner_cls = CotChatAgentRunner
|
||||
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
|
||||
runner_cls = CotCompletionAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
|
||||
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
runner_cls = FunctionCallAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
|
||||
|
||||
runner = runner_cls(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation_result,
|
||||
|
||||
@ -671,7 +671,7 @@ class WorkflowResponseConverter:
|
||||
task_id=task_id,
|
||||
data=AgentLogStreamResponse.Data(
|
||||
node_execution_id=event.node_execution_id,
|
||||
message_id=event.id,
|
||||
id=event.id,
|
||||
parent_id=event.parent_id,
|
||||
label=event.label,
|
||||
error=event.error,
|
||||
|
||||
@ -13,7 +13,6 @@ from core.app.apps.common.workflow_response_converter import WorkflowResponseCon
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
ChunkType,
|
||||
MessageQueueMessage,
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
@ -484,33 +483,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if delta_text is None:
|
||||
return
|
||||
|
||||
tool_call = event.tool_call
|
||||
tool_result = event.tool_result
|
||||
tool_payload = tool_call or tool_result
|
||||
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else None
|
||||
tool_name = tool_payload.name if tool_payload and tool_payload.name else None
|
||||
tool_arguments = tool_call.arguments if tool_call else None
|
||||
tool_elapsed_time = tool_result.elapsed_time if tool_result else None
|
||||
tool_files = tool_result.files if tool_result else []
|
||||
tool_icon = tool_payload.icon if tool_payload else None
|
||||
tool_icon_dark = tool_payload.icon_dark if tool_payload else None
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
yield self._text_chunk_to_stream_response(
|
||||
text=delta_text,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
chunk_type=event.chunk_type,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
tool_files=tool_files,
|
||||
tool_elapsed_time=tool_elapsed_time,
|
||||
tool_icon=tool_icon,
|
||||
tool_icon_dark=tool_icon_dark,
|
||||
)
|
||||
yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector)
|
||||
|
||||
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle agent log events."""
|
||||
@ -673,61 +650,16 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
session.add(workflow_app_log)
|
||||
|
||||
def _text_chunk_to_stream_response(
|
||||
self,
|
||||
text: str,
|
||||
from_variable_selector: list[str] | None = None,
|
||||
chunk_type: ChunkType | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_arguments: str | None = None,
|
||||
tool_files: list[str] | None = None,
|
||||
tool_error: str | None = None,
|
||||
tool_elapsed_time: float | None = None,
|
||||
tool_icon: str | dict | None = None,
|
||||
tool_icon_dark: str | dict | None = None,
|
||||
self, text: str, from_variable_selector: list[str] | None = None
|
||||
) -> TextChunkStreamResponse:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
from core.app.entities.task_entities import ChunkType as ResponseChunkType
|
||||
|
||||
response_chunk_type = ResponseChunkType(chunk_type.value) if chunk_type else ResponseChunkType.TEXT
|
||||
|
||||
data = TextChunkStreamResponse.Data(
|
||||
text=text,
|
||||
from_variable_selector=from_variable_selector,
|
||||
chunk_type=response_chunk_type,
|
||||
)
|
||||
|
||||
if response_chunk_type == ResponseChunkType.TOOL_CALL:
|
||||
data = data.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
elif response_chunk_type == ResponseChunkType.TOOL_RESULT:
|
||||
data = data.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_files": tool_files,
|
||||
"tool_error": tool_error,
|
||||
"tool_elapsed_time": tool_elapsed_time,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
|
||||
response = TextChunkStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=data,
|
||||
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -463,20 +463,12 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
from core.app.entities.queue_entities import ChunkType as QueueChunkType
|
||||
|
||||
if event.is_final and not event.chunk:
|
||||
return
|
||||
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk,
|
||||
from_variable_selector=list(event.selector),
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
chunk_type=QueueChunkType(event.chunk_type.value),
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
|
||||
@ -1,70 +0,0 @@
|
||||
"""
|
||||
LLM Generation Detail entities.
|
||||
|
||||
Defines the structure for storing and transmitting LLM generation details
|
||||
including reasoning content, tool calls, and their sequence.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ContentSegment(BaseModel):
|
||||
"""Represents a content segment in the generation sequence."""
|
||||
|
||||
type: Literal["content"] = "content"
|
||||
start: int = Field(..., description="Start position in the text")
|
||||
end: int = Field(..., description="End position in the text")
|
||||
|
||||
|
||||
class ReasoningSegment(BaseModel):
|
||||
"""Represents a reasoning segment in the generation sequence."""
|
||||
|
||||
type: Literal["reasoning"] = "reasoning"
|
||||
index: int = Field(..., description="Index into reasoning_content array")
|
||||
|
||||
|
||||
class ToolCallSegment(BaseModel):
|
||||
"""Represents a tool call segment in the generation sequence."""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
index: int = Field(..., description="Index into tool_calls array")
|
||||
|
||||
|
||||
SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment
|
||||
|
||||
|
||||
class ToolCallDetail(BaseModel):
|
||||
"""Represents a tool call with its arguments and result."""
|
||||
|
||||
id: str = Field(default="", description="Unique identifier for the tool call")
|
||||
name: str = Field(..., description="Name of the tool")
|
||||
arguments: str = Field(default="", description="JSON string of tool arguments")
|
||||
result: str = Field(default="", description="Result from the tool execution")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds")
|
||||
|
||||
|
||||
class LLMGenerationDetailData(BaseModel):
|
||||
"""
|
||||
Domain model for LLM generation detail.
|
||||
|
||||
Contains the structured data for reasoning content, tool calls,
|
||||
and their display sequence.
|
||||
"""
|
||||
|
||||
reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments")
|
||||
tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details")
|
||||
sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments")
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if there's any meaningful generation detail."""
|
||||
return not self.reasoning_content and not self.tool_calls
|
||||
|
||||
def to_response_dict(self) -> dict:
|
||||
"""Convert to dictionary for API response."""
|
||||
return {
|
||||
"reasoning_content": self.reasoning_content,
|
||||
"tool_calls": [tc.model_dump() for tc in self.tool_calls],
|
||||
"sequence": [seg.model_dump() for seg in self.sequence],
|
||||
}
|
||||
@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
@ -177,17 +177,6 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
THOUGHT_START = "thought_start" # Agent thought start
|
||||
THOUGHT_END = "thought_end" # Agent thought end
|
||||
|
||||
|
||||
class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueTextChunkEvent entity
|
||||
@ -202,16 +191,6 @@ class QueueTextChunkEvent(AppQueueEvent):
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
# Extended fields for Agent/Tool streaming
|
||||
chunk_type: ChunkType = ChunkType.TEXT
|
||||
"""type of the chunk"""
|
||||
|
||||
# Tool streaming payloads
|
||||
tool_call: ToolCall | None = None
|
||||
"""structured tool call info"""
|
||||
tool_result: ToolResult | None = None
|
||||
"""structured tool result info"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
@ -113,38 +113,6 @@ class MessageStreamResponse(StreamResponse):
|
||||
answer: str
|
||||
from_variable_selector: list[str] | None = None
|
||||
|
||||
# Extended fields for Agent/Tool streaming (imported at runtime to avoid circular import)
|
||||
chunk_type: str | None = None
|
||||
"""type of the chunk: text, tool_call, tool_result, thought"""
|
||||
|
||||
# Tool call fields (when chunk_type == "tool_call")
|
||||
tool_call_id: str | None = None
|
||||
"""unique identifier for this tool call"""
|
||||
tool_name: str | None = None
|
||||
"""name of the tool being called"""
|
||||
tool_arguments: str | None = None
|
||||
"""accumulated tool arguments JSON"""
|
||||
|
||||
# Tool result fields (when chunk_type == "tool_result")
|
||||
tool_files: list[str] | None = None
|
||||
"""file IDs produced by tool"""
|
||||
tool_error: str | None = None
|
||||
"""error message if tool failed"""
|
||||
tool_elapsed_time: float | None = None
|
||||
"""elapsed time spent executing the tool"""
|
||||
tool_icon: str | dict | None = None
|
||||
"""icon of the tool"""
|
||||
tool_icon_dark: str | dict | None = None
|
||||
"""dark theme icon of the tool"""
|
||||
|
||||
def model_dump(self, *args, **kwargs) -> dict[str, object]:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump(*args, **kwargs)
|
||||
|
||||
def model_dump_json(self, *args, **kwargs) -> str:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump_json(*args, **kwargs)
|
||||
|
||||
|
||||
class MessageAudioStreamResponse(StreamResponse):
|
||||
"""
|
||||
@ -614,17 +582,6 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
THOUGHT_START = "thought_start" # Agent thought start
|
||||
THOUGHT_END = "thought_end" # Agent thought end
|
||||
|
||||
|
||||
class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
TextChunkStreamResponse entity
|
||||
@ -638,36 +595,6 @@ class TextChunkStreamResponse(StreamResponse):
|
||||
text: str
|
||||
from_variable_selector: list[str] | None = None
|
||||
|
||||
# Extended fields for Agent/Tool streaming
|
||||
chunk_type: ChunkType = ChunkType.TEXT
|
||||
"""type of the chunk"""
|
||||
|
||||
# Tool call fields (when chunk_type == TOOL_CALL)
|
||||
tool_call_id: str | None = None
|
||||
"""unique identifier for this tool call"""
|
||||
tool_name: str | None = None
|
||||
"""name of the tool being called"""
|
||||
tool_arguments: str | None = None
|
||||
"""accumulated tool arguments JSON"""
|
||||
|
||||
# Tool result fields (when chunk_type == TOOL_RESULT)
|
||||
tool_files: list[str] | None = None
|
||||
"""file IDs produced by tool"""
|
||||
tool_error: str | None = None
|
||||
"""error message if tool failed"""
|
||||
|
||||
# Tool elapsed time fields (when chunk_type == TOOL_RESULT)
|
||||
tool_elapsed_time: float | None = None
|
||||
"""elapsed time spent executing the tool"""
|
||||
|
||||
def model_dump(self, *args, **kwargs) -> dict[str, object]:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump(*args, **kwargs)
|
||||
|
||||
def model_dump_json(self, *args, **kwargs) -> str:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump_json(*args, **kwargs)
|
||||
|
||||
event: StreamEvent = StreamEvent.TEXT_CHUNK
|
||||
data: Data
|
||||
|
||||
@ -816,7 +743,7 @@ class AgentLogStreamResponse(StreamResponse):
|
||||
"""
|
||||
|
||||
node_execution_id: str
|
||||
message_id: str
|
||||
id: str
|
||||
label: str
|
||||
parent_id: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from core.variables import Variable
|
||||
from core.variables import VariableBase
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import NodeType
|
||||
@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
|
||||
continue
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, Variable):
|
||||
if not isinstance(variable, VariableBase):
|
||||
logger.warning(
|
||||
"Conversation variable not found in variable pool. selector=%s",
|
||||
selector,
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
@ -59,7 +58,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import AppMode, Conversation, LLMGenerationDetail, Message, MessageAgentThought
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -69,8 +68,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||
|
||||
@ -412,136 +409,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
)
|
||||
)
|
||||
|
||||
# Save LLM generation detail if there's reasoning_content
|
||||
self._save_generation_detail(session=session, message=message, llm_result=llm_result)
|
||||
|
||||
message_was_created.send(
|
||||
message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
)
|
||||
|
||||
def _save_generation_detail(self, *, session: Session, message: Message, llm_result: LLMResult) -> None:
|
||||
"""
|
||||
Save LLM generation detail for Completion/Chat/Agent-Chat applications.
|
||||
For Agent-Chat, also merges MessageAgentThought records.
|
||||
"""
|
||||
import json
|
||||
|
||||
reasoning_list: list[str] = []
|
||||
tool_calls_list: list[dict] = []
|
||||
sequence: list[dict] = []
|
||||
answer = message.answer or ""
|
||||
|
||||
# Check if this is Agent-Chat mode by looking for agent thoughts
|
||||
agent_thoughts = (
|
||||
session.query(MessageAgentThought)
|
||||
.filter_by(message_id=message.id)
|
||||
.order_by(MessageAgentThought.position.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
if agent_thoughts:
|
||||
# Agent-Chat mode: merge MessageAgentThought records
|
||||
content_pos = 0
|
||||
cleaned_answer_parts: list[str] = []
|
||||
for thought in agent_thoughts:
|
||||
# Add thought/reasoning
|
||||
if thought.thought:
|
||||
reasoning_text = thought.thought
|
||||
if "<think" in reasoning_text.lower():
|
||||
clean_text, extracted_reasoning = self._split_reasoning_from_answer(reasoning_text)
|
||||
if extracted_reasoning:
|
||||
reasoning_text = extracted_reasoning
|
||||
thought.thought = clean_text or extracted_reasoning
|
||||
reasoning_list.append(reasoning_text)
|
||||
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
|
||||
|
||||
# Add tool calls
|
||||
if thought.tool:
|
||||
tool_calls_list.append(
|
||||
{
|
||||
"name": thought.tool,
|
||||
"arguments": thought.tool_input or "",
|
||||
"result": thought.observation or "",
|
||||
}
|
||||
)
|
||||
sequence.append({"type": "tool_call", "index": len(tool_calls_list) - 1})
|
||||
|
||||
# Add answer content if present
|
||||
if thought.answer:
|
||||
content_text = thought.answer
|
||||
if "<think" in content_text.lower():
|
||||
clean_answer, extracted_reasoning = self._split_reasoning_from_answer(content_text)
|
||||
if extracted_reasoning:
|
||||
reasoning_list.append(extracted_reasoning)
|
||||
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
|
||||
content_text = clean_answer
|
||||
thought.answer = clean_answer or content_text
|
||||
|
||||
if content_text:
|
||||
start = content_pos
|
||||
end = content_pos + len(content_text)
|
||||
sequence.append({"type": "content", "start": start, "end": end})
|
||||
content_pos = end
|
||||
cleaned_answer_parts.append(content_text)
|
||||
|
||||
if cleaned_answer_parts:
|
||||
merged_answer = "".join(cleaned_answer_parts)
|
||||
message.answer = merged_answer
|
||||
llm_result.message.content = merged_answer
|
||||
else:
|
||||
# Completion/Chat mode: use reasoning_content from llm_result
|
||||
reasoning_content = llm_result.reasoning_content
|
||||
if not reasoning_content and answer:
|
||||
# Extract reasoning from <think> blocks and clean the final answer
|
||||
clean_answer, reasoning_content = self._split_reasoning_from_answer(answer)
|
||||
if reasoning_content:
|
||||
answer = clean_answer
|
||||
llm_result.message.content = clean_answer
|
||||
llm_result.reasoning_content = reasoning_content
|
||||
message.answer = clean_answer
|
||||
if reasoning_content:
|
||||
reasoning_list = [reasoning_content]
|
||||
# Content comes first, then reasoning
|
||||
if answer:
|
||||
sequence.append({"type": "content", "start": 0, "end": len(answer)})
|
||||
sequence.append({"type": "reasoning", "index": 0})
|
||||
|
||||
# Only save if there's meaningful generation detail
|
||||
if not reasoning_list and not tool_calls_list:
|
||||
return
|
||||
|
||||
# Check if generation detail already exists
|
||||
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
|
||||
|
||||
if existing:
|
||||
existing.reasoning_content = json.dumps(reasoning_list) if reasoning_list else None
|
||||
existing.tool_calls = json.dumps(tool_calls_list) if tool_calls_list else None
|
||||
existing.sequence = json.dumps(sequence) if sequence else None
|
||||
else:
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
reasoning_content=json.dumps(reasoning_list) if reasoning_list else None,
|
||||
tool_calls=json.dumps(tool_calls_list) if tool_calls_list else None,
|
||||
sequence=json.dumps(sequence) if sequence else None,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
@classmethod
|
||||
def _split_reasoning_from_answer(cls, text: str) -> tuple[str, str]:
|
||||
"""
|
||||
Extract reasoning segments from <think> blocks and return (clean_text, reasoning).
|
||||
"""
|
||||
matches = cls._THINK_PATTERN.findall(text)
|
||||
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
|
||||
|
||||
clean_text = cls._THINK_PATTERN.sub("", text)
|
||||
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||
|
||||
return clean_text, reasoning_content or ""
|
||||
|
||||
def _handle_stop(self, event: QueueStopEvent):
|
||||
"""
|
||||
Handle stop.
|
||||
|
||||
@ -232,31 +232,15 @@ class MessageCycleManager:
|
||||
answer: str,
|
||||
message_id: str,
|
||||
from_variable_selector: list[str] | None = None,
|
||||
chunk_type: str | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_arguments: str | None = None,
|
||||
tool_files: list[str] | None = None,
|
||||
tool_error: str | None = None,
|
||||
tool_elapsed_time: float | None = None,
|
||||
tool_icon: str | dict | None = None,
|
||||
tool_icon_dark: str | dict | None = None,
|
||||
event_type: StreamEvent | None = None,
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
Message to stream response.
|
||||
:param answer: answer
|
||||
:param message_id: message id
|
||||
:param from_variable_selector: from variable selector
|
||||
:param chunk_type: type of the chunk (text, function_call, tool_result, thought)
|
||||
:param tool_call_id: unique identifier for this tool call
|
||||
:param tool_name: name of the tool being called
|
||||
:param tool_arguments: accumulated tool arguments JSON
|
||||
:param tool_files: file IDs produced by tool
|
||||
:param tool_error: error message if tool failed
|
||||
:return:
|
||||
"""
|
||||
response = MessageStreamResponse(
|
||||
return MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer,
|
||||
@ -264,35 +248,6 @@ class MessageCycleManager:
|
||||
event=event_type or StreamEvent.MESSAGE,
|
||||
)
|
||||
|
||||
if chunk_type:
|
||||
response = response.model_copy(update={"chunk_type": chunk_type})
|
||||
|
||||
if chunk_type == "tool_call":
|
||||
response = response.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
elif chunk_type == "tool_result":
|
||||
response = response.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_files": tool_files,
|
||||
"tool_error": tool_error,
|
||||
"tool_elapsed_time": tool_elapsed_time,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
"""
|
||||
Message replace to stream response.
|
||||
|
||||
@ -5,6 +5,7 @@ from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.models.document import Document
|
||||
@ -89,8 +90,6 @@ class DatasetIndexToolCallbackHandler:
|
||||
# TODO(-LAN-): Improve type check
|
||||
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
@ -3,7 +3,6 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
class PreviewDetail(BaseModel):
|
||||
content: str
|
||||
summary: str | None = None
|
||||
child_chunks: list[str] | None = None
|
||||
|
||||
|
||||
|
||||
@ -33,6 +33,10 @@ class MaxRetriesExceededError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
request_error = httpx.RequestError
|
||||
max_retries_exceeded_error = MaxRetriesExceededError
|
||||
|
||||
|
||||
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
|
||||
return {
|
||||
"http://": httpx.HTTPTransport(
|
||||
|
||||
@ -311,18 +311,14 @@ class IndexingRunner:
|
||||
qa_preview_texts: list[QAPreviewDetail] = []
|
||||
|
||||
total_segments = 0
|
||||
# doc_form represents the segmentation method (general, parent-child, QA)
|
||||
index_type = doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# one extract_setting is one source document
|
||||
for extract_setting in extract_settings:
|
||||
# extract
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
# Extract document content
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
|
||||
# Cleaning and segmentation
|
||||
documents = index_processor.transform(
|
||||
text_docs,
|
||||
current_user=None,
|
||||
@ -365,12 +361,6 @@ class IndexingRunner:
|
||||
|
||||
if doc_form and doc_form == "qa_model":
|
||||
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
|
||||
|
||||
# Generate summary preview
|
||||
summary_index_setting = tmp_processing_rule["summary_index_setting"] if "summary_index_setting" in tmp_processing_rule else None
|
||||
if summary_index_setting and summary_index_setting.get('enable') and preview_texts:
|
||||
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
|
||||
|
||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
|
||||
|
||||
def _extract(
|
||||
|
||||
@ -434,6 +434,3 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
|
||||
You should edit the prompt according to the IDEAL OUTPUT."""
|
||||
|
||||
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
|
||||
|
||||
DEFAULT_GENERATOR_SUMMARY_PROMPT = """
|
||||
You are a helpful assistant that summarizes long pieces of text into concise summaries. Given the following text, generate a brief summary that captures the main points and key information. The summary should be clear, concise, and written in complete sentences. """
|
||||
|
||||
@ -21,6 +21,7 @@ from core.model_runtime.model_providers.__base.speech2text_model import Speech2T
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.workflow.utils.generator_timeout import with_first_token_timeout
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import ProviderType
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
@ -109,6 +110,7 @@ class ModelInstance:
|
||||
stream: Literal[True] = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
first_token_timeout: float | None = None,
|
||||
) -> Generator: ...
|
||||
|
||||
@overload
|
||||
@ -121,6 +123,7 @@ class ModelInstance:
|
||||
stream: Literal[False] = False,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
first_token_timeout: float | None = None,
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
@ -133,6 +136,7 @@ class ModelInstance:
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
first_token_timeout: float | None = None,
|
||||
) -> Union[LLMResult, Generator]: ...
|
||||
|
||||
def invoke_llm(
|
||||
@ -144,6 +148,7 @@ class ModelInstance:
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
first_token_timeout: float | None = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -155,26 +160,31 @@ class ModelInstance:
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:param first_token_timeout: timeout in seconds for receiving first token (streaming only)
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
|
||||
result = self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
# Apply first token timeout wrapper for streaming responses
|
||||
if stream and first_token_timeout and first_token_timeout > 0 and isinstance(result, Generator):
|
||||
result = with_first_token_timeout(result, first_token_timeout)
|
||||
|
||||
return cast(Union[LLMResult, Generator], result)
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self, prompt_messages: Sequence[PromptMessage], tools: Sequence[PromptMessageTool] | None = None
|
||||
) -> int:
|
||||
|
||||
@ -55,7 +55,7 @@ from core.ops.entities.trace_entity import (
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
@ -275,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from core.plugin.entities.endpoint import EndpointEntityWithInstance
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.impl.exc import PluginDaemonInternalServerError
|
||||
|
||||
|
||||
class PluginEndpointClient(BasePluginClient):
|
||||
@ -70,18 +71,27 @@ class PluginEndpointClient(BasePluginClient):
|
||||
def delete_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
"""
|
||||
Delete the given endpoint.
|
||||
|
||||
This operation is idempotent: if the endpoint is already deleted (record not found),
|
||||
it will return True instead of raising an error.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/endpoint/remove",
|
||||
bool,
|
||||
data={
|
||||
"endpoint_id": endpoint_id,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
try:
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/endpoint/remove",
|
||||
bool,
|
||||
data={
|
||||
"endpoint_id": endpoint_id,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
except PluginDaemonInternalServerError as e:
|
||||
# Make delete idempotent: if record is not found, consider it a success
|
||||
if "record not found" in str(e.description).lower():
|
||||
return True
|
||||
raise
|
||||
|
||||
def enable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
"""
|
||||
|
||||
@ -392,69 +392,6 @@ class RetrievalService:
|
||||
records = []
|
||||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
segment_file_map = {}
|
||||
segment_summary_map = {} # Map segment_id to summary content
|
||||
summary_segment_ids = set() # Track segments retrieved via summary
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Process documents
|
||||
for document in documents:
|
||||
segment_id = None
|
||||
attachment_info = None
|
||||
child_chunk = None
|
||||
document_id = document.metadata.get("document_id")
|
||||
if document_id not in dataset_documents:
|
||||
continue
|
||||
|
||||
dataset_document = dataset_documents[document_id]
|
||||
if not dataset_document:
|
||||
continue
|
||||
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
# Handle parent-child documents
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
attachment_info_dict = cls.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
attachment_info = attachment_info_dict["attachment_info"]
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
else:
|
||||
# Check if this is a summary document
|
||||
is_summary = document.metadata.get("is_summary", False)
|
||||
if is_summary:
|
||||
# For summary documents, find the original chunk via original_chunk_id
|
||||
original_chunk_id = document.metadata.get("original_chunk_id")
|
||||
if not original_chunk_id:
|
||||
continue
|
||||
segment_id = original_chunk_id
|
||||
# Track that this segment was retrieved via summary
|
||||
summary_segment_ids.add(segment_id)
|
||||
else:
|
||||
# For normal documents, find by child chunk index_node_id
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
|
||||
child_chunk = session.scalar(child_chunk_stmt)
|
||||
|
||||
if not child_chunk:
|
||||
continue
|
||||
segment_id = child_chunk.segment_id
|
||||
|
||||
if not segment_id:
|
||||
continue
|
||||
|
||||
segment = (
|
||||
session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == segment_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
valid_dataset_documents = {}
|
||||
image_doc_ids: list[Any] = []
|
||||
@ -570,47 +507,7 @@ class RetrievalService:
|
||||
max_score = max(
|
||||
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
|
||||
)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
if segment:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
else:
|
||||
# Check if this is a summary document
|
||||
is_summary = document.metadata.get("is_summary", False)
|
||||
if is_summary:
|
||||
# For summary documents, find the original chunk via original_chunk_id
|
||||
original_chunk_id = document.metadata.get("original_chunk_id")
|
||||
if not original_chunk_id:
|
||||
continue
|
||||
# Track that this segment was retrieved via summary
|
||||
summary_segment_ids.add(original_chunk_id)
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == original_chunk_id,
|
||||
)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
else:
|
||||
# For normal documents, find by index_node_id
|
||||
index_node_id = document.metadata.get("doc_id")
|
||||
if not index_node_id:
|
||||
continue
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata.get("score"), # type: ignore
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": max_score,
|
||||
"child_chunks": child_chunk_details,
|
||||
@ -645,23 +542,6 @@ class RetrievalService:
|
||||
if record["segment"].id in attachment_map:
|
||||
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
|
||||
|
||||
# Batch query summaries for segments retrieved via summary (only enabled summaries)
|
||||
if summary_segment_ids:
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
DocumentSegmentSummary.chunk_id.in_(summary_segment_ids),
|
||||
DocumentSegmentSummary.status == "completed",
|
||||
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for summary in summaries:
|
||||
if summary.summary_content:
|
||||
segment_summary_map[summary.chunk_id] = summary.summary_content
|
||||
|
||||
result: list[RetrievalSegments] = []
|
||||
for record in records:
|
||||
# Extract segment
|
||||
@ -696,16 +576,9 @@ class RetrievalService:
|
||||
else None
|
||||
)
|
||||
|
||||
# Extract summary if this segment was retrieved via summary
|
||||
summary_content = segment_summary_map.get(segment.id)
|
||||
|
||||
# Create RetrievalSegments object
|
||||
retrieval_segment = RetrievalSegments(
|
||||
segment=segment,
|
||||
child_chunks=child_chunks_list,
|
||||
score=score,
|
||||
files=files,
|
||||
summary=summary_content
|
||||
segment=segment, child_chunks=child_chunks_list, score=score, files=files
|
||||
)
|
||||
result.append(retrieval_segment)
|
||||
|
||||
|
||||
@ -20,4 +20,3 @@ class RetrievalSegments(BaseModel):
|
||||
child_chunks: list[RetrievalChildChunk] | None = None
|
||||
score: float | None = None
|
||||
files: list[dict[str, str | int]] | None = None
|
||||
summary: str | None = None # Summary content if retrieved via summary index
|
||||
|
||||
@ -13,7 +13,6 @@ from urllib.parse import unquote, urlparse
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -46,15 +45,6 @@ class BaseIndexProcessor(ABC):
|
||||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def generate_summary_preview(self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
|
||||
The summary can be stored in a new attribute, e.g., summary.
|
||||
This method should be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load(
|
||||
self,
|
||||
|
||||
@ -1,13 +1,9 @@
|
||||
"""Paragraph index processor."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@ -21,19 +17,12 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_runtime.entities.message_entities import UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.model_manager import ModelInstance
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
@ -119,29 +108,6 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
keyword.add_texts(documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
@ -261,70 +227,3 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
else:
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
def generate_summary_preview(self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment, concurrently call generate_summary to generate a summary
|
||||
and write it to the summary attribute of PreviewDetail.
|
||||
"""
|
||||
import concurrent.futures
|
||||
from flask import current_app
|
||||
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def process(preview: PreviewDetail) -> None:
|
||||
"""Generate summary for a single preview item."""
|
||||
try:
|
||||
if flask_app:
|
||||
# Ensure Flask app context in worker thread
|
||||
with flask_app.app_context():
|
||||
summary = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
preview.summary = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
preview.summary = summary
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary for preview: {str(e)}")
|
||||
# Don't fail the entire preview if summary generation fails
|
||||
preview.summary = None
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
list(executor.map(process, preview_texts))
|
||||
return preview_texts
|
||||
|
||||
@staticmethod
|
||||
def generate_summary(tenant_id: str, text: str, summary_index_setting: dict = None) -> str:
|
||||
"""
|
||||
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt.
|
||||
"""
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
raise ValueError("summary_index_setting is required and must be enabled to generate summary.")
|
||||
|
||||
model_name = summary_index_setting.get("model_name")
|
||||
model_provider_name = summary_index_setting.get("model_provider_name")
|
||||
summary_prompt = summary_index_setting.get("summary_prompt")
|
||||
|
||||
# Import default summary prompt
|
||||
if not summary_prompt:
|
||||
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
|
||||
prompt = f"{summary_prompt}\n{text}"
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(tenant_id, model_provider_name, ModelType.LLM)
|
||||
model_instance = ModelInstance(provider_model_bundle, model_name)
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={},
|
||||
stream=False
|
||||
)
|
||||
|
||||
return getattr(result.message, "content", "")
|
||||
|
||||
@ -25,7 +25,6 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
@ -136,29 +135,6 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# node_ids is segment's node_ids
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
delete_child_chunks = kwargs.get("delete_child_chunks") or False
|
||||
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
|
||||
|
||||
@ -25,10 +25,9 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -145,30 +144,6 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Note: qa_model doesn't generate summaries, but we clean them for completeness
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
vector.delete_by_ids(node_ids)
|
||||
|
||||
@ -29,7 +29,6 @@ from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
LLMGenerationDetail,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
)
|
||||
@ -458,113 +457,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
session.merge(db_model)
|
||||
session.flush()
|
||||
|
||||
# Save LLMGenerationDetail for LLM nodes with successful execution
|
||||
if (
|
||||
domain_model.node_type == NodeType.LLM
|
||||
and domain_model.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
and domain_model.outputs is not None
|
||||
):
|
||||
self._save_llm_generation_detail(session, domain_model)
|
||||
|
||||
def _save_llm_generation_detail(self, session, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save LLM generation detail for LLM nodes.
|
||||
Extracts reasoning_content, tool_calls, and sequence from outputs and metadata.
|
||||
"""
|
||||
outputs = execution.outputs or {}
|
||||
metadata = execution.metadata or {}
|
||||
|
||||
reasoning_list = self._extract_reasoning(outputs)
|
||||
tool_calls_list = self._extract_tool_calls(metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG))
|
||||
|
||||
if not reasoning_list and not tool_calls_list:
|
||||
return
|
||||
|
||||
sequence = self._build_generation_sequence(outputs.get("text", ""), reasoning_list, tool_calls_list)
|
||||
self._upsert_generation_detail(session, execution, reasoning_list, tool_calls_list, sequence)
|
||||
|
||||
def _extract_reasoning(self, outputs: Mapping[str, Any]) -> list[str]:
|
||||
"""Extract reasoning_content as a clean list of non-empty strings."""
|
||||
reasoning_content = outputs.get("reasoning_content")
|
||||
if isinstance(reasoning_content, str):
|
||||
trimmed = reasoning_content.strip()
|
||||
return [trimmed] if trimmed else []
|
||||
if isinstance(reasoning_content, list):
|
||||
return [item.strip() for item in reasoning_content if isinstance(item, str) and item.strip()]
|
||||
return []
|
||||
|
||||
def _extract_tool_calls(self, agent_log: Any) -> list[dict[str, str]]:
|
||||
"""Extract tool call records from agent logs."""
|
||||
if not agent_log or not isinstance(agent_log, list):
|
||||
return []
|
||||
|
||||
tool_calls: list[dict[str, str]] = []
|
||||
for log in agent_log:
|
||||
log_data = log.data if hasattr(log, "data") else (log.get("data", {}) if isinstance(log, dict) else {})
|
||||
tool_name = log_data.get("tool_name")
|
||||
if tool_name and str(tool_name).strip():
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": log_data.get("tool_call_id", ""),
|
||||
"name": tool_name,
|
||||
"arguments": json.dumps(log_data.get("tool_args", {})),
|
||||
"result": str(log_data.get("output", "")),
|
||||
}
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
def _build_generation_sequence(
|
||||
self, text: str, reasoning_list: list[str], tool_calls_list: list[dict[str, str]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a simple content/reasoning/tool_call sequence."""
|
||||
sequence: list[dict[str, Any]] = []
|
||||
if text:
|
||||
sequence.append({"type": "content", "start": 0, "end": len(text)})
|
||||
for index in range(len(reasoning_list)):
|
||||
sequence.append({"type": "reasoning", "index": index})
|
||||
for index in range(len(tool_calls_list)):
|
||||
sequence.append({"type": "tool_call", "index": index})
|
||||
return sequence
|
||||
|
||||
def _upsert_generation_detail(
|
||||
self,
|
||||
session,
|
||||
execution: WorkflowNodeExecution,
|
||||
reasoning_list: list[str],
|
||||
tool_calls_list: list[dict[str, str]],
|
||||
sequence: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Insert or update LLMGenerationDetail with serialized fields."""
|
||||
existing = (
|
||||
session.query(LLMGenerationDetail)
|
||||
.filter_by(
|
||||
workflow_run_id=execution.workflow_execution_id,
|
||||
node_id=execution.node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
reasoning_json = json.dumps(reasoning_list) if reasoning_list else None
|
||||
tool_calls_json = json.dumps(tool_calls_list) if tool_calls_list else None
|
||||
sequence_json = json.dumps(sequence) if sequence else None
|
||||
|
||||
if existing:
|
||||
existing.reasoning_content = reasoning_json
|
||||
existing.tool_calls = tool_calls_json
|
||||
existing.sequence = sequence_json
|
||||
return
|
||||
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_run_id=execution.workflow_execution_id,
|
||||
node_id=execution.node_id,
|
||||
reasoning_content=reasoning_json,
|
||||
tool_calls=tool_calls_json,
|
||||
sequence=sequence_json,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
def get_db_models_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any
|
||||
if TYPE_CHECKING:
|
||||
from models.model import File
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
@ -155,60 +154,6 @@ class Tool(ABC):
|
||||
|
||||
return parameters
|
||||
|
||||
def to_prompt_message_tool(self) -> PromptMessageTool:
|
||||
message_tool = PromptMessageTool(
|
||||
name=self.entity.identity.name,
|
||||
description=self.entity.description.llm if self.entity.description else "",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
|
||||
parameters = self.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = parameter.type.as_normal_type()
|
||||
if parameter.type in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}:
|
||||
# Determine the description based on parameter type
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
file_format_desc = " Input the file id with format: [File: file_id]."
|
||||
else:
|
||||
file_format_desc = "Input the file id with format: [Files: file_id1, file_id2, ...]. "
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = {
|
||||
"type": "string",
|
||||
"description": (parameter.llm_description or "") + file_format_desc,
|
||||
}
|
||||
continue
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = (
|
||||
{
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
if parameter.input_schema is None
|
||||
else parameter.input_schema
|
||||
)
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return message_tool
|
||||
|
||||
def create_image_message(
|
||||
self,
|
||||
image: str,
|
||||
|
||||
@ -7,8 +7,8 @@ from typing import Any, cast
|
||||
|
||||
from flask import has_request_context
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.tools.__base.tool import Tool
|
||||
@ -20,7 +20,6 @@ from core.tools.entities.tool_entities import (
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant
|
||||
@ -230,30 +229,32 @@ class WorkflowTool(Tool):
|
||||
"""
|
||||
Resolve user from database (worker/Celery context).
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
|
||||
tenant = session.scalar(tenant_stmt)
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
user_stmt = select(Account).where(Account.id == user_id)
|
||||
user = session.scalar(user_stmt)
|
||||
if user:
|
||||
user.current_tenant = tenant
|
||||
session.expunge(user)
|
||||
return user
|
||||
|
||||
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
|
||||
end_user = session.scalar(end_user_stmt)
|
||||
if end_user:
|
||||
session.expunge(end_user)
|
||||
return end_user
|
||||
|
||||
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
|
||||
tenant = db.session.scalar(tenant_stmt)
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
user_stmt = select(Account).where(Account.id == user_id)
|
||||
user = db.session.scalar(user_stmt)
|
||||
if user:
|
||||
user.current_tenant = tenant
|
||||
return user
|
||||
|
||||
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
|
||||
end_user = db.session.scalar(end_user_stmt)
|
||||
if end_user:
|
||||
return end_user
|
||||
|
||||
return None
|
||||
|
||||
def _get_workflow(self, app_id: str, version: str) -> Workflow:
|
||||
"""
|
||||
get the workflow by app id and version
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
if not version:
|
||||
stmt = (
|
||||
select(Workflow)
|
||||
@ -265,22 +266,24 @@ class WorkflowTool(Tool):
|
||||
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("workflow not found or not published")
|
||||
if not workflow:
|
||||
raise ValueError("workflow not found or not published")
|
||||
|
||||
return workflow
|
||||
session.expunge(workflow)
|
||||
return workflow
|
||||
|
||||
def _get_app(self, app_id: str) -> App:
|
||||
"""
|
||||
get the app by app id
|
||||
"""
|
||||
stmt = select(App).where(App.id == app_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
return app
|
||||
session.expunge(app)
|
||||
return app
|
||||
|
||||
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
|
||||
"""
|
||||
|
||||
@ -30,6 +30,7 @@ from .variables import (
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
VariableBase,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -62,4 +63,5 @@ __all__ = [
|
||||
"StringSegment",
|
||||
"StringVariable",
|
||||
"Variable",
|
||||
"VariableBase",
|
||||
]
|
||||
|
||||
@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
|
||||
# - The union must include all non-abstract subclasses of `Segment`, except:
|
||||
# - `SegmentGroup`, which is not added to the variable pool.
|
||||
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
|
||||
# - `VariableBase` and its subclasses, which are handled by `Variable`.
|
||||
SegmentUnion: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[NoneSegment, Tag(SegmentType.NONE)]
|
||||
|
||||
@ -27,7 +27,7 @@ from .segments import (
|
||||
from .types import SegmentType
|
||||
|
||||
|
||||
class Variable(Segment):
|
||||
class VariableBase(Segment):
|
||||
"""
|
||||
A variable is a segment that has a name.
|
||||
|
||||
@ -45,23 +45,23 @@ class Variable(Segment):
|
||||
selector: Sequence[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StringVariable(StringSegment, Variable):
|
||||
class StringVariable(StringSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
class FloatVariable(FloatSegment, Variable):
|
||||
class FloatVariable(FloatSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
class IntegerVariable(IntegerSegment, Variable):
|
||||
class IntegerVariable(IntegerSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
class ObjectVariable(ObjectSegment, Variable):
|
||||
class ObjectVariable(ObjectSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayVariable(ArraySegment, Variable):
|
||||
class ArrayVariable(ArraySegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
@ -89,16 +89,16 @@ class SecretVariable(StringVariable):
|
||||
return encrypter.obfuscated_token(self.value)
|
||||
|
||||
|
||||
class NoneVariable(NoneSegment, Variable):
|
||||
class NoneVariable(NoneSegment, VariableBase):
|
||||
value_type: SegmentType = SegmentType.NONE
|
||||
value: None = None
|
||||
|
||||
|
||||
class FileVariable(FileSegment, Variable):
|
||||
class FileVariable(FileSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
class BooleanVariable(BooleanSegment, Variable):
|
||||
class BooleanVariable(BooleanSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel):
|
||||
value: Any
|
||||
|
||||
|
||||
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
|
||||
# Use `Variable` for type hinting when serialization is not required.
|
||||
# The `Variable` type is used to enable serialization and deserialization with Pydantic.
|
||||
# Use `VariableBase` for type hinting when serialization is not required.
|
||||
#
|
||||
# Note:
|
||||
# - All variants in `VariableUnion` must inherit from the `Variable` class.
|
||||
# - The union must include all non-abstract subclasses of `Segment`, except:
|
||||
VariableUnion: TypeAlias = Annotated[
|
||||
# - All variants in `Variable` must inherit from the `VariableBase` class.
|
||||
# - The union must include all non-abstract subclasses of `VariableBase`.
|
||||
Variable: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[NoneVariable, Tag(SegmentType.NONE)]
|
||||
| Annotated[StringVariable, Tag(SegmentType.STRING)]
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import abc
|
||||
from typing import Protocol
|
||||
|
||||
from core.variables import Variable
|
||||
from core.variables import VariableBase
|
||||
|
||||
|
||||
class ConversationVariableUpdater(Protocol):
|
||||
@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, conversation_id: str, variable: "Variable"):
|
||||
def update(self, conversation_id: str, variable: "VariableBase"):
|
||||
"""
|
||||
Updates the value of the specified conversation variable in the underlying storage.
|
||||
|
||||
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
|
||||
:param variable: The `Variable` instance containing the updated value.
|
||||
:param variable: The `VariableBase` instance containing the updated value.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@ -1,16 +1,11 @@
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"ToolCall",
|
||||
"ToolCallResult",
|
||||
"ToolResult",
|
||||
"ToolResultStatus",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
]
|
||||
|
||||
@ -1,39 +0,0 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File
|
||||
|
||||
|
||||
class ToolResultStatus(StrEnum):
|
||||
SUCCESS = "success"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str | None = Field(default=None, description="Unique identifier for this tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool being called")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
icon: str | dict | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[str] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")
|
||||
icon: str | dict | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
|
||||
|
||||
class ToolCallResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier for the tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[File] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")
|
||||
@ -251,8 +251,6 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
LLM_CONTENT_SEQUENCE = "llm_content_sequence"
|
||||
LLM_TRACE = "llm_trace"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.variables.variables import Variable
|
||||
|
||||
|
||||
class CommandType(StrEnum):
|
||||
@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand):
|
||||
class VariableUpdate(BaseModel):
|
||||
"""Represents a single variable update instruction."""
|
||||
|
||||
value: VariableUnion = Field(description="New variable value")
|
||||
value: Variable = Field(description="New variable value")
|
||||
|
||||
|
||||
class UpdateVariablesCommand(GraphEngineCommand):
|
||||
|
||||
@ -16,13 +16,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
ChunkType,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
@ -327,24 +321,11 @@ class ResponseStreamCoordinator:
|
||||
selector: Sequence[str],
|
||||
chunk: str,
|
||||
is_final: bool = False,
|
||||
chunk_type: ChunkType = ChunkType.TEXT,
|
||||
tool_call: ToolCall | None = None,
|
||||
tool_result: ToolResult | None = None,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Create a stream chunk event with consistent structure.
|
||||
|
||||
For selectors with special prefixes (sys, env, conversation), we use the
|
||||
active response node's information since these are not actual node IDs.
|
||||
|
||||
Args:
|
||||
node_id: The node ID to attribute the event to
|
||||
execution_id: The execution ID for this node
|
||||
selector: The variable selector
|
||||
chunk: The chunk content
|
||||
is_final: Whether this is the final chunk
|
||||
chunk_type: The semantic type of the chunk being streamed
|
||||
tool_call: Structured data for tool_call chunks
|
||||
tool_result: Structured data for tool_result chunks
|
||||
"""
|
||||
# Check if this is a special selector that doesn't correspond to a node
|
||||
if selector and selector[0] not in self._graph.nodes and self._active_session:
|
||||
@ -357,9 +338,6 @@ class ResponseStreamCoordinator:
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
chunk_type=chunk_type,
|
||||
tool_call=tool_call,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
# Standard case: selector refers to an actual node
|
||||
@ -371,9 +349,6 @@ class ResponseStreamCoordinator:
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
chunk_type=chunk_type,
|
||||
tool_call=tool_call,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
|
||||
@ -381,8 +356,6 @@ class ResponseStreamCoordinator:
|
||||
|
||||
Handles both regular node selectors and special system selectors (sys, env, conversation).
|
||||
For special selectors, we attribute the output to the active response node.
|
||||
|
||||
For object-type variables, automatically streams all child fields that have stream events.
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
@ -391,81 +364,60 @@ class ResponseStreamCoordinator:
|
||||
# Determine which node to attribute the output to
|
||||
# For special selectors (sys, env, conversation), use the active response node
|
||||
# For regular selectors, use the source node
|
||||
active_session = self._active_session
|
||||
special_selector = bool(active_session and source_selector_prefix not in self._graph.nodes)
|
||||
output_node_id = active_session.node_id if special_selector and active_session else source_selector_prefix
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
# Special selector - use active response node
|
||||
output_node_id = self._active_session.node_id
|
||||
else:
|
||||
# Regular node selector
|
||||
output_node_id = source_selector_prefix
|
||||
execution_id = self._get_or_create_execution_id(output_node_id)
|
||||
|
||||
# Check if there's a direct stream for this selector
|
||||
has_direct_stream = (
|
||||
tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams
|
||||
)
|
||||
|
||||
stream_targets = [segment.selector] if has_direct_stream else sorted(self._find_child_streams(segment.selector))
|
||||
|
||||
if stream_targets:
|
||||
all_complete = True
|
||||
|
||||
for target_selector in stream_targets:
|
||||
while self._has_unread_stream(target_selector):
|
||||
if event := self._pop_stream_chunk(target_selector):
|
||||
events.append(
|
||||
self._rewrite_stream_event(
|
||||
event=event,
|
||||
output_node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
special_selector=bool(special_selector),
|
||||
)
|
||||
)
|
||||
|
||||
if not self._is_stream_closed(target_selector):
|
||||
all_complete = False
|
||||
|
||||
is_complete = all_complete
|
||||
|
||||
# Fallback: check if scalar value exists in variable pool
|
||||
if not is_complete and not has_direct_stream:
|
||||
if value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self._active_session
|
||||
and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=segment.selector,
|
||||
chunk=value.markdown,
|
||||
is_final=is_last_segment,
|
||||
# Stream all available chunks
|
||||
while self._has_unread_stream(segment.selector):
|
||||
if event := self._pop_stream_chunk(segment.selector):
|
||||
# For special selectors, we need to update the event to use
|
||||
# the active response node's information
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
# Create a new event with the response node's information
|
||||
# but keep the original selector
|
||||
updated_event = NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=event.selector, # Keep original selector
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
)
|
||||
events.append(updated_event)
|
||||
else:
|
||||
# Regular node selector - use event as is
|
||||
events.append(event)
|
||||
|
||||
# Check if this is the last chunk by looking ahead
|
||||
stream_closed = self._is_stream_closed(segment.selector)
|
||||
# Check if stream is closed to determine if segment is complete
|
||||
if stream_closed:
|
||||
is_complete = True
|
||||
|
||||
elif value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=segment.selector,
|
||||
chunk=value.markdown,
|
||||
is_final=is_last_segment,
|
||||
)
|
||||
is_complete = True
|
||||
)
|
||||
is_complete = True
|
||||
|
||||
return events, is_complete
|
||||
|
||||
def _rewrite_stream_event(
|
||||
self,
|
||||
event: NodeRunStreamChunkEvent,
|
||||
output_node_id: str,
|
||||
execution_id: str,
|
||||
special_selector: bool,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Rewrite event to attribute to active response node when selector is special."""
|
||||
if not special_selector:
|
||||
return event
|
||||
|
||||
return self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=event.chunk_type,
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
|
||||
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""Process a text segment. Returns (events, is_complete)."""
|
||||
assert self._active_session is not None
|
||||
@ -561,36 +513,6 @@ class ResponseStreamCoordinator:
|
||||
|
||||
# ============= Internal Stream Management Methods =============
|
||||
|
||||
def _find_child_streams(self, parent_selector: Sequence[str]) -> list[tuple[str, ...]]:
|
||||
"""Find all child stream selectors that are descendants of the parent selector.
|
||||
|
||||
For example, if parent_selector is ['llm', 'generation'], this will find:
|
||||
- ['llm', 'generation', 'content']
|
||||
- ['llm', 'generation', 'tool_calls']
|
||||
- ['llm', 'generation', 'tool_results']
|
||||
- ['llm', 'generation', 'thought']
|
||||
|
||||
Args:
|
||||
parent_selector: The parent selector to search for children
|
||||
|
||||
Returns:
|
||||
List of child selector tuples found in stream buffers or closed streams
|
||||
"""
|
||||
parent_key = tuple(parent_selector)
|
||||
parent_len = len(parent_key)
|
||||
child_streams: set[tuple[str, ...]] = set()
|
||||
|
||||
# Search in both active buffers and closed streams
|
||||
all_selectors = set(self._stream_buffers.keys()) | self._closed_streams
|
||||
|
||||
for selector_key in all_selectors:
|
||||
# Check if this selector is a direct child of the parent
|
||||
# Direct child means: len(child) == len(parent) + 1 and child starts with parent
|
||||
if len(selector_key) == parent_len + 1 and selector_key[:parent_len] == parent_key:
|
||||
child_streams.add(selector_key)
|
||||
|
||||
return sorted(child_streams)
|
||||
|
||||
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Append a stream chunk to the internal buffer.
|
||||
|
||||
@ -36,7 +36,6 @@ from .loop import (
|
||||
|
||||
# Node events
|
||||
from .node import (
|
||||
ChunkType,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
@ -45,13 +44,10 @@ from .node import (
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseGraphEvent",
|
||||
"ChunkType",
|
||||
"GraphEngineEvent",
|
||||
"GraphNodeEventBase",
|
||||
"GraphRunAbortedEvent",
|
||||
@ -77,6 +73,4 @@ __all__ = [
|
||||
"NodeRunStartedEvent",
|
||||
"NodeRunStreamChunkEvent",
|
||||
"NodeRunSucceededEvent",
|
||||
"ToolCall",
|
||||
"ToolResult",
|
||||
]
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
@ -22,39 +21,13 @@ class NodeRunStartedEvent(GraphNodeEventBase):
|
||||
provider_id: str = ""
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
THOUGHT_START = "thought_start" # Agent thought start
|
||||
THOUGHT_END = "thought_end" # Agent thought end
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
||||
"""Stream chunk event for workflow node execution."""
|
||||
|
||||
# Base fields
|
||||
# Spec-compliant fields
|
||||
selector: Sequence[str] = Field(
|
||||
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
|
||||
)
|
||||
chunk: str = Field(..., description="the actual chunk content")
|
||||
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
|
||||
|
||||
# Tool call fields (when chunk_type == TOOL_CALL)
|
||||
tool_call: ToolCall | None = Field(
|
||||
default=None,
|
||||
description="structured payload for tool_call chunks",
|
||||
)
|
||||
|
||||
# Tool result fields (when chunk_type == TOOL_RESULT)
|
||||
tool_result: ToolResult | None = Field(
|
||||
default=None,
|
||||
description="structured payload for tool_result chunks",
|
||||
)
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
||||
|
||||
@ -13,21 +13,16 @@ from .loop import (
|
||||
LoopSucceededEvent,
|
||||
)
|
||||
from .node import (
|
||||
ChunkType,
|
||||
ModelInvokeCompletedEvent,
|
||||
PauseRequestedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunRetryEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
ThoughtChunkEvent,
|
||||
ToolCallChunkEvent,
|
||||
ToolResultChunkEvent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentLogEvent",
|
||||
"ChunkType",
|
||||
"IterationFailedEvent",
|
||||
"IterationNextEvent",
|
||||
"IterationStartedEvent",
|
||||
@ -44,7 +39,4 @@ __all__ = [
|
||||
"RunRetryEvent",
|
||||
"StreamChunkEvent",
|
||||
"StreamCompletedEvent",
|
||||
"ThoughtChunkEvent",
|
||||
"ToolCallChunkEvent",
|
||||
"ToolResultChunkEvent",
|
||||
]
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.file import File
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import ToolCall, ToolResult
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
@ -34,60 +32,13 @@ class RunRetryEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="Retry start time")
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
THOUGHT_START = "thought_start" # Agent thought start
|
||||
THOUGHT_END = "thought_end" # Agent thought end
|
||||
|
||||
|
||||
class StreamChunkEvent(NodeEventBase):
|
||||
"""Base stream chunk event - normal text streaming output."""
|
||||
|
||||
# Spec-compliant fields
|
||||
selector: Sequence[str] = Field(
|
||||
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
|
||||
)
|
||||
chunk: str = Field(..., description="the actual chunk content")
|
||||
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
|
||||
tool_call: ToolCall | None = Field(default=None, description="structured payload for tool_call chunks")
|
||||
tool_result: ToolResult | None = Field(default=None, description="structured payload for tool_result chunks")
|
||||
|
||||
|
||||
class ToolCallChunkEvent(StreamChunkEvent):
|
||||
"""Tool call streaming event - tool call arguments streaming output."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TOOL_CALL, frozen=True)
|
||||
tool_call: ToolCall | None = Field(default=None, description="structured tool call payload")
|
||||
|
||||
|
||||
class ToolResultChunkEvent(StreamChunkEvent):
|
||||
"""Tool result event - tool execution result."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TOOL_RESULT, frozen=True)
|
||||
tool_result: ToolResult | None = Field(default=None, description="structured tool result payload")
|
||||
|
||||
|
||||
class ThoughtStartChunkEvent(StreamChunkEvent):
|
||||
"""Agent thought start streaming event - Agent thinking process (ReAct)."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT_START, frozen=True)
|
||||
|
||||
|
||||
class ThoughtEndChunkEvent(StreamChunkEvent):
|
||||
"""Agent thought end streaming event - Agent thinking process (ReAct)."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT_END, frozen=True)
|
||||
|
||||
|
||||
class ThoughtChunkEvent(StreamChunkEvent):
|
||||
"""Agent thought streaming event - Agent thinking process (ReAct)."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT, frozen=True)
|
||||
|
||||
|
||||
class StreamCompletedEvent(NodeEventBase):
|
||||
|
||||
@ -23,10 +23,22 @@ class RetryConfig(BaseModel):
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
# First token timeout for LLM nodes (milliseconds), 0 means no timeout
|
||||
first_token_timeout: int = 0
|
||||
|
||||
@property
|
||||
def first_token_timeout_seconds(self) -> float:
|
||||
return self.first_token_timeout / 1000
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
@property
|
||||
def has_first_token_timeout(self) -> bool:
|
||||
"""Check if first token timeout should be applied (retry enabled and timeout > 0)."""
|
||||
return self.retry_enabled and self.first_token_timeout > 0
|
||||
|
||||
|
||||
class VariableSelector(BaseModel):
|
||||
"""
|
||||
|
||||
@ -48,9 +48,6 @@ from core.workflow.node_events import (
|
||||
RunRetrieverResourceEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
ThoughtChunkEvent,
|
||||
ToolCallChunkEvent,
|
||||
ToolResultChunkEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -567,8 +564,6 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
@ -576,60 +571,6 @@ class Node(Generic[NodeDataT]):
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType(event.chunk_type.value),
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: ToolCallChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType.TOOL_CALL,
|
||||
tool_call=event.tool_call,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: ToolResultChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.entities import ToolResult, ToolResultStatus
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
tool_result = event.tool_result or ToolResult()
|
||||
status: ToolResultStatus = tool_result.status or ToolResultStatus.SUCCESS
|
||||
tool_result = tool_result.model_copy(
|
||||
update={"status": status, "files": tool_result.files or []},
|
||||
)
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType.TOOL_RESULT,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: ThoughtChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType.THOUGHT,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
|
||||
@ -62,21 +62,6 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
inputs = {"variable_selector": variable_selector}
|
||||
process_data = {"documents": value if isinstance(value, list) else [value]}
|
||||
|
||||
# Ensure storage_key is loaded for File objects
|
||||
files_to_check = value if isinstance(value, list) else [value]
|
||||
files_needing_storage_key = [
|
||||
f for f in files_to_check
|
||||
if isinstance(f, File) and not f.storage_key and f.related_id
|
||||
]
|
||||
if files_needing_storage_key:
|
||||
from factories.file_factory import StorageKeyLoader
|
||||
from extensions.ext_database import db
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
with Session(bind=db.engine) as session:
|
||||
storage_key_loader = StorageKeyLoader(session, tenant_id=self.tenant_id)
|
||||
storage_key_loader.load_storage_keys(files_needing_storage_key)
|
||||
|
||||
try:
|
||||
if isinstance(value, list):
|
||||
extracted_text_list = list(map(_extract_text_from_file, value))
|
||||
@ -430,15 +415,6 @@ def _download_file_content(file: File) -> bytes:
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
else:
|
||||
# Check if storage_key is set
|
||||
if not file.storage_key:
|
||||
raise FileDownloadError(f"File storage_key is missing for file: {file.filename}")
|
||||
|
||||
# Check if file exists before downloading
|
||||
from extensions.ext_storage import storage
|
||||
if not storage.exists(file.storage_key):
|
||||
raise FileDownloadError(f"File not found in storage: {file.storage_key}")
|
||||
|
||||
return file_manager.download(file)
|
||||
except Exception as e:
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
@ -17,6 +17,7 @@ from core.helper import ssrf_proxy
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from ..protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from .entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeData,
|
||||
@ -78,6 +79,8 @@ class Executor:
|
||||
timeout: HttpRequestNodeTimeout,
|
||||
variable_pool: VariablePool,
|
||||
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
http_client: HttpClientProtocol = ssrf_proxy,
|
||||
file_manager: FileManagerProtocol = file_manager,
|
||||
):
|
||||
# If authorization API key is present, convert the API key using the variable pool
|
||||
if node_data.authorization.type == "api-key":
|
||||
@ -104,6 +107,8 @@ class Executor:
|
||||
self.data = None
|
||||
self.json = None
|
||||
self.max_retries = max_retries
|
||||
self._http_client = http_client
|
||||
self._file_manager = file_manager
|
||||
|
||||
# init template
|
||||
self.variable_pool = variable_pool
|
||||
@ -200,7 +205,7 @@ class Executor:
|
||||
if file_variable is None:
|
||||
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
|
||||
file = file_variable.value
|
||||
self.content = file_manager.download(file)
|
||||
self.content = self._file_manager.download(file)
|
||||
case "x-www-form-urlencoded":
|
||||
form_data = {
|
||||
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
|
||||
@ -239,7 +244,7 @@ class Executor:
|
||||
):
|
||||
file_tuple = (
|
||||
file.filename,
|
||||
file_manager.download(file),
|
||||
self._file_manager.download(file),
|
||||
file.mime_type or "application/octet-stream",
|
||||
)
|
||||
if key not in files:
|
||||
@ -332,19 +337,18 @@ class Executor:
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
_METHOD_MAP = {
|
||||
"get": ssrf_proxy.get,
|
||||
"head": ssrf_proxy.head,
|
||||
"post": ssrf_proxy.post,
|
||||
"put": ssrf_proxy.put,
|
||||
"delete": ssrf_proxy.delete,
|
||||
"patch": ssrf_proxy.patch,
|
||||
"get": self._http_client.get,
|
||||
"head": self._http_client.head,
|
||||
"post": self._http_client.post,
|
||||
"put": self._http_client.put,
|
||||
"delete": self._http_client.delete,
|
||||
"patch": self._http_client.patch,
|
||||
}
|
||||
method_lc = self.method.lower()
|
||||
if method_lc not in _METHOD_MAP:
|
||||
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
||||
|
||||
request_args = {
|
||||
"url": self.url,
|
||||
"data": self.data,
|
||||
"files": self.files,
|
||||
"json": self.json,
|
||||
@ -357,8 +361,12 @@ class Executor:
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries)
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
response: httpx.Response = _METHOD_MAP[method_lc](
|
||||
url=self.url,
|
||||
**request_args,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
|
||||
raise HttpRequestNodeError(str(e)) from e
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
return response
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
@ -13,6 +14,7 @@ from core.workflow.nodes.base import variable_template_parser
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from factories import file_factory
|
||||
|
||||
from .entities import (
|
||||
@ -30,10 +32,35 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
node_type = NodeType.HTTP_REQUEST
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
http_client: HttpClientProtocol = ssrf_proxy,
|
||||
tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
file_manager: FileManagerProtocol = file_manager,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._http_client = http_client
|
||||
self._tool_file_manager_factory = tool_file_manager_factory
|
||||
self._file_manager = file_manager
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
@ -71,6 +98,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
max_retries=0,
|
||||
http_client=self._http_client,
|
||||
file_manager=self._file_manager,
|
||||
)
|
||||
process_data["request"] = http_executor.to_log()
|
||||
|
||||
@ -199,7 +228,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
mime_type = (
|
||||
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
)
|
||||
tool_file_manager = ToolFileManager()
|
||||
tool_file_manager = self._tool_file_manager_factory()
|
||||
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
|
||||
@ -11,7 +11,7 @@ from typing_extensions import TypeIs
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import (
|
||||
NodeExecutionType,
|
||||
@ -240,7 +240,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
datetime,
|
||||
list[GraphNodeEventBase],
|
||||
object | None,
|
||||
dict[str, VariableUnion],
|
||||
dict[str, Variable],
|
||||
LLMUsage,
|
||||
]
|
||||
],
|
||||
@ -308,7 +308,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
@ -515,11 +515,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
|
||||
return variable_mapping
|
||||
|
||||
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
|
||||
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
|
||||
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
|
||||
|
||||
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
|
||||
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
|
||||
parent_pool = self.graph_runtime_state.variable_pool
|
||||
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
|
||||
|
||||
@ -158,5 +158,3 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
||||
type: str = "knowledge-index"
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
indexing_technique: str | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import concurrent.futures
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -18,9 +16,7 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
@ -71,18 +67,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
# index knowledge
|
||||
try:
|
||||
if is_preview:
|
||||
# Preview mode: generate summaries for chunks directly without saving to database
|
||||
# Format preview and generate summaries on-the-fly
|
||||
# Get indexing_technique and summary_index_setting from node_data (workflow graph config)
|
||||
# or fallback to dataset if not available in node_data
|
||||
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
|
||||
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
|
||||
|
||||
outputs = self._get_preview_output_with_summaries(
|
||||
node_data.chunk_structure, chunks, dataset=dataset,
|
||||
indexing_technique=indexing_technique,
|
||||
summary_index_setting=summary_index_setting
|
||||
)
|
||||
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
@ -178,9 +163,6 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Generate summary index if enabled
|
||||
self._handle_summary_index_generation(dataset, document, variable_pool)
|
||||
|
||||
return {
|
||||
"dataset_id": ds_id_value,
|
||||
"dataset_name": dataset_name_value,
|
||||
@ -191,269 +173,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
"display_status": "completed",
|
||||
}
|
||||
|
||||
def _handle_summary_index_generation(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
document: Document,
|
||||
variable_pool: VariablePool,
|
||||
) -> None:
|
||||
"""
|
||||
Handle summary index generation based on mode (debug/preview or production).
|
||||
|
||||
Args:
|
||||
dataset: Dataset containing the document
|
||||
document: Document to generate summaries for
|
||||
variable_pool: Variable pool to check invoke_from
|
||||
"""
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
return
|
||||
|
||||
# Check if summary index is enabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return
|
||||
|
||||
# Skip qa_model documents
|
||||
if document.doc_form == "qa_model":
|
||||
return
|
||||
|
||||
# Determine if in preview/debug mode
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER
|
||||
|
||||
# Determine if only parent chunks should be processed
|
||||
only_parent_chunks = dataset.chunk_structure == "parent_child_index"
|
||||
|
||||
if is_preview:
|
||||
try:
|
||||
# Query segments that need summary generation
|
||||
query = db.session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
segments = query.all()
|
||||
|
||||
if not segments:
|
||||
logger.info(f"No segments found for document {document.id}")
|
||||
return
|
||||
|
||||
# Filter segments based on mode
|
||||
segments_to_process = []
|
||||
for segment in segments:
|
||||
# Skip if summary already exists
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed")
|
||||
.first()
|
||||
)
|
||||
if existing_summary:
|
||||
continue
|
||||
|
||||
# For parent-child mode, all segments are parent chunks, so process all
|
||||
segments_to_process.append(segment)
|
||||
|
||||
if not segments_to_process:
|
||||
logger.info(f"No segments need summary generation for document {document.id}")
|
||||
return
|
||||
|
||||
# Use ThreadPoolExecutor for concurrent generation
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
max_workers = min(10, len(segments_to_process)) # Limit to 10 workers
|
||||
|
||||
def process_segment(segment: DocumentSegment) -> None:
|
||||
"""Process a single segment in a thread with Flask app context."""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
SummaryIndexService.generate_and_vectorize_summary(
|
||||
segment, dataset, summary_index_setting
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary for segment {segment.id}: {str(e)}")
|
||||
# Continue processing other segments
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [
|
||||
executor.submit(process_segment, segment) for segment in segments_to_process
|
||||
]
|
||||
# Wait for all tasks to complete
|
||||
concurrent.futures.wait(futures, timeout=300)
|
||||
|
||||
logger.info(
|
||||
f"Successfully generated summary index for {len(segments_to_process)} segments "
|
||||
f"in document {document.id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to generate summary index for document {document.id}: {str(e)}")
|
||||
# Don't fail the entire indexing process if summary generation fails
|
||||
else:
|
||||
# Production mode: asynchronous generation
|
||||
logger.info(f"Queuing summary index generation task for document {document.id} (production mode)")
|
||||
try:
|
||||
generate_summary_index_task.delay(dataset.id, document.id, None)
|
||||
logger.info(f"Summary index generation task queued for document {document.id}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to queue summary index generation task for document {document.id}: {str(e)}")
|
||||
# Don't fail the entire indexing process if task queuing fails
|
||||
|
||||
def _get_preview_output_with_summaries(
|
||||
self, chunk_structure: str, chunks: Any, dataset: Dataset,
|
||||
indexing_technique: str | None = None,
|
||||
summary_index_setting: dict | None = None
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Generate preview output with summaries for chunks in preview mode.
|
||||
This method generates summaries on-the-fly without saving to database.
|
||||
|
||||
Args:
|
||||
chunk_structure: Chunk structure type
|
||||
chunks: Chunks to generate preview for
|
||||
dataset: Dataset object (for tenant_id)
|
||||
indexing_technique: Indexing technique from node config or dataset
|
||||
summary_index_setting: Summary index setting from node config or dataset
|
||||
"""
|
||||
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
|
||||
# Check if summary index is enabled
|
||||
if indexing_technique != "high_quality":
|
||||
return preview_output
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return preview_output
|
||||
|
||||
# Generate summaries for chunks
|
||||
if "preview" in preview_output and isinstance(preview_output["preview"], list):
|
||||
chunk_count = len(preview_output["preview"])
|
||||
logger.info(
|
||||
f"Generating summaries for {chunk_count} chunks in preview mode "
|
||||
f"(dataset: {dataset.id})"
|
||||
)
|
||||
# Use ParagraphIndexProcessor's generate_summary method
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
# Get Flask app for application context in worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def generate_summary_for_chunk(preview_item: dict) -> None:
|
||||
"""Generate summary for a single chunk."""
|
||||
if "content" in preview_item:
|
||||
try:
|
||||
# Set Flask application context in worker thread
|
||||
if flask_app:
|
||||
with flask_app.app_context():
|
||||
summary = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary for chunk: {str(e)}")
|
||||
# Don't fail the entire preview if summary generation fails
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
|
||||
timeout_seconds = min(300, 60 * len(preview_output["preview"]))
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor:
|
||||
futures = [
|
||||
executor.submit(generate_summary_for_chunk, preview_item)
|
||||
for preview_item in preview_output["preview"]
|
||||
]
|
||||
# Wait for all tasks to complete with timeout
|
||||
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
|
||||
|
||||
# Cancel tasks that didn't complete in time
|
||||
if not_done:
|
||||
logger.warning(
|
||||
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s. "
|
||||
"Cancelling remaining tasks..."
|
||||
)
|
||||
for future in not_done:
|
||||
future.cancel()
|
||||
# Wait a bit for cancellation to take effect
|
||||
concurrent.futures.wait(not_done, timeout=5)
|
||||
|
||||
completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None)
|
||||
logger.info(
|
||||
f"Completed summary generation for preview chunks: {completed_count}/{len(preview_output['preview'])} succeeded"
|
||||
)
|
||||
|
||||
return preview_output
|
||||
|
||||
def _get_preview_output(
|
||||
self, chunk_structure: str, chunks: Any, dataset: Dataset | None = None, variable_pool: VariablePool | None = None
|
||||
) -> Mapping[str, Any]:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
|
||||
# If dataset is provided, try to enrich preview with summaries
|
||||
if dataset and variable_pool:
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if document_id:
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if document:
|
||||
# Query summaries for this document
|
||||
summaries = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if summaries:
|
||||
# Create a map of segment content to summary for matching
|
||||
# Use content matching as chunks in preview might not be indexed yet
|
||||
summary_by_content = {}
|
||||
for summary in summaries:
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(id=summary.chunk_id, dataset_id=dataset.id)
|
||||
.first()
|
||||
)
|
||||
if segment:
|
||||
# Normalize content for matching (strip whitespace)
|
||||
normalized_content = segment.content.strip()
|
||||
summary_by_content[normalized_content] = summary.summary_content
|
||||
|
||||
# Enrich preview with summaries by content matching
|
||||
if "preview" in preview_output and isinstance(preview_output["preview"], list):
|
||||
matched_count = 0
|
||||
for preview_item in preview_output["preview"]:
|
||||
if "content" in preview_item:
|
||||
# Normalize content for matching
|
||||
normalized_chunk_content = preview_item["content"].strip()
|
||||
if normalized_chunk_content in summary_by_content:
|
||||
preview_item["summary"] = summary_by_content[normalized_chunk_content]
|
||||
matched_count += 1
|
||||
|
||||
if matched_count > 0:
|
||||
logger.info(
|
||||
f"Enriched preview with {matched_count} existing summaries "
|
||||
f"(dataset: {dataset.id}, document: {document.id})"
|
||||
)
|
||||
|
||||
return preview_output
|
||||
return index_processor.format_preview(chunks)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
|
||||
@ -3,7 +3,6 @@ from .entities import (
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
ToolMetadata,
|
||||
VisionConfig,
|
||||
)
|
||||
from .node import LLMNode
|
||||
@ -14,6 +13,5 @@ __all__ = [
|
||||
"LLMNodeCompletionModelPromptTemplate",
|
||||
"LLMNodeData",
|
||||
"ModelConfig",
|
||||
"ToolMetadata",
|
||||
"VisionConfig",
|
||||
]
|
||||
|
||||
@ -1,17 +1,10 @@
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.entities import ToolCall, ToolCallResult
|
||||
from core.workflow.node_events import AgentLogEvent
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
@ -65,268 +58,6 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class ToolMetadata(BaseModel):
|
||||
"""
|
||||
Tool metadata for LLM node with tool support.
|
||||
|
||||
Defines the essential fields needed for tool configuration,
|
||||
particularly the 'type' field to identify tool provider type.
|
||||
"""
|
||||
|
||||
# Core fields
|
||||
enabled: bool = True
|
||||
type: ToolProviderType = Field(..., description="Tool provider type: builtin, api, mcp, workflow")
|
||||
provider_name: str = Field(..., description="Tool provider name/identifier")
|
||||
tool_name: str = Field(..., description="Tool name")
|
||||
|
||||
# Optional fields
|
||||
plugin_unique_identifier: str | None = Field(None, description="Plugin unique identifier for plugin tools")
|
||||
credential_id: str | None = Field(None, description="Credential ID for tools requiring authentication")
|
||||
|
||||
# Configuration fields
|
||||
parameters: dict[str, Any] = Field(default_factory=dict, description="Tool parameters")
|
||||
settings: dict[str, Any] = Field(default_factory=dict, description="Tool settings configuration")
|
||||
extra: dict[str, Any] = Field(default_factory=dict, description="Extra tool configuration like custom description")
|
||||
|
||||
|
||||
class ModelTraceSegment(BaseModel):
|
||||
"""Model invocation trace segment with token usage and output."""
|
||||
|
||||
text: str | None = Field(None, description="Model output text content")
|
||||
reasoning: str | None = Field(None, description="Reasoning/thought content from model")
|
||||
tool_calls: list[ToolCall] = Field(default_factory=list, description="Tool calls made by the model")
|
||||
|
||||
|
||||
class ToolTraceSegment(BaseModel):
|
||||
"""Tool invocation trace segment with call details and result."""
|
||||
|
||||
id: str | None = Field(default=None, description="Unique identifier for this tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool being called")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
output: str | None = Field(default=None, description="Tool call result")
|
||||
|
||||
|
||||
class LLMTraceSegment(BaseModel):
|
||||
"""
|
||||
Streaming trace segment for LLM tool-enabled runs.
|
||||
|
||||
Represents alternating model and tool invocations in sequence:
|
||||
model -> tool -> model -> tool -> ...
|
||||
|
||||
Each segment records its execution duration.
|
||||
"""
|
||||
|
||||
type: Literal["model", "tool"]
|
||||
duration: float = Field(..., description="Execution duration in seconds")
|
||||
usage: LLMUsage | None = Field(default=None, description="Token usage statistics for this model call")
|
||||
output: ModelTraceSegment | ToolTraceSegment = Field(..., description="Output of the segment")
|
||||
|
||||
# Common metadata for both model and tool segments
|
||||
provider: str | None = Field(default=None, description="Model or tool provider identifier")
|
||||
name: str | None = Field(default=None, description="Name of the model or tool")
|
||||
icon: str | None = Field(default=None, description="Icon for the provider")
|
||||
icon_dark: str | None = Field(default=None, description="Dark theme icon for the provider")
|
||||
error: str | None = Field(default=None, description="Error message if segment failed")
|
||||
status: Literal["success", "error"] | None = Field(default=None, description="Tool execution status")
|
||||
|
||||
|
||||
class LLMGenerationData(BaseModel):
|
||||
"""Generation data from LLM invocation with tools.
|
||||
|
||||
For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2
|
||||
- reasoning_contents: [thought1, thought2, ...] - one element per turn
|
||||
- tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results
|
||||
"""
|
||||
|
||||
text: str = Field(..., description="Accumulated text content from all turns")
|
||||
reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn")
|
||||
tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results")
|
||||
sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering")
|
||||
usage: LLMUsage = Field(..., description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(None, description="Finish reason from LLM")
|
||||
files: list[File] = Field(default_factory=list, description="Generated files")
|
||||
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
|
||||
|
||||
|
||||
class ThinkTagStreamParser:
|
||||
"""Lightweight state machine to split streaming chunks by <think> tags."""
|
||||
|
||||
_START_PATTERN = re.compile(r"<think(?:\s[^>]*)?>", re.IGNORECASE)
|
||||
_END_PATTERN = re.compile(r"</think>", re.IGNORECASE)
|
||||
_START_PREFIX = "<think"
|
||||
_END_PREFIX = "</think"
|
||||
|
||||
def __init__(self):
|
||||
self._buffer = ""
|
||||
self._in_think = False
|
||||
|
||||
@staticmethod
|
||||
def _suffix_prefix_len(text: str, prefix: str) -> int:
|
||||
"""Return length of the longest suffix of `text` that is a prefix of `prefix`."""
|
||||
max_len = min(len(text), len(prefix) - 1)
|
||||
for i in range(max_len, 0, -1):
|
||||
if text[-i:].lower() == prefix[:i].lower():
|
||||
return i
|
||||
return 0
|
||||
|
||||
def process(self, chunk: str) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Split incoming chunk into ('thought' | 'text', content) tuples.
|
||||
Content excludes the <think> tags themselves and handles split tags across chunks.
|
||||
"""
|
||||
parts: list[tuple[str, str]] = []
|
||||
self._buffer += chunk
|
||||
|
||||
while self._buffer:
|
||||
if self._in_think:
|
||||
end_match = self._END_PATTERN.search(self._buffer)
|
||||
if end_match:
|
||||
thought_text = self._buffer[: end_match.start()]
|
||||
if thought_text:
|
||||
parts.append(("thought", thought_text))
|
||||
parts.append(("thought_end", ""))
|
||||
self._buffer = self._buffer[end_match.end() :]
|
||||
self._in_think = False
|
||||
continue
|
||||
|
||||
hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX)
|
||||
emit = self._buffer[: len(self._buffer) - hold_len]
|
||||
if emit:
|
||||
parts.append(("thought", emit))
|
||||
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
|
||||
break
|
||||
|
||||
start_match = self._START_PATTERN.search(self._buffer)
|
||||
if start_match:
|
||||
prefix = self._buffer[: start_match.start()]
|
||||
if prefix:
|
||||
parts.append(("text", prefix))
|
||||
self._buffer = self._buffer[start_match.end() :]
|
||||
parts.append(("thought_start", ""))
|
||||
self._in_think = True
|
||||
continue
|
||||
|
||||
hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX)
|
||||
emit = self._buffer[: len(self._buffer) - hold_len]
|
||||
if emit:
|
||||
parts.append(("text", emit))
|
||||
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
|
||||
break
|
||||
|
||||
cleaned_parts: list[tuple[str, str]] = []
|
||||
for kind, content in parts:
|
||||
# Extra safeguard: strip any stray tags that slipped through.
|
||||
content = self._START_PATTERN.sub("", content)
|
||||
content = self._END_PATTERN.sub("", content)
|
||||
if content or kind in {"thought_start", "thought_end"}:
|
||||
cleaned_parts.append((kind, content))
|
||||
|
||||
return cleaned_parts
|
||||
|
||||
def flush(self) -> list[tuple[str, str]]:
|
||||
"""Flush remaining buffer when the stream ends."""
|
||||
if not self._buffer:
|
||||
return []
|
||||
kind = "thought" if self._in_think else "text"
|
||||
content = self._buffer
|
||||
# Drop dangling partial tags instead of emitting them
|
||||
if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX):
|
||||
content = ""
|
||||
self._buffer = ""
|
||||
if not content and not self._in_think:
|
||||
return []
|
||||
# Strip any complete tags that might still be present.
|
||||
content = self._START_PATTERN.sub("", content)
|
||||
content = self._END_PATTERN.sub("", content)
|
||||
|
||||
result: list[tuple[str, str]] = []
|
||||
if content:
|
||||
result.append((kind, content))
|
||||
if self._in_think:
|
||||
result.append(("thought_end", ""))
|
||||
self._in_think = False
|
||||
return result
|
||||
|
||||
|
||||
class StreamBuffers(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
think_parser: ThinkTagStreamParser = Field(default_factory=ThinkTagStreamParser)
|
||||
pending_thought: list[str] = Field(default_factory=list)
|
||||
pending_content: list[str] = Field(default_factory=list)
|
||||
pending_tool_calls: list[ToolCall] = Field(default_factory=list)
|
||||
current_turn_reasoning: list[str] = Field(default_factory=list)
|
||||
reasoning_per_turn: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TraceState(BaseModel):
|
||||
trace_segments: list[LLMTraceSegment] = Field(default_factory=list)
|
||||
tool_trace_map: dict[str, LLMTraceSegment] = Field(default_factory=dict)
|
||||
tool_call_index_map: dict[str, int] = Field(default_factory=dict)
|
||||
model_segment_start_time: float | None = Field(default=None, description="Start time for current model segment")
|
||||
pending_usage: LLMUsage | None = Field(default=None, description="Pending usage for current model segment")
|
||||
|
||||
|
||||
class AggregatedResult(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
text: str = ""
|
||||
files: list[File] = Field(default_factory=list)
|
||||
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
agent_logs: list[AgentLogEvent] = Field(default_factory=list)
|
||||
agent_result: AgentResult | None = None
|
||||
|
||||
|
||||
class ToolOutputState(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
stream: StreamBuffers = Field(default_factory=StreamBuffers)
|
||||
trace: TraceState = Field(default_factory=TraceState)
|
||||
aggregate: AggregatedResult = Field(default_factory=AggregatedResult)
|
||||
agent: AgentContext = Field(default_factory=AgentContext)
|
||||
|
||||
|
||||
class ToolLogPayload(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
tool_name: str = ""
|
||||
tool_call_id: str = ""
|
||||
tool_args: dict[str, Any] = Field(default_factory=dict)
|
||||
tool_output: Any = None
|
||||
tool_error: Any = None
|
||||
files: list[Any] = Field(default_factory=list)
|
||||
meta: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_log(cls, log: AgentLog) -> "ToolLogPayload":
|
||||
data = log.data or {}
|
||||
return cls(
|
||||
tool_name=data.get("tool_name", ""),
|
||||
tool_call_id=data.get("tool_call_id", ""),
|
||||
tool_args=data.get("tool_args") or {},
|
||||
tool_output=data.get("output"),
|
||||
tool_error=data.get("error"),
|
||||
files=data.get("files") or [],
|
||||
meta=data.get("meta") or {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, data: Mapping[str, Any]) -> "ToolLogPayload":
|
||||
return cls(
|
||||
tool_name=data.get("tool_name", ""),
|
||||
tool_call_id=data.get("tool_call_id", ""),
|
||||
tool_args=data.get("tool_args") or {},
|
||||
tool_output=data.get("output"),
|
||||
tool_error=data.get("error"),
|
||||
files=data.get("files") or [],
|
||||
meta=data.get("meta") or {},
|
||||
)
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
@ -355,10 +86,6 @@ class LLMNodeData(BaseNodeData):
|
||||
),
|
||||
)
|
||||
|
||||
# Tool support
|
||||
tools: Sequence[ToolMetadata] = Field(default_factory=list)
|
||||
max_iterations: int | None = Field(default=None, description="Maximum number of iterations for the LLM node")
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
def convert_none_prompt_config(cls, v: Any):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,16 +1,21 @@
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
@ -43,6 +48,9 @@ class DifyNodeFactory(NodeFactory):
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
http_request_http_client: HttpClientProtocol = ssrf_proxy,
|
||||
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
http_request_file_manager: FileManagerProtocol = file_manager,
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
@ -61,6 +69,9 @@ class DifyNodeFactory(NodeFactory):
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
self._http_request_http_client = http_request_http_client
|
||||
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
|
||||
self._http_request_file_manager = http_request_file_manager
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
@ -113,6 +124,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
code_providers=self._code_providers,
|
||||
code_limits=self._code_limits,
|
||||
)
|
||||
|
||||
if node_type == NodeType.TEMPLATE_TRANSFORM:
|
||||
return TemplateTransformNode(
|
||||
id=node_id,
|
||||
@ -122,6 +134,17 @@ class DifyNodeFactory(NodeFactory):
|
||||
template_renderer=self._template_renderer,
|
||||
)
|
||||
|
||||
if node_type == NodeType.HTTP_REQUEST:
|
||||
return HttpRequestNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
http_client=self._http_request_http_client,
|
||||
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
|
||||
file_manager=self._http_request_file_manager,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
||||
29
api/core/workflow/nodes/protocols.py
Normal file
29
api/core/workflow/nodes/protocols.py
Normal file
@ -0,0 +1,29 @@
|
||||
from typing import Protocol
|
||||
|
||||
import httpx
|
||||
|
||||
from core.file import File
|
||||
|
||||
|
||||
class HttpClientProtocol(Protocol):
|
||||
@property
|
||||
def max_retries_exceeded_error(self) -> type[Exception]: ...
|
||||
|
||||
@property
|
||||
def request_error(self) -> type[Exception]: ...
|
||||
|
||||
def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
|
||||
def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
|
||||
def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
|
||||
def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
|
||||
def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
|
||||
def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
|
||||
|
||||
class FileManagerProtocol(Protocol):
|
||||
def download(self, f: File, /) -> bytes: ...
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables import SegmentType, VariableBase
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
@ -73,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
assigned_variable_selector = self.node_data.assigned_variable_selector
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
if not isinstance(original_variable, VariableBase):
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
|
||||
match self.node_data.write_mode:
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables import SegmentType, VariableBase
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
@ -118,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
# ==================== Validation Part
|
||||
|
||||
# Check if variable exists
|
||||
if not isinstance(variable, Variable):
|
||||
if not isinstance(variable, VariableBase):
|
||||
raise VariableNotFoundError(variable_selector=item.variable_selector)
|
||||
|
||||
# Check if operation is supported
|
||||
@ -192,7 +192,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
|
||||
for selector in updated_variable_selectors:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, Variable):
|
||||
if not isinstance(variable, VariableBase):
|
||||
raise VariableNotFoundError(variable_selector=selector)
|
||||
process_data[variable.name] = variable.value
|
||||
|
||||
@ -213,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
def _handle_item(
|
||||
self,
|
||||
*,
|
||||
variable: Variable,
|
||||
variable: VariableBase,
|
||||
operation: Operation,
|
||||
value: Any,
|
||||
):
|
||||
|
||||
@ -9,10 +9,10 @@ from typing import Annotated, Any, Union, cast
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables import Segment, SegmentGroup, VariableBase
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, ObjectSegment
|
||||
from core.variables.variables import RAGPipelineVariableInput, VariableUnion
|
||||
from core.variables.variables import RAGPipelineVariableInput, Variable
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
@ -32,7 +32,7 @@ class VariablePool(BaseModel):
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
|
||||
variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
|
||||
description="Variables mapping",
|
||||
default=defaultdict(dict),
|
||||
)
|
||||
@ -46,13 +46,13 @@ class VariablePool(BaseModel):
|
||||
description="System variables",
|
||||
default_factory=SystemVariable.empty,
|
||||
)
|
||||
environment_variables: Sequence[VariableUnion] = Field(
|
||||
environment_variables: Sequence[Variable] = Field(
|
||||
description="Environment variables.",
|
||||
default_factory=list[VariableUnion],
|
||||
default_factory=list[Variable],
|
||||
)
|
||||
conversation_variables: Sequence[VariableUnion] = Field(
|
||||
conversation_variables: Sequence[Variable] = Field(
|
||||
description="Conversation variables.",
|
||||
default_factory=list[VariableUnion],
|
||||
default_factory=list[Variable],
|
||||
)
|
||||
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
|
||||
description="RAG pipeline variables.",
|
||||
@ -105,7 +105,7 @@ class VariablePool(BaseModel):
|
||||
f"got {len(selector)} elements"
|
||||
)
|
||||
|
||||
if isinstance(value, Variable):
|
||||
if isinstance(value, VariableBase):
|
||||
variable = value
|
||||
elif isinstance(value, Segment):
|
||||
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
|
||||
@ -114,9 +114,9 @@ class VariablePool(BaseModel):
|
||||
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
||||
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
|
||||
# Based on the definition of `Variable`,
|
||||
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
||||
self.variable_dictionary[node_id][name] = cast(Variable, variable)
|
||||
|
||||
@classmethod
|
||||
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
|
||||
|
||||
56
api/core/workflow/utils/generator_timeout.py
Normal file
56
api/core/workflow/utils/generator_timeout.py
Normal file
@ -0,0 +1,56 @@
|
||||
"""
|
||||
Generator timeout utilities for workflow nodes.
|
||||
|
||||
Provides timeout wrappers for streaming generators, primarily used for
|
||||
LLM response streaming where we need to enforce time-to-first-token limits.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class FirstTokenTimeoutError(Exception):
|
||||
"""Raised when a generator fails to yield its first item within the configured timeout."""
|
||||
|
||||
def __init__(self, timeout_ms: int):
|
||||
self.timeout_ms = timeout_ms
|
||||
super().__init__(f"Generator timed out after {timeout_ms}ms without yielding first item")
|
||||
|
||||
|
||||
def with_first_token_timeout(
|
||||
generator: Generator[T, None, None],
|
||||
timeout_seconds: float,
|
||||
) -> Generator[T, None, None]:
|
||||
"""
|
||||
Wrap a generator with first token timeout monitoring.
|
||||
|
||||
Only monitors the time until the FIRST item is yielded.
|
||||
Once the first item arrives, timeout monitoring stops and
|
||||
subsequent items are yielded without timeout checks.
|
||||
|
||||
Args:
|
||||
generator: The source generator to wrap
|
||||
timeout_seconds: Maximum time to wait for first item (in seconds)
|
||||
|
||||
Yields:
|
||||
Items from the source generator
|
||||
|
||||
Raises:
|
||||
FirstTokenTimeoutError: If first item doesn't arrive within timeout
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
|
||||
# Handle first item separately to check timeout only once
|
||||
try:
|
||||
first_item = next(generator)
|
||||
if time.monotonic() - start_time > timeout_seconds:
|
||||
raise FirstTokenTimeoutError(int(timeout_seconds * 1000))
|
||||
yield first_item
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
# Yield remaining items without timeout checks
|
||||
yield from generator
|
||||
@ -2,7 +2,7 @@ import abc
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.variables import Variable
|
||||
from core.variables import VariableBase
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
@ -26,7 +26,7 @@ class VariableLoader(Protocol):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
|
||||
"""Load variables based on the provided selectors. If the selectors are empty,
|
||||
this method should return an empty list.
|
||||
|
||||
@ -36,7 +36,7 @@ class VariableLoader(Protocol):
|
||||
:param: selectors: a list of string list, each inner list should have at least two elements:
|
||||
- the first element is the node ID,
|
||||
- the second element is the variable name.
|
||||
:return: a list of Variable objects that match the provided selectors.
|
||||
:return: a list of VariableBase objects that match the provided selectors.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader):
|
||||
Serves as a placeholder when no variable loading is needed.
|
||||
"""
|
||||
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
|
||||
return []
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user