mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 03:59:30 +08:00
Compare commits
424 Commits
8c92-updat
...
fix/reset-
| Author | SHA1 | Date | |
|---|---|---|---|
| 967c9081cc | |||
| a22cc5bc5e | |||
| 1fbdf6b465 | |||
| 491e1fd6a4 | |||
| 0e33dfb5c2 | |||
| ea708e7a32 | |||
| c09e29c3f8 | |||
| 2d53ba8671 | |||
| 9be863fefa | |||
| 8f43629cd8 | |||
| 9ee71902c1 | |||
| 3f2a461b22 | |||
| 51e2e4a728 | |||
| a012c87445 | |||
| 450578d4c0 | |||
| 837237aa6d | |||
| 495ad848be | |||
| 16fa798f21 | |||
| b32c93df6f | |||
| 76da8b4ff3 | |||
| 25bfc1cc3b | |||
| b63dfbf654 | |||
| 51ea87ab85 | |||
| 00698e41b7 | |||
| df938a4543 | |||
| b34d09649b | |||
| a92df530da | |||
| 9161936f41 | |||
| f9a21b56ab | |||
| 220e1df847 | |||
| 8cfdde594c | |||
| 31a8fd810c | |||
| 939a1b91a0 | |||
| e8a1c99626 | |||
| 898393a0f8 | |||
| 431936beb9 | |||
| 163540bf4a | |||
| 221130b448 | |||
| 9fad97ec9b | |||
| 0c2729d9b3 | |||
| 049925c5d8 | |||
| a2e03b811e | |||
| 1e10bf525c | |||
| 8b1af36d94 | |||
| c56d9e3f69 | |||
| b1eb265fa5 | |||
| 0711dd4159 | |||
| ae0a26f5b6 | |||
| c2a0950660 | |||
| bfe98009fd | |||
| ea1704d211 | |||
| a49ab7258d | |||
| 425a0f9095 | |||
| 3d050f449c | |||
| 905a5b348d | |||
| 1a1e825685 | |||
| 758b289b6f | |||
| 11be198fc8 | |||
| 3e082e6976 | |||
| cf990cdace | |||
| ce309bd008 | |||
| 3ed0937734 | |||
| 04912fa775 | |||
| 0ff1b61232 | |||
| 4d3d8b35d9 | |||
| c323028179 | |||
| 977406f703 | |||
| 80535ed20e | |||
| f4f391ada0 | |||
| 6d9d2a1079 | |||
| 4db40646cb | |||
| 71bd20cf73 | |||
| dbc8ffccbd | |||
| c3d0c80ca1 | |||
| 75ccba6e52 | |||
| 8a38489e53 | |||
| c301052789 | |||
| 38f7c77fa4 | |||
| 44762a38c2 | |||
| 7938a2bc5b | |||
| cfaae1b8a1 | |||
| cbc4348727 | |||
| c888e3c46e | |||
| 4fa7bf0128 | |||
| eec57e84e4 | |||
| dac9c1953a | |||
| 70149ea05e | |||
| 1d93f41fcf | |||
| 1ecf581ba1 | |||
| 1584a78fc9 | |||
| 5a5ee0db26 | |||
| cef7fd484b | |||
| 3ef37b24bb | |||
| 88c3286bd4 | |||
| 04f40303fd | |||
| ececc5ec2c | |||
| a29da52562 | |||
| 3533ed0fdd | |||
| e34513cbc4 | |||
| c976c1cb2d | |||
| 9358786c7f | |||
| 1629e22d98 | |||
| 8fb50d52dc | |||
| dc8a618b6a | |||
| d9a0d6caa8 | |||
| f3e7fea628 | |||
| 7c78d42627 | |||
| 13346f2874 | |||
| d330aaf57a | |||
| 9758e36f7c | |||
| 5ad435ef32 | |||
| 6ed630db15 | |||
| 256c5231f2 | |||
| 344e0a3318 | |||
| 59772c2493 | |||
| f828c0e754 | |||
| 960b0707c8 | |||
| 03938be789 | |||
| a6b94f11e5 | |||
| e83635ee5a | |||
| d79372a46d | |||
| bbd11c9e89 | |||
| 152fd52cd7 | |||
| ccabdbc83b | |||
| 56c8221b3f | |||
| d132abcdb4 | |||
| d60348572e | |||
| 3db700db34 | |||
| 0982cf6018 | |||
| 9aee14f4f8 | |||
| 10749726eb | |||
| f55faae31b | |||
| cd63bd48bd | |||
| 9a5a06bb40 | |||
| 0cff94d90e | |||
| 8408975036 | |||
| e01fae9151 | |||
| 46a8f06b6c | |||
| 0fb97ee9e9 | |||
| b033002d9f | |||
| 130163ca65 | |||
| 9d3eaefcdd | |||
| 5c51d61049 | |||
| ba0a59f998 | |||
| 7fc25cafb2 | |||
| a0fde9f012 | |||
| a7859de625 | |||
| d1b4bb247a | |||
| 91e23027f1 | |||
| 3b3b16eb00 | |||
| 8cd69899b4 | |||
| 710c17ed59 | |||
| 8d884aad29 | |||
| 7d7cce04eb | |||
| 88969a609b | |||
| 00af19c7f2 | |||
| 79d7dcaad2 | |||
| 1a2f37e7c7 | |||
| bffd67f3a4 | |||
| 0bbf8f72b1 | |||
| 047ea8c143 | |||
| f54b9b12b0 | |||
| cb99b8f04d | |||
| 7c03bcba2b | |||
| 92fa7271ed | |||
| 1fcf6e4943 | |||
| d3486cab31 | |||
| f4a7efde3d | |||
| 38d4f0fd96 | |||
| ec4f885dad | |||
| 3781c2a025 | |||
| 3782f17dc7 | |||
| 966b586200 | |||
| 29698aeed2 | |||
| 15ff8efb15 | |||
| 407e1c8276 | |||
| 03cc3868ef | |||
| fdd21ac815 | |||
| 44524760ee | |||
| e368825c21 | |||
| dd0a870969 | |||
| 0c4c268003 | |||
| 5d40064f12 | |||
| 8dad6b6a6d | |||
| fd9f5c0fc8 | |||
| 2f54965a72 | |||
| e5359ba136 | |||
| a1a3fa0283 | |||
| ff7344f3d3 | |||
| bcd33be22a | |||
| c1184789e2 | |||
| ff57848268 | |||
| d223fee9b9 | |||
| ad18d084f3 | |||
| 9941d1f160 | |||
| dab30117da | |||
| 3f7d46358c | |||
| 890b5d222f | |||
| fefcb1e959 | |||
| da6fa55eed | |||
| ad438740c4 | |||
| 231ecc1bfe | |||
| 790ed0845e | |||
| 309875650d | |||
| fc260fab97 | |||
| 92fa87e729 | |||
| 13fa56b5b1 | |||
| f6accd8ae2 | |||
| e66ef9145b | |||
| 64f5e34096 | |||
| 3f3b9beeff | |||
| 9ce48b4dc4 | |||
| da9a28b9e2 | |||
| cd09d27f11 | |||
| 912ca2bcfe | |||
| dd949a23e1 | |||
| dcd95632cb | |||
| 75113163a2 | |||
| 66713091e2 | |||
| 00cc50c659 | |||
| b459af07ae | |||
| c93f7f6f2c | |||
| a740fe2292 | |||
| ba2486d961 | |||
| e471884dcc | |||
| ff762e9d84 | |||
| e09174d7d7 | |||
| 016663bf44 | |||
| 824618c2ef | |||
| 918c187c50 | |||
| 0133843164 | |||
| c49127540f | |||
| fef5d88f59 | |||
| ba887993aa | |||
| 42b6e32574 | |||
| 1f91a971a8 | |||
| e487e1f622 | |||
| b3010e35cb | |||
| 5891731ab2 | |||
| 383971d56f | |||
| 840d45d407 | |||
| ff40efdc26 | |||
| c6ae13f67b | |||
| 55dd5f71d7 | |||
| c05d9bd813 | |||
| 0404007982 | |||
| 71c20ef3c8 | |||
| 7171eaf0b5 | |||
| b16f87c9b6 | |||
| 5bbc626b5e | |||
| abb2b860f2 | |||
| f7a9aadc98 | |||
| bacc9a7970 | |||
| be94274fbd | |||
| 4bc230eb13 | |||
| 4d596ad231 | |||
| de9fae63b8 | |||
| 67d8e5a7f0 | |||
| 997ff45e56 | |||
| 88508b8631 | |||
| eb525b697f | |||
| 8bf9eee91e | |||
| bac0513b8b | |||
| e95b7b57c9 | |||
| d4a90c43a4 | |||
| 0d2dda7e77 | |||
| a44a03696e | |||
| 33b1aaf182 | |||
| 930c36e757 | |||
| 91dad285fa | |||
| 2d2ce5df85 | |||
| 545a34fbaf | |||
| 61a6c6dbcf | |||
| 2b23c43434 | |||
| 0fb339ca4f | |||
| c1871e67aa | |||
| f711f9a317 | |||
| 9ff3310cb6 | |||
| b6bdcc7052 | |||
| 67b0771081 | |||
| 9a07488da9 | |||
| ef043c6906 | |||
| ab814e3eac | |||
| a0e1eeb3f1 | |||
| b1ebeb67a7 | |||
| 082179f70f | |||
| 8786ebdbca | |||
| b49a4eab62 | |||
| 0a7b59f500 | |||
| c264d9152f | |||
| 3bf9d898c0 | |||
| a7f2849e74 | |||
| 0957ece92f | |||
| 949bf38d3c | |||
| 7bafb7f959 | |||
| 9735f55ca4 | |||
| 4c1f9b949b | |||
| 0af0c94dde | |||
| 8e4f0640cc | |||
| 1f513e3b43 | |||
| aa0841e2a8 | |||
| b6a1562357 | |||
| bee0797401 | |||
| e085f39c13 | |||
| 344844d3e0 | |||
| 6e9f82491d | |||
| 372b1c3db8 | |||
| 58d305dbed | |||
| 0360a0416b | |||
| 72282b6e8f | |||
| 8391884c4e | |||
| b018f2b0a0 | |||
| 754f1a3cfa | |||
| b22c28b099 | |||
| ab56b4a818 | |||
| cd9e28dbf4 | |||
| 04f9637b6f | |||
| b8a29bfb35 | |||
| 5e2b0d7b39 | |||
| b483d5fad5 | |||
| 04196288f8 | |||
| cc349e70b1 | |||
| 50bdbfae69 | |||
| 2f45673694 | |||
| b5fb55069b | |||
| 7ba9d30775 | |||
| e69b588bad | |||
| aadac22ce4 | |||
| d12015c722 | |||
| 2641326432 | |||
| 20109553b9 | |||
| 0e1444d17c | |||
| 65d376bdae | |||
| e3c1310afa | |||
| 38da19a729 | |||
| 91110499dd | |||
| 4dca9a12a8 | |||
| 3e448f0102 | |||
| ca75a1c9a3 | |||
| 61ebc756aa | |||
| 4bea38042a | |||
| 337abc536b | |||
| cc02b78aca | |||
| 18f2d24f8e | |||
| 0c7b9a462f | |||
| 4dd5580854 | |||
| 440bd825d8 | |||
| d2379c38bd | |||
| cbc55c577b | |||
| 8e962d15d1 | |||
| b07c766551 | |||
| 9e3dd69277 | |||
| db9e5665c2 | |||
| cad77ce0bf | |||
| 6f4518ebf7 | |||
| a8f5748dee | |||
| 738d3001be | |||
| df4e32aaa0 | |||
| a25e37a96d | |||
| f156b46705 | |||
| 3b64e118d0 | |||
| 566cd20849 | |||
| db5c51ffc5 | |||
| df76527f29 | |||
| dac13ae604 | |||
| 53a80a5dbe | |||
| 1507792a0c | |||
| 00b9bbff75 | |||
| e1f8b4b387 | |||
| 1539d86f7d | |||
| 67bb14d3ee | |||
| 5653309080 | |||
| 0f52b34b61 | |||
| 75e35857c1 | |||
| 4f81be70e3 | |||
| 1d4d627d05 | |||
| 2357234f39 | |||
| a3f7d8f996 | |||
| 56f12e70c1 | |||
| b14afda160 | |||
| 44b4948972 | |||
| 487eac3b91 | |||
| 84b2913cd9 | |||
| 176d810c8d | |||
| 8b9a9d0574 | |||
| 9e66564526 | |||
| 781a9a56cd | |||
| a49321775c | |||
| 0c83f62848 | |||
| 1ddd4bc549 | |||
| 1d16528dff | |||
| c0ed353c10 | |||
| 93be1219eb | |||
| 3276d6429d | |||
| 50072a63ae | |||
| 1ab7e1cba8 | |||
| b0aef35c63 | |||
| ac351b700c | |||
| d1e5d30ea9 | |||
| c73e84d992 | |||
| 09998612e7 | |||
| f71ad55d58 | |||
| 5b81397054 | |||
| e056e0835a | |||
| e1819fb7e5 | |||
| 0360f0b33b | |||
| 560fe8a0f6 | |||
| da27d261b0 | |||
| c3e3a18ab4 | |||
| ab34cea714 | |||
| db0780cfa8 | |||
| 2b51fc23d9 | |||
| 5f0bd5119a | |||
| 8353352bda | |||
| 73845cbec5 | |||
| c2f94e9e8a | |||
| e54efda36f | |||
| d4bd19f6d8 | |||
| 4decbbbf18 | |||
| b15867f92e | |||
| a5e5fbc6e0 | |||
| 1b1471b6d8 | |||
| 5280bffde2 | |||
| db0fc94b39 |
@ -5,5 +5,18 @@
|
||||
"typescript-lsp@claude-plugins-official": true,
|
||||
"pyright-lsp@claude-plugins-official": true,
|
||||
"ralph-loop@claude-plugins-official": true
|
||||
},
|
||||
"hooks": {
|
||||
"PreToolUse": [
|
||||
{
|
||||
"matcher": "Bash",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "npx -y block-no-verify@1.1.1"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
6
.github/workflows/api-tests.yml
vendored
6
.github/workflows/api-tests.yml
vendored
@ -39,12 +39,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run pyrefly check
|
||||
run: |
|
||||
cd api
|
||||
uv add --dev pyrefly
|
||||
uv run pyrefly check || true
|
||||
|
||||
- name: Run dify config tests
|
||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||
|
||||
|
||||
29
.github/workflows/deploy-hitl.yml
vendored
Normal file
29
.github/workflows/deploy-hitl.yml
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
name: Deploy HITL
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "feat/hitl-frontend"
|
||||
- "feat/hitl-backend"
|
||||
types:
|
||||
- completed
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
(
|
||||
github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
|
||||
github.event.workflow_run.head_branch == 'feat/hitl-backend'
|
||||
)
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.HITL_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
|
||||
@ -589,6 +589,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
|
||||
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
|
||||
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
|
||||
ENABLE_CLEAN_MESSAGES=false
|
||||
ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
||||
ENABLE_DATASETS_QUEUE_MONITOR=false
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
@ -34,7 +35,7 @@ from libs.rsa import generate_key_pair
|
||||
from models import Tenant
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
|
||||
from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||
@ -45,6 +46,7 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
|
||||
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -62,8 +64,10 @@ def reset_password(email, new_password, password_confirm):
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style("Passwords do not match.", fg="red"))
|
||||
return
|
||||
normalized_email = email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
@ -84,7 +88,7 @@ def reset_password(email, new_password, password_confirm):
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
@ -100,20 +104,22 @@ def reset_email(email, new_email, email_confirm):
|
||||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style("New emails do not match.", fg="red"))
|
||||
return
|
||||
normalized_new_email = new_email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(new_email)
|
||||
email_validate(normalized_new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
account.email = new_email
|
||||
account.email = normalized_new_email
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
@ -658,7 +664,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
||||
return
|
||||
|
||||
# Create account
|
||||
email = email.strip()
|
||||
email = email.strip().lower()
|
||||
|
||||
if "@" not in email:
|
||||
click.echo(click.style("Invalid email address.", fg="red"))
|
||||
@ -852,6 +858,61 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
|
||||
click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.")
|
||||
@click.option("--days", default=30, show_default=True, help="Delete workflow runs created before N days ago.")
|
||||
@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
|
||||
)
|
||||
@click.option(
|
||||
"--dry-run",
|
||||
is_flag=True,
|
||||
help="Preview cleanup results without deleting any workflow run data.",
|
||||
)
|
||||
def clean_workflow_runs(
|
||||
days: int,
|
||||
batch_size: int,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
Clean workflow runs and related workflow data for free tenants.
|
||||
"""
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
|
||||
|
||||
WorkflowRunCleanup(
|
||||
days=days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
).run()
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Workflow run cleanup completed. start={start_time.isoformat()} "
|
||||
f"end={end_time.isoformat()} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
|
||||
@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
|
||||
def clear_orphaned_file_records(force: bool):
|
||||
|
||||
@ -959,6 +959,16 @@ class MailConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
ENABLE_TRIAL_APP: bool = Field(
|
||||
description="Enable trial app",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ENABLE_EXPLORE_BANNER: bool = Field(
|
||||
description="Enable explore banner",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class RagEtlConfig(BaseSettings):
|
||||
"""
|
||||
@ -1101,6 +1111,10 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable clean messages task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_WORKFLOW_RUN_CLEANUP_TASK: bool = Field(
|
||||
description="Enable scheduled workflow run cleanup task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
|
||||
description="Enable mail clean document notify task",
|
||||
default=False,
|
||||
|
||||
@ -107,10 +107,12 @@ from .datasets.rag_pipeline import (
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
banner,
|
||||
installed_app,
|
||||
parameter,
|
||||
recommended_app,
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
|
||||
# Import tag controllers
|
||||
@ -145,6 +147,7 @@ __all__ = [
|
||||
"apikey",
|
||||
"app",
|
||||
"audio",
|
||||
"banner",
|
||||
"billing",
|
||||
"bp",
|
||||
"completion",
|
||||
@ -198,6 +201,7 @@ __all__ = [
|
||||
"statistic",
|
||||
"tags",
|
||||
"tool_providers",
|
||||
"trial",
|
||||
"trigger_providers",
|
||||
"version",
|
||||
"website",
|
||||
|
||||
@ -15,7 +15,7 @@ from controllers.console.wraps import only_edition_cloud
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@ -32,6 +32,8 @@ class InsertExploreAppPayload(BaseModel):
|
||||
language: str = Field(...)
|
||||
category: str = Field(...)
|
||||
position: int = Field(...)
|
||||
can_trial: bool = Field(default=False)
|
||||
trial_limit: int = Field(default=0)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
@ -39,11 +41,33 @@ class InsertExploreAppPayload(BaseModel):
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
class InsertExploreBannerPayload(BaseModel):
|
||||
category: str = Field(...)
|
||||
title: str = Field(...)
|
||||
description: str = Field(...)
|
||||
img_src: str = Field(..., alias="img-src")
|
||||
language: str = Field(default="en-US")
|
||||
link: str = Field(...)
|
||||
sort: int = Field(...)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
InsertExploreAppPayload.__name__,
|
||||
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
InsertExploreBannerPayload.__name__,
|
||||
InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def admin_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
@ -109,6 +133,20 @@ class InsertExploreAppListApi(Resource):
|
||||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
if payload.can_trial:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=payload.app_id,
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=payload.trial_limit,
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = payload.trial_limit
|
||||
|
||||
app.is_public = True
|
||||
db.session.commit()
|
||||
@ -123,6 +161,20 @@ class InsertExploreAppListApi(Resource):
|
||||
recommended_app.category = payload.category
|
||||
recommended_app.position = payload.position
|
||||
|
||||
if payload.can_trial:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=payload.app_id,
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=payload.trial_limit,
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = payload.trial_limit
|
||||
app.is_public = True
|
||||
|
||||
db.session.commit()
|
||||
@ -168,7 +220,62 @@ class InsertExploreAppApi(Resource):
|
||||
for installed_app in installed_apps:
|
||||
session.delete(installed_app)
|
||||
|
||||
trial_app = session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
|
||||
).scalar_one_or_none()
|
||||
if trial_app:
|
||||
session.delete(trial_app)
|
||||
|
||||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-banner")
|
||||
class InsertExploreBannerApi(Resource):
|
||||
@console_ns.doc("insert_explore_banner")
|
||||
@console_ns.doc(description="Insert an explore banner")
|
||||
@console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__])
|
||||
@console_ns.response(201, "Banner inserted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
|
||||
|
||||
content = {
|
||||
"category": payload.category,
|
||||
"title": payload.title,
|
||||
"description": payload.description,
|
||||
"img-src": payload.img_src,
|
||||
}
|
||||
|
||||
banner = ExporleBanner(
|
||||
content=content,
|
||||
link=payload.link,
|
||||
sort=payload.sort,
|
||||
language=payload.language,
|
||||
)
|
||||
db.session.add(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-banner/<uuid:banner_id>")
|
||||
class DeleteExploreBannerApi(Resource):
|
||||
@console_ns.doc("delete_explore_banner")
|
||||
@console_ns.doc(description="Delete an explore banner")
|
||||
@console_ns.doc(params={"banner_id": "Banner ID to delete"})
|
||||
@console_ns.response(204, "Banner deleted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, banner_id):
|
||||
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
|
||||
if not banner:
|
||||
raise NotFound(f"Banner '{banner_id}' is not found")
|
||||
|
||||
db.session.delete(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@ -272,6 +272,7 @@ class AnnotationExportApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response_data = {"data": marshal(annotation_list, annotation_fields)}
|
||||
@ -359,7 +360,6 @@ class AnnotationBatchImportApi(Resource):
|
||||
file.seek(0, 2) # Seek to end of file
|
||||
file_size = file.tell()
|
||||
file.seek(0) # Reset to beginning
|
||||
|
||||
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
|
||||
if file_size > max_size_bytes:
|
||||
abort(
|
||||
|
||||
@ -115,3 +115,9 @@ class InvokeRateLimitError(BaseHTTPException):
|
||||
error_code = "rate_limit_error"
|
||||
description = "Rate Limit Error"
|
||||
code = 429
|
||||
|
||||
|
||||
class NeedAddIdsError(BaseHTTPException):
|
||||
error_code = "need_add_ids"
|
||||
description = "Need to add ids."
|
||||
code = 400
|
||||
|
||||
@ -202,6 +202,7 @@ message_detail_model = console_ns.model(
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"generation_detail": fields.Raw,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -23,6 +23,11 @@ def _load_app_model(app_id: str) -> App | None:
|
||||
return app_model
|
||||
|
||||
|
||||
def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
|
||||
return app_model
|
||||
|
||||
|
||||
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P1, R1]):
|
||||
@wraps(view_func)
|
||||
@ -62,3 +67,44 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
app_id = kwargs.get("app_id")
|
||||
app_id = str(app_id)
|
||||
|
||||
del kwargs["app_id"]
|
||||
|
||||
app_model = _load_app_model_with_trial(app_id)
|
||||
|
||||
if not app_model:
|
||||
raise AppNotFoundError()
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
if mode is not None:
|
||||
if isinstance(mode, list):
|
||||
modes = mode
|
||||
else:
|
||||
modes = [mode]
|
||||
|
||||
if app_mode not in modes:
|
||||
mode_values = {m.value for m in modes}
|
||||
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
|
||||
|
||||
kwargs["app_model"] = app_model
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
@ -63,10 +63,9 @@ class ActivateCheckApi(Resource):
|
||||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workspaceId = args.workspace_id
|
||||
reg_email = args.email
|
||||
token = args.token
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
|
||||
if invitation:
|
||||
data = invitation.get("data", {})
|
||||
tenant = invitation.get("tenant", None)
|
||||
@ -100,11 +99,12 @@ class ActivateApi(Resource):
|
||||
def post(self):
|
||||
args = ActivatePayload.model_validate(console_ns.payload)
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
|
||||
normalized_request_email = args.email.lower() if args.email else None
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
|
||||
if invitation is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
|
||||
RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
|
||||
|
||||
account = invitation["account"]
|
||||
account.name = args.name
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
@ -62,6 +61,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
@ -70,13 +70,12 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
if args.language in languages:
|
||||
language = args.language
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
token = None
|
||||
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@ -88,9 +87,9 @@ class EmailRegisterCheckApi(Resource):
|
||||
def post(self):
|
||||
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args.email
|
||||
user_email = args.email.lower()
|
||||
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email)
|
||||
if is_email_register_error_rate_limit:
|
||||
raise EmailRegisterLimitError()
|
||||
|
||||
@ -98,11 +97,14 @@ class EmailRegisterCheckApi(Resource):
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_email_register_error_rate_limit(args.email)
|
||||
AccountService.add_email_register_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
@ -113,8 +115,8 @@ class EmailRegisterCheckApi(Resource):
|
||||
user_email, code=args.code, additional_data={"phase": "register"}
|
||||
)
|
||||
|
||||
AccountService.reset_email_register_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_email_register_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/email-register")
|
||||
@ -141,22 +143,23 @@ class EmailRegisterResetApi(Resource):
|
||||
AccountService.revoke_email_register_token(args.token)
|
||||
|
||||
email = register_data.get("email", "")
|
||||
normalized_email = email.lower()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(email, args.password_confirm)
|
||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
def _create_new_account(self, email, password) -> Account | None:
|
||||
def _create_new_account(self, email: str, password: str) -> Account | None:
|
||||
# Create new account if allowed
|
||||
account = None
|
||||
try:
|
||||
|
||||
@ -4,7 +4,6 @@ import secrets
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -76,6 +74,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
@ -87,11 +86,11 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
account=account,
|
||||
email=args.email,
|
||||
email=normalized_email,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
)
|
||||
@ -122,9 +121,9 @@ class ForgotPasswordCheckApi(Resource):
|
||||
def post(self):
|
||||
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args.email
|
||||
user_email = args.email.lower()
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
@ -132,11 +131,16 @@ class ForgotPasswordCheckApi(Resource):
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args.email)
|
||||
AccountService.add_forgot_password_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
@ -144,11 +148,11 @@ class ForgotPasswordCheckApi(Resource):
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=args.code, additional_data={"phase": "reset"}
|
||||
token_email, code=args.code, additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_forgot_password_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password/resets")
|
||||
@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource):
|
||||
password_hashed = hash_password(args.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
|
||||
@ -90,32 +90,38 @@ class LoginApi(Resource):
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
args = LoginPayload.model_validate(console_ns.payload)
|
||||
request_email = args.email
|
||||
normalized_email = request_email.lower()
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
|
||||
if is_login_error_rate_limit:
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
invite_token = args.invite_token
|
||||
invitation_data: dict[str, Any] | None = None
|
||||
if args.invite_token:
|
||||
invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token)
|
||||
if invite_token:
|
||||
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
|
||||
if invitation_data is None:
|
||||
invite_token = None
|
||||
|
||||
try:
|
||||
if invitation_data:
|
||||
data = invitation_data.get("data", {})
|
||||
invitee_email = data.get("email") if data else None
|
||||
if invitee_email != args.email:
|
||||
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
|
||||
if invitee_email_normalized != normalized_email:
|
||||
raise InvalidEmailError()
|
||||
account = AccountService.authenticate(args.email, args.password, args.invite_token)
|
||||
else:
|
||||
account = AccountService.authenticate(args.email, args.password)
|
||||
account = _authenticate_account_with_case_fallback(
|
||||
request_email, normalized_email, args.password, invite_token
|
||||
)
|
||||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
AccountService.add_login_error_rate_limit(args.email)
|
||||
raise AuthenticationFailedError()
|
||||
except services.errors.account.AccountPasswordError as exc:
|
||||
AccountService.add_login_error_rate_limit(normalized_email)
|
||||
raise AuthenticationFailedError() from exc
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if len(tenants) == 0:
|
||||
@ -130,7 +136,7 @@ class LoginApi(Resource):
|
||||
}
|
||||
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
@ -170,18 +176,19 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
account = _get_account_with_case_fallback(args.email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
email=args.email,
|
||||
email=normalized_email,
|
||||
account=account,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
@ -196,6 +203,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
@ -206,13 +214,13 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
account = _get_account_with_case_fallback(args.email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_email_code_login_email(email=args.email, language=language)
|
||||
token = AccountService.send_email_code_login_email(email=normalized_email, language=language)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
@ -229,14 +237,17 @@ class EmailCodeLoginApi(Resource):
|
||||
def post(self):
|
||||
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args.email
|
||||
original_email = args.email
|
||||
user_email = original_email.lower()
|
||||
language = args.language
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if token_data["email"] != args.email:
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != user_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args.code:
|
||||
@ -244,7 +255,7 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
AccountService.revoke_email_code_login_token(args.token)
|
||||
try:
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
account = _get_account_with_case_fallback(original_email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
@ -275,7 +286,7 @@ class EmailCodeLoginApi(Resource):
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
AccountService.reset_login_error_rate_limit(user_email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
@ -309,3 +320,22 @@ class RefreshTokenApi(Resource):
|
||||
return response
|
||||
except Exception as e:
|
||||
return {"result": "fail", "message": str(e)}, 401
|
||||
|
||||
|
||||
def _get_account_with_case_fallback(email: str):
|
||||
account = AccountService.get_user_through_email(email)
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return AccountService.get_user_through_email(email.lower())
|
||||
|
||||
|
||||
def _authenticate_account_with_case_fallback(
|
||||
original_email: str, normalized_email: str, password: str, invite_token: str | None
|
||||
):
|
||||
try:
|
||||
return AccountService.authenticate(original_email, password, invite_token)
|
||||
except services.errors.account.AccountPasswordError:
|
||||
if original_email == normalized_email:
|
||||
raise
|
||||
return AccountService.authenticate(normalized_email, password, invite_token)
|
||||
|
||||
@ -3,7 +3,6 @@ import logging
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
@ -118,7 +117,10 @@ class OAuthCallback(Resource):
|
||||
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
||||
if invitation:
|
||||
invitation_email = invitation.get("email", None)
|
||||
if invitation_email != user_info.email:
|
||||
invitation_email_normalized = (
|
||||
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
|
||||
)
|
||||
if invitation_email_normalized != user_info.email.lower():
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||
@ -159,10 +161,7 @@ class OAuthCallback(Resource):
|
||||
ip_address=extract_remote_ip(request),
|
||||
)
|
||||
|
||||
base_url = dify_config.CONSOLE_WEB_URL
|
||||
query_char = "&" if "?" in base_url else "?"
|
||||
target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
|
||||
response = redirect(target_url)
|
||||
response = redirect(f"{dify_config.CONSOLE_WEB_URL}?oauth_new_user={str(oauth_new_user).lower()}")
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
@ -175,7 +174,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
|
||||
if not account:
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
|
||||
|
||||
return account
|
||||
|
||||
@ -197,9 +196,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
|
||||
tenant_was_created.send(new_tenant)
|
||||
|
||||
if not account:
|
||||
normalized_email = user_info.email.lower()
|
||||
oauth_new_user = True
|
||||
if not FeatureService.get_system_features().is_allow_register:
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountRegisterError(
|
||||
description=(
|
||||
"This email account has been deleted within the past "
|
||||
@ -210,7 +210,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
|
||||
raise AccountRegisterError(description=("Invalid email or password"))
|
||||
account_name = user_info.name or "Dify"
|
||||
account = RegisterService.register(
|
||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||
email=normalized_email,
|
||||
name=account_name,
|
||||
password=None,
|
||||
open_id=user_info.id,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
# Set interface language
|
||||
|
||||
@ -146,6 +146,7 @@ class DatasetUpdatePayload(BaseModel):
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
summary_index_setting: dict[str, Any] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
|
||||
@ -39,9 +39,10 @@ from fields.document_fields import (
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from models.dataset import DocumentPipelineExecutionLog, DocumentSegmentSummary
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
from ..app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
@ -104,6 +105,10 @@ class DocumentRenamePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class GenerateSummaryPayload(BaseModel):
|
||||
document_list: list[str]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
KnowledgeConfig,
|
||||
@ -111,6 +116,7 @@ register_schema_models(
|
||||
RetrievalModel,
|
||||
DocumentRetryPayload,
|
||||
DocumentRenamePayload,
|
||||
GenerateSummaryPayload,
|
||||
)
|
||||
|
||||
|
||||
@ -295,6 +301,97 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
# Check if dataset has summary index enabled
|
||||
has_summary_index = (
|
||||
dataset.summary_index_setting
|
||||
and dataset.summary_index_setting.get("enable") is True
|
||||
)
|
||||
|
||||
# Filter documents that need summary calculation
|
||||
documents_need_summary = [doc for doc in documents if doc.need_summary is True]
|
||||
document_ids_need_summary = [str(doc.id) for doc in documents_need_summary]
|
||||
|
||||
# Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled)
|
||||
summary_status_map = {}
|
||||
if has_summary_index and document_ids_need_summary:
|
||||
# Get all segments for these documents (excluding qa_model and re_segment)
|
||||
segments = (
|
||||
db.session.query(DocumentSegment.id, DocumentSegment.document_id)
|
||||
.where(
|
||||
DocumentSegment.document_id.in_(document_ids_need_summary),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.tenant_id == current_tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Group segments by document_id
|
||||
document_segments_map = {}
|
||||
for segment in segments:
|
||||
doc_id = str(segment.document_id)
|
||||
if doc_id not in document_segments_map:
|
||||
document_segments_map[doc_id] = []
|
||||
document_segments_map[doc_id].append(segment.id)
|
||||
|
||||
# Get all summary records for these segments
|
||||
all_segment_ids = [seg.id for seg in segments]
|
||||
summaries = {}
|
||||
if all_segment_ids:
|
||||
summary_records = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id.in_(all_segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only count enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
summaries = {summary.chunk_id: summary.status for summary in summary_records}
|
||||
|
||||
# Calculate summary_index_status for each document
|
||||
for doc_id in document_ids_need_summary:
|
||||
segment_ids = document_segments_map.get(doc_id, [])
|
||||
if not segment_ids:
|
||||
# No segments, status is "GENERATING" (waiting to generate)
|
||||
summary_status_map[doc_id] = "GENERATING"
|
||||
continue
|
||||
|
||||
# Count summary statuses for this document's segments
|
||||
status_counts = {"completed": 0, "generating": 0, "error": 0, "not_started": 0}
|
||||
for segment_id in segment_ids:
|
||||
status = summaries.get(segment_id, "not_started")
|
||||
if status in status_counts:
|
||||
status_counts[status] += 1
|
||||
else:
|
||||
status_counts["not_started"] += 1
|
||||
|
||||
total_segments = len(segment_ids)
|
||||
completed_count = status_counts["completed"]
|
||||
generating_count = status_counts["generating"]
|
||||
error_count = status_counts["error"]
|
||||
|
||||
# Determine overall status (only three states: GENERATING, COMPLETED, ERROR)
|
||||
if completed_count == total_segments:
|
||||
summary_status_map[doc_id] = "COMPLETED"
|
||||
elif error_count > 0:
|
||||
# Has errors (even if some are completed or generating)
|
||||
summary_status_map[doc_id] = "ERROR"
|
||||
elif generating_count > 0 or status_counts["not_started"] > 0:
|
||||
# Still generating or not started
|
||||
summary_status_map[doc_id] = "GENERATING"
|
||||
else:
|
||||
# Default to generating
|
||||
summary_status_map[doc_id] = "GENERATING"
|
||||
|
||||
# Add summary_index_status to each document
|
||||
for document in documents:
|
||||
if has_summary_index and document.need_summary is True:
|
||||
document.summary_index_status = summary_status_map.get(str(document.id), "GENERATING")
|
||||
else:
|
||||
# Return null if summary index is not enabled or document doesn't need summary
|
||||
document.summary_index_status = None
|
||||
|
||||
if fetch:
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
@ -393,6 +490,7 @@ class DatasetDocumentListApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
|
||||
@console_ns.route("/datasets/init")
|
||||
class DatasetInitApi(Resource):
|
||||
@console_ns.doc("init_dataset")
|
||||
@ -780,6 +878,7 @@ class DocumentApi(DocumentResource):
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
@ -815,6 +914,7 @@ class DocumentApi(DocumentResource):
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
|
||||
return response, 200
|
||||
@ -1182,3 +1282,211 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
"input_data": log.input_data,
|
||||
"datasource_node_id": log.datasource_node_id,
|
||||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/generate-summary")
|
||||
class DocumentGenerateSummaryApi(Resource):
|
||||
@console_ns.doc("generate_summary_for_documents")
|
||||
@console_ns.doc(description="Generate summary index for documents")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
|
||||
@console_ns.response(200, "Summary generation started successfully")
|
||||
@console_ns.response(400, "Invalid request or dataset configuration")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
"""
|
||||
Generate summary index for specified documents.
|
||||
|
||||
This endpoint checks if the dataset configuration supports summary generation
|
||||
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
|
||||
then asynchronously generates summary indexes for the provided documents.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check permissions
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Validate request payload
|
||||
payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
|
||||
document_list = payload.document_list
|
||||
|
||||
if not document_list:
|
||||
raise ValueError("document_list cannot be empty.")
|
||||
|
||||
# Check if dataset configuration supports summary generation
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
raise ValueError(
|
||||
f"Summary generation is only available for 'high_quality' indexing technique. "
|
||||
f"Current indexing technique: {dataset.indexing_technique}"
|
||||
)
|
||||
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
raise ValueError(
|
||||
"Summary index is not enabled for this dataset. "
|
||||
"Please enable it in the dataset settings."
|
||||
)
|
||||
|
||||
# Verify all documents exist and belong to the dataset
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.id.in_(document_list),
|
||||
Document.dataset_id == dataset_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(documents) != len(document_list):
|
||||
found_ids = {doc.id for doc in documents}
|
||||
missing_ids = set(document_list) - found_ids
|
||||
raise NotFound(f"Some documents not found: {list(missing_ids)}")
|
||||
|
||||
# Dispatch async tasks for each document
|
||||
for document in documents:
|
||||
# Skip qa_model documents as they don't generate summaries
|
||||
if document.doc_form == "qa_model":
|
||||
logger.info(
|
||||
f"Skipping summary generation for qa_model document {document.id}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Dispatch async task
|
||||
generate_summary_index_task(dataset_id, document.id)
|
||||
logger.info(
|
||||
f"Dispatched summary generation task for document {document.id} in dataset {dataset_id}"
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/summary-status")
|
||||
class DocumentSummaryStatusApi(DocumentResource):
|
||||
@console_ns.doc("get_document_summary_status")
|
||||
@console_ns.doc(description="Get summary index generation status for a document")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Summary status retrieved successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""
|
||||
Get summary index generation status for a document.
|
||||
|
||||
Returns:
|
||||
- total_segments: Total number of segments in the document
|
||||
- summary_status: Dictionary with status counts
|
||||
- completed: Number of summaries completed
|
||||
- generating: Number of summaries being generated
|
||||
- error: Number of summaries with errors
|
||||
- not_started: Number of segments without summary records
|
||||
- summaries: List of summary records with status and content preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
# Get document
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check permissions
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Get all segments for this document
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
total_segments = len(segments)
|
||||
|
||||
# Get all summary records for these segments
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries = []
|
||||
if segment_ids:
|
||||
summaries = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
DocumentSegmentSummary.document_id == document_id,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create a mapping of chunk_id to summary
|
||||
summary_map = {summary.chunk_id: summary for summary in summaries}
|
||||
|
||||
# Count statuses
|
||||
status_counts = {
|
||||
"completed": 0,
|
||||
"generating": 0,
|
||||
"error": 0,
|
||||
"not_started": 0,
|
||||
}
|
||||
|
||||
summary_list = []
|
||||
for segment in segments:
|
||||
summary = summary_map.get(segment.id)
|
||||
if summary:
|
||||
status = summary.status
|
||||
status_counts[status] = status_counts.get(status, 0) + 1
|
||||
summary_list.append({
|
||||
"segment_id": segment.id,
|
||||
"segment_position": segment.position,
|
||||
"status": summary.status,
|
||||
"summary_preview": summary.summary_content[:100] + "..." if summary.summary_content and len(summary.summary_content) > 100 else summary.summary_content,
|
||||
"error": summary.error,
|
||||
"created_at": int(summary.created_at.timestamp()) if summary.created_at else None,
|
||||
"updated_at": int(summary.updated_at.timestamp()) if summary.updated_at else None,
|
||||
})
|
||||
else:
|
||||
status_counts["not_started"] += 1
|
||||
summary_list.append({
|
||||
"segment_id": segment.id,
|
||||
"segment_position": segment.position,
|
||||
"status": "not_started",
|
||||
"summary_preview": None,
|
||||
"error": None,
|
||||
"created_at": None,
|
||||
"updated_at": None,
|
||||
})
|
||||
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"summary_status": status_counts,
|
||||
"summaries": summary_list,
|
||||
}, 200
|
||||
|
||||
@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.helper import escape_like_pattern
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.dataset import ChildChunk, DocumentSegment, DocumentSegmentSummary
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
|
||||
@ -41,6 +41,23 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
||||
def _get_segment_with_summary(segment, dataset_id):
|
||||
"""Helper function to marshal segment and add summary information."""
|
||||
segment_dict = marshal(segment, segment_fields)
|
||||
# Query summary for this segment (only enabled summaries)
|
||||
summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
)
|
||||
.first()
|
||||
)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
class SegmentListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
status: list[str] = Field(default_factory=list)
|
||||
@ -63,6 +80,7 @@ class SegmentUpdatePayload(BaseModel):
|
||||
keywords: list[str] | None = None
|
||||
regenerate_child_chunks: bool = False
|
||||
attachment_ids: list[str] | None = None
|
||||
summary: str | None = None # Summary content for summary index
|
||||
|
||||
|
||||
class BatchImportPayload(BaseModel):
|
||||
@ -180,8 +198,34 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
|
||||
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
# Query summaries for all segments in this page (batch query for efficiency)
|
||||
segment_ids = [segment.id for segment in segments.items]
|
||||
summaries = {}
|
||||
if segment_ids:
|
||||
summary_records = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
# Only include enabled summaries
|
||||
summaries = {
|
||||
summary.chunk_id: summary.summary_content
|
||||
for summary in summary_records
|
||||
if summary.enabled is True
|
||||
}
|
||||
|
||||
# Add summary to each segment
|
||||
segments_with_summary = []
|
||||
for segment in segments.items:
|
||||
segment_dict = marshal(segment, segment_fields)
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
segments_with_summary.append(segment_dict)
|
||||
|
||||
response = {
|
||||
"data": marshal(segments.items, segment_fields),
|
||||
"data": segments_with_summary,
|
||||
"limit": limit,
|
||||
"total": segments.total,
|
||||
"total_pages": segments.pages,
|
||||
@ -327,7 +371,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.create_segment(payload_dict, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
@ -389,10 +433,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
|
||||
# Update segment (summary update with change detection is handled in SegmentService.update_segment)
|
||||
segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
|
||||
)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Resource
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from libs.login import login_required
|
||||
@ -10,17 +10,56 @@ from ..wraps import (
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
)
|
||||
from fields.hit_testing_fields import (
|
||||
child_chunk_fields,
|
||||
document_fields,
|
||||
files_fields,
|
||||
hit_testing_record_fields,
|
||||
segment_fields,
|
||||
)
|
||||
|
||||
register_schema_model(console_ns, HitTestingPayload)
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
document_model = _get_or_create_model("HitTestingDocument", document_fields)
|
||||
|
||||
segment_fields_copy = segment_fields.copy()
|
||||
segment_fields_copy["document"] = fields.Nested(document_model)
|
||||
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
|
||||
|
||||
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
|
||||
files_model = _get_or_create_model("HitTestingFile", files_fields)
|
||||
|
||||
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
|
||||
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
|
||||
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
|
||||
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
|
||||
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
|
||||
|
||||
# Response model for hit testing API
|
||||
hit_testing_response_fields = {
|
||||
"query": fields.String,
|
||||
"records": fields.List(fields.Nested(hit_testing_record_model)),
|
||||
}
|
||||
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@console_ns.doc("test_dataset_retrieval")
|
||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
|
||||
@console_ns.response(200, "Hit testing completed successfully")
|
||||
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
|
||||
43
api/controllers/console/explore/banner.py
Normal file
43
api/controllers/console/explore/banner.py
Normal file
@ -0,0 +1,43 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import explore_banner_enabled
|
||||
from extensions.ext_database import db
|
||||
from models.model import ExporleBanner
|
||||
|
||||
|
||||
class BannerApi(Resource):
|
||||
"""Resource for banner list."""
|
||||
|
||||
@explore_banner_enabled
|
||||
def get(self):
|
||||
"""Get banner list."""
|
||||
language = request.args.get("language", "en-US")
|
||||
|
||||
# Build base query for enabled banners
|
||||
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
|
||||
|
||||
# Try to get banners in the requested language
|
||||
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
|
||||
|
||||
# Fallback to en-US if no banners found and language is not en-US
|
||||
if not banners and language != "en-US":
|
||||
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
|
||||
# Convert banners to serializable format
|
||||
result = []
|
||||
for banner in banners:
|
||||
banner_data = {
|
||||
"id": banner.id,
|
||||
"content": banner.content, # Already parsed as JSON by SQLAlchemy
|
||||
"link": banner.link,
|
||||
"sort": banner.sort,
|
||||
"status": banner.status,
|
||||
"created_at": banner.created_at.isoformat() if banner.created_at else None,
|
||||
}
|
||||
result.append(banner_data)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
api.add_resource(BannerApi, "/explore/banners")
|
||||
@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException):
|
||||
error_code = "access_denied"
|
||||
description = "App access denied."
|
||||
code = 403
|
||||
|
||||
|
||||
class TrialAppNotAllowed(BaseHTTPException):
|
||||
"""*403* `Trial App Not Allowed`
|
||||
|
||||
Raise if the user has reached the trial app limit.
|
||||
"""
|
||||
|
||||
error_code = "trial_app_not_allowed"
|
||||
code = 403
|
||||
description = "the app is not allowed to be trial."
|
||||
|
||||
|
||||
class TrialAppLimitExceeded(BaseHTTPException):
|
||||
"""*403* `Trial App Limit Exceeded`
|
||||
|
||||
Raise if the user has exceeded the trial app limit.
|
||||
"""
|
||||
|
||||
error_code = "trial_app_limit_exceeded"
|
||||
code = 403
|
||||
description = "The user has exceeded the trial app limit."
|
||||
|
||||
@ -29,6 +29,7 @@ recommended_app_fields = {
|
||||
"category": fields.String,
|
||||
"position": fields.Integer,
|
||||
"is_listed": fields.Boolean,
|
||||
"can_trial": fields.Boolean,
|
||||
}
|
||||
|
||||
recommended_app_list_fields = {
|
||||
|
||||
512
api/controllers/console/explore/trial.py
Normal file
512
api/controllers/console/explore/trial.py
Normal file
@ -0,0 +1,512 @@
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import Parameters as ParametersResponse
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
NeedAddIdsError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model_with_trial
|
||||
from controllers.console.explore.error import (
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
NotWorkflowAppError,
|
||||
)
|
||||
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields_with_site
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.account import TenantStatus
|
||||
from models.model import AppMode, Site
|
||||
from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_service import AppService
|
||||
from services.audio_service import AudioService
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.errors.message import (
|
||||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
def post(self, trial_app):
|
||||
"""
|
||||
Run workflow
|
||||
"""
|
||||
app_model = trial_app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
assert current_user is not None
|
||||
try:
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
def post(self, trial_app, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
"""
|
||||
app_model = trial_app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialMessageSuggestedQuestionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def get(self, trial_app, message_id):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
|
||||
|
||||
class TrialChatAudioApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||
)
|
||||
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialSitApi(Resource):
|
||||
"""Resource for trial app sites."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
def get(self, app_model):
|
||||
"""Retrieve app site info.
|
||||
|
||||
Returns the site configuration for the application including theme, icons, and text.
|
||||
"""
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
assert app_model.tenant
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return SiteResponse.model_validate(site).model_dump(mode="json")
|
||||
|
||||
|
||||
class TrialAppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
def get(self, app_model):
|
||||
"""Retrieve app parameters."""
|
||||
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return ParametersResponse.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.get_app(app_model)
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
class AppWorkflowApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(workflow_fields)
|
||||
def get(self, app_model):
|
||||
"""Get workflow detail"""
|
||||
if not app_model.workflow_id:
|
||||
raise AppUnavailableError()
|
||||
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.id == app_model.workflow_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return workflow
|
||||
|
||||
|
||||
class DatasetListApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
def get(self, app_model):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
ids = request.args.getlist("ids")
|
||||
|
||||
tenant_id = app_model.tenant_id
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id)
|
||||
else:
|
||||
raise NeedAddIdsError()
|
||||
|
||||
data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields))
|
||||
|
||||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
|
||||
|
||||
api.add_resource(
|
||||
TrialMessageSuggestedQuestionApi,
|
||||
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||
endpoint="trial_app_suggested_question",
|
||||
)
|
||||
|
||||
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
|
||||
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
|
||||
|
||||
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
|
||||
|
||||
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
|
||||
|
||||
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
|
||||
|
||||
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
|
||||
|
||||
api.add_resource(TrialAppWorkflowRunApi, "/trial-apps/<uuid:app_id>/workflows/run", endpoint="trial_app_workflow_run")
|
||||
api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
|
||||
|
||||
api.add_resource(AppWorkflowApi, "/trial-apps/<uuid:app_id>/workflows", endpoint="trial_app_workflow")
|
||||
api.add_resource(DatasetListApi, "/trial-apps/<uuid:app_id>/datasets", endpoint="trial_app_datasets")
|
||||
@ -2,14 +2,15 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import abort
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError
|
||||
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import InstalledApp
|
||||
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -71,6 +72,61 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[App, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
|
||||
|
||||
if trial_app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
app = trial_app.app
|
||||
|
||||
if app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
|
||||
account_trial_app_record = (
|
||||
db.session.query(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
|
||||
.first()
|
||||
)
|
||||
if account_trial_app_record:
|
||||
if account_trial_app_record.count >= trial_app.trial_limit:
|
||||
raise TrialAppLimitExceeded()
|
||||
|
||||
return view(app, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_trial_app:
|
||||
abort(403, "Trial app feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_explore_banner:
|
||||
abort(403, "Explore banner feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
class InstalledAppResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
|
||||
@ -80,3 +136,13 @@ class InstalledAppResource(Resource):
|
||||
account_initialization_required,
|
||||
login_required,
|
||||
]
|
||||
|
||||
|
||||
class TrialAppResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
|
||||
method_decorators = [
|
||||
trial_app_required,
|
||||
account_initialization_required,
|
||||
login_required,
|
||||
]
|
||||
|
||||
@ -84,10 +84,11 @@ class SetupApi(Resource):
|
||||
raise NotInitValidateError()
|
||||
|
||||
args = SetupRequestPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
# setup
|
||||
RegisterService.setup(
|
||||
email=args.email,
|
||||
email=normalized_email,
|
||||
name=args.name,
|
||||
password=args.password,
|
||||
ip_address=extract_remote_ip(request),
|
||||
|
||||
@ -41,7 +41,7 @@ from fields.member_fields import account_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Account, AccountIntegrate, InvitationCode
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
@ -536,7 +536,8 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
account = None
|
||||
user_email = args.email
|
||||
user_email = None
|
||||
email_for_sending = args.email.lower()
|
||||
if args.phase is not None and args.phase == "new_email":
|
||||
if args.token is None:
|
||||
raise InvalidTokenError()
|
||||
@ -546,16 +547,24 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
|
||||
if user_email != current_user.email:
|
||||
if user_email.lower() != current_user.email.lower():
|
||||
raise InvalidEmailError()
|
||||
|
||||
user_email = current_user.email
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
email_for_sending = account.email
|
||||
user_email = account.email
|
||||
|
||||
token = AccountService.send_change_email_email(
|
||||
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
|
||||
account=account,
|
||||
email=email_for_sending,
|
||||
old_email=user_email,
|
||||
language=language,
|
||||
phase=args.phase,
|
||||
)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@ -571,9 +580,9 @@ class ChangeEmailCheckApi(Resource):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailValidityPayload.model_validate(payload)
|
||||
|
||||
user_email = args.email
|
||||
user_email = args.email.lower()
|
||||
|
||||
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
|
||||
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email)
|
||||
if is_change_email_error_rate_limit:
|
||||
raise EmailChangeLimitError()
|
||||
|
||||
@ -581,11 +590,13 @@ class ChangeEmailCheckApi(Resource):
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_change_email_error_rate_limit(args.email)
|
||||
AccountService.add_change_email_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
@ -596,8 +607,8 @@ class ChangeEmailCheckApi(Resource):
|
||||
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
|
||||
)
|
||||
|
||||
AccountService.reset_change_email_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_change_email_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/reset")
|
||||
@ -611,11 +622,12 @@ class ChangeEmailResetApi(Resource):
|
||||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailResetPayload.model_validate(payload)
|
||||
normalized_new_email = args.new_email.lower()
|
||||
|
||||
if AccountService.is_account_in_freeze(args.new_email):
|
||||
if AccountService.is_account_in_freeze(normalized_new_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if not AccountService.check_email_unique(args.new_email):
|
||||
if not AccountService.check_email_unique(normalized_new_email):
|
||||
raise EmailAlreadyInUseError()
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
@ -626,13 +638,13 @@ class ChangeEmailResetApi(Resource):
|
||||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.email != old_email:
|
||||
if current_user.email.lower() != old_email.lower():
|
||||
raise AccountNotFound()
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=args.new_email)
|
||||
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
email=args.new_email,
|
||||
email=normalized_new_email,
|
||||
)
|
||||
|
||||
return updated_account
|
||||
@ -645,8 +657,9 @@ class CheckEmailUnique(Resource):
|
||||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
args = CheckEmailUniquePayload.model_validate(payload)
|
||||
if AccountService.is_account_in_freeze(args.email):
|
||||
normalized_email = args.email.lower()
|
||||
if AccountService.is_account_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
if not AccountService.check_email_unique(args.email):
|
||||
if not AccountService.check_email_unique(normalized_email):
|
||||
raise EmailAlreadyInUseError()
|
||||
return {"result": "success"}
|
||||
|
||||
@ -116,26 +116,31 @@ class MemberInviteEmailApi(Resource):
|
||||
raise WorkspaceMembersLimitExceeded()
|
||||
|
||||
for invitee_email in invitee_emails:
|
||||
normalized_invitee_email = invitee_email.lower()
|
||||
try:
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
token = RegisterService.invite_new_member(
|
||||
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
|
||||
tenant=inviter.current_tenant,
|
||||
email=invitee_email,
|
||||
language=interface_language,
|
||||
role=invitee_role,
|
||||
inviter=inviter,
|
||||
)
|
||||
encoded_invitee_email = parse.quote(invitee_email)
|
||||
encoded_invitee_email = parse.quote(normalized_invitee_email)
|
||||
invitation_results.append(
|
||||
{
|
||||
"status": "success",
|
||||
"email": invitee_email,
|
||||
"email": normalized_invitee_email,
|
||||
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
|
||||
}
|
||||
)
|
||||
except AccountAlreadyInTenantError:
|
||||
invitation_results.append(
|
||||
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
|
||||
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
|
||||
)
|
||||
except Exception as e:
|
||||
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
|
||||
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
|
||||
@ -358,14 +358,12 @@ def annotation_import_rate_limit(view: Callable[P, R]):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
current_time = int(time.time() * 1000)
|
||||
|
||||
# Check per-minute rate limit
|
||||
minute_key = f"annotation_import_rate_limit:{current_tenant_id}:1min"
|
||||
redis_client.zadd(minute_key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(minute_key, 0, current_time - 60000)
|
||||
minute_count = redis_client.zcard(minute_key)
|
||||
redis_client.expire(minute_key, 120) # 2 minutes TTL
|
||||
|
||||
if minute_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE:
|
||||
abort(
|
||||
429,
|
||||
@ -379,7 +377,6 @@ def annotation_import_rate_limit(view: Callable[P, R]):
|
||||
redis_client.zremrangebyscore(hour_key, 0, current_time - 3600000)
|
||||
hour_count = redis_client.zcard(hour_key)
|
||||
redis_client.expire(hour_key, 7200) # 2 hours TTL
|
||||
|
||||
if hour_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:
|
||||
abort(
|
||||
429,
|
||||
|
||||
@ -4,7 +4,6 @@ import secrets
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@ -22,7 +21,7 @@ from controllers.web import web_ns
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models import Account
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
@ -70,6 +69,9 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
def post(self):
|
||||
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
request_email = payload.email
|
||||
normalized_email = request_email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
@ -80,12 +82,12 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
|
||||
token = None
|
||||
if account is None:
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
|
||||
token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@ -104,9 +106,9 @@ class ForgotPasswordCheckApi(Resource):
|
||||
def post(self):
|
||||
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
user_email = payload.email
|
||||
user_email = payload.email.lower()
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
@ -114,11 +116,16 @@ class ForgotPasswordCheckApi(Resource):
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if payload.code != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(payload.email)
|
||||
AccountService.add_forgot_password_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
@ -126,11 +133,11 @@ class ForgotPasswordCheckApi(Resource):
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=payload.code, additional_data={"phase": "reset"}
|
||||
token_email, code=payload.code, additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(payload.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_forgot_password_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@web_ns.route("/forgot-password/resets")
|
||||
@ -174,7 +181,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
|
||||
@ -197,25 +197,29 @@ class EmailCodeLoginApi(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
user_email = args["email"].lower()
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if token_data["email"] != args["email"]:
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
if normalized_token_email != user_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args["code"]:
|
||||
raise EmailCodeError()
|
||||
|
||||
WebAppAuthService.revoke_email_code_login_token(args["token"])
|
||||
account = WebAppAuthService.get_user_through_email(user_email)
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
if not account:
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
AccountService.reset_login_error_rate_limit(user_email)
|
||||
response = make_response({"result": "success", "data": {"access_token": token}})
|
||||
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
|
||||
return response
|
||||
|
||||
380
api/core/agent/agent_app_runner.py
Normal file
380
api/core/agent/agent_app_runner.py
Normal file
@ -0,0 +1,380 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentAppRunner(BaseAgentRunner):
|
||||
def _create_tool_invoke_hook(self, message: Message):
|
||||
"""
|
||||
Create a tool invoke hook that uses ToolEngine.agent_invoke.
|
||||
This hook handles file creation and returns proper meta information.
|
||||
"""
|
||||
# Get trace manager from app generate entity
|
||||
trace_manager = self.application_generate_entity.trace_manager
|
||||
|
||||
def tool_invoke_hook(
|
||||
tool: Tool, tool_args: dict[str, Any], tool_name: str
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""Hook that uses agent_invoke for proper file and meta handling."""
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
|
||||
# Publish files and track IDs
|
||||
for message_file_id in message_files:
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
self._current_message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, message_files, tool_invoke_meta
|
||||
|
||||
return tool_invoke_hook
|
||||
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run Agent application
|
||||
"""
|
||||
self.query = query
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, _ = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
# Create tool invoke hook for agent_invoke
|
||||
tool_invoke_hook = self._create_tool_invoke_hook(message)
|
||||
|
||||
# Get instruction for ReAct strategy
|
||||
instruction = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
# Use factory to create appropriate strategy
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=self.model_features,
|
||||
model_instance=self.model_instance,
|
||||
tools=list(tool_instances.values()),
|
||||
files=list(self.files),
|
||||
max_iterations=app_config.agent.max_iteration,
|
||||
context=self.build_execution_context(),
|
||||
agent_strategy=self.config.strategy,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Initialize state variables
|
||||
current_agent_thought_id = None
|
||||
has_published_thought = False
|
||||
current_tool_name: str | None = None
|
||||
self._current_message_file_ids: list[str] = []
|
||||
|
||||
# organize prompt messages
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
|
||||
# Run strategy
|
||||
generator = strategy.run(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Consume generator and collect result
|
||||
result: AgentResult | None = None
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
output = next(generator)
|
||||
except StopIteration as e:
|
||||
# Generator finished, get the return value
|
||||
result = e.value
|
||||
break
|
||||
|
||||
if isinstance(output, LLMResultChunk):
|
||||
# Handle LLM chunk
|
||||
if current_agent_thought_id and not has_published_thought:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
has_published_thought = True
|
||||
|
||||
yield output
|
||||
|
||||
elif isinstance(output, AgentLog):
|
||||
# Handle Agent Log using log_type for type-safe dispatch
|
||||
if output.status == AgentLog.LogStatus.START:
|
||||
if output.log_type == AgentLog.LogType.ROUND:
|
||||
# Start of a new round
|
||||
message_file_ids: list[str] = []
|
||||
current_agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message="",
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
has_published_thought = False
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call start - extract data from structured fields
|
||||
current_tool_name = output.data.get("tool_name", "")
|
||||
tool_input = output.data.get("tool_args", {})
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_input=tool_input,
|
||||
thought=None,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.status == AgentLog.LogStatus.SUCCESS:
|
||||
if output.log_type == AgentLog.LogType.THOUGHT:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
thought_text = output.data.get("thought")
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=thought_text,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call finished
|
||||
tool_output = output.data.get("output")
|
||||
# Get meta from strategy output (now properly populated)
|
||||
tool_meta = output.data.get("meta")
|
||||
|
||||
# Wrap tool_meta with tool_name as key (required by agent_service)
|
||||
if tool_meta and current_tool_name:
|
||||
tool_meta = {current_tool_name: tool_meta}
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
observation=tool_output,
|
||||
tool_invoke_meta=tool_meta,
|
||||
answer=None,
|
||||
messages_ids=self._current_message_file_ids,
|
||||
)
|
||||
# Clear message file ids after saving
|
||||
self._current_message_file_ids = []
|
||||
current_tool_name = None
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.ROUND:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Round finished - save LLM usage and answer
|
||||
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
|
||||
llm_result = output.data.get("llm_result")
|
||||
final_answer = output.data.get("final_answer")
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=llm_result,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Re-raise any other exceptions
|
||||
raise
|
||||
|
||||
# Process final result
|
||||
if isinstance(result, AgentResult):
|
||||
final_answer = result.text
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
# Publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=self.model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=usage,
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
if not prompt_template:
|
||||
return prompt_messages or []
|
||||
|
||||
prompt_messages = prompt_messages or []
|
||||
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
|
||||
return prompt_messages
|
||||
|
||||
if not prompt_messages:
|
||||
return [SystemPromptMessage(content=prompt_template)]
|
||||
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
We need to remove the image messages from the prompt messages at the first iteration.
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
# For ReAct strategy, use the agent prompt template
|
||||
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
|
||||
prompt_template = self.config.prompt.first_prompt
|
||||
else:
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
return prompt_messages
|
||||
@ -1,11 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
from typing import Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -41,6 +42,7 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -114,9 +116,20 @@ class BaseAgentRunner(AppRunner):
|
||||
features = model_schema.features if model_schema and model_schema.features else []
|
||||
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
||||
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
||||
self.model_features = features
|
||||
self.query: str | None = ""
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
def build_execution_context(self) -> ExecutionContext:
|
||||
"""Build execution context."""
|
||||
return ExecutionContext(
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_config.app_id,
|
||||
conversation_id=self.conversation.id,
|
||||
message_id=self.message.id,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
def _repack_app_generate_entity(
|
||||
self, app_generate_entity: AgentChatAppGenerateEntity
|
||||
) -> AgentChatAppGenerateEntity:
|
||||
@ -289,6 +302,7 @@ class BaseAgentRunner(AppRunner):
|
||||
thought = MessageAgentThought(
|
||||
message_id=message_id,
|
||||
message_chain_id=None,
|
||||
tool_process_data=None,
|
||||
thought="",
|
||||
tool=tool_name,
|
||||
tool_labels_str="{}",
|
||||
@ -296,20 +310,20 @@ class BaseAgentRunner(AppRunner):
|
||||
tool_input=tool_input,
|
||||
message=message,
|
||||
message_token=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
message_unit_price=Decimal(0),
|
||||
message_price_unit=Decimal("0.001"),
|
||||
message_files=json.dumps(messages_ids) if messages_ids else "",
|
||||
answer="",
|
||||
observation="",
|
||||
answer_token=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
answer_unit_price=Decimal(0),
|
||||
answer_price_unit=Decimal("0.001"),
|
||||
tokens=0,
|
||||
total_price=0,
|
||||
total_price=Decimal(0),
|
||||
position=self.agent_thought_count + 1,
|
||||
currency="USD",
|
||||
latency=0,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=self.user_id,
|
||||
)
|
||||
|
||||
@ -342,7 +356,8 @@ class BaseAgentRunner(AppRunner):
|
||||
raise ValueError("agent thought not found")
|
||||
|
||||
if thought:
|
||||
agent_thought.thought += thought
|
||||
existing_thought = agent_thought.thought or ""
|
||||
agent_thought.thought = f"{existing_thought}{thought}"
|
||||
|
||||
if tool_name:
|
||||
agent_thought.tool = tool_name
|
||||
@ -440,21 +455,30 @@ class BaseAgentRunner(AppRunner):
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
if agent_thoughts:
|
||||
for agent_thought in agent_thoughts:
|
||||
tools = agent_thought.tool
|
||||
if tools:
|
||||
tools = tools.split(";")
|
||||
tool_names_raw = agent_thought.tool
|
||||
if tool_names_raw:
|
||||
tool_names = tool_names_raw.split(";")
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
try:
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
except Exception:
|
||||
tool_inputs = {tool: {} for tool in tools}
|
||||
try:
|
||||
tool_responses = json.loads(agent_thought.observation)
|
||||
except Exception:
|
||||
tool_responses = dict.fromkeys(tools, agent_thought.observation)
|
||||
tool_input_payload = agent_thought.tool_input
|
||||
if tool_input_payload:
|
||||
try:
|
||||
tool_inputs = json.loads(tool_input_payload)
|
||||
except Exception:
|
||||
tool_inputs = {tool: {} for tool in tool_names}
|
||||
else:
|
||||
tool_inputs = {tool: {} for tool in tool_names}
|
||||
|
||||
for tool in tools:
|
||||
observation_payload = agent_thought.observation
|
||||
if observation_payload:
|
||||
try:
|
||||
tool_responses = json.loads(observation_payload)
|
||||
except Exception:
|
||||
tool_responses = dict.fromkeys(tool_names, observation_payload)
|
||||
else:
|
||||
tool_responses = dict.fromkeys(tool_names, observation_payload)
|
||||
|
||||
for tool in tool_names:
|
||||
# generate a uuid for tool call
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
tool_calls.append(
|
||||
@ -484,7 +508,7 @@ class BaseAgentRunner(AppRunner):
|
||||
*tool_call_response,
|
||||
]
|
||||
)
|
||||
if not tools:
|
||||
if not tool_names_raw:
|
||||
result.append(AssistantPromptMessage(content=agent_thought.thought))
|
||||
else:
|
||||
if message.answer:
|
||||
|
||||
@ -1,437 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.workflow.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage]
|
||||
_agent_scratchpad: list[AgentScratchpadUnit]
|
||||
_instruction: str
|
||||
_query: str
|
||||
_prompt_messages_tools: Sequence[PromptMessageTool]
|
||||
|
||||
def run(
|
||||
self,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: Mapping[str, str],
|
||||
) -> Generator:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
|
||||
app_generate_entity = self.application_generate_entity
|
||||
self._repack_app_generate_entity(app_generate_entity)
|
||||
self._init_react_state(query)
|
||||
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
# check model mode
|
||||
if "Observation" not in app_generate_entity.model_conf.stop:
|
||||
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config.agent
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template or ""
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
self._prompt_messages_tools = prompt_messages_tools
|
||||
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
final_answer = ""
|
||||
prompt_messages: list = [] # Initialize prompt_messages
|
||||
agent_thought_id = "" # Initialize agent_thought_id
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.total_tokens += usage.total_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
llm_usage.total_price += usage.total_price
|
||||
|
||||
model_instance = self.model_instance
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
self._prompt_messages_tools = []
|
||||
|
||||
message_file_ids: list[str] = []
|
||||
|
||||
agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
if iteration_step > 1:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
tools=[],
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
usage_dict: dict[str, LLMUsage | None] = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
action = chunk
|
||||
# detect action
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
scratchpad.action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.action = action
|
||||
else:
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += chunk
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Check if max iteration is reached and model still wants to call tools
|
||||
if iteration_step == max_iteration_steps and scratchpad.action:
|
||||
if scratchpad.action.action_name.lower() != "final answer":
|
||||
raise AgentMaxIterationError(app_config.agent.max_iteration)
|
||||
|
||||
# get llm usage
|
||||
if "usage" in usage_dict:
|
||||
if usage_dict["usage"] is not None:
|
||||
increase_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
usage_dict["usage"] = LLMUsage.empty_usage()
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought or "",
|
||||
observation="",
|
||||
answer=scratchpad.agent_response or "",
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
if not scratchpad.is_final():
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = ""
|
||||
else:
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
|
||||
elif isinstance(scratchpad.action.action_input, str):
|
||||
final_answer = scratchpad.action.action_input
|
||||
else:
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
except TypeError:
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
else:
|
||||
function_call_state = True
|
||||
# action is tool call, invoke tool
|
||||
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
|
||||
action=scratchpad.action,
|
||||
tool_instances=tool_instances,
|
||||
message_file_ids=message_file_ids,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
scratchpad.observation = tool_invoke_response
|
||||
scratchpad.agent_response = tool_invoke_response
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name=scratchpad.action.action_name,
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
thought=scratchpad.thought or "",
|
||||
observation={scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in self._prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
|
||||
),
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name="",
|
||||
tool_input={},
|
||||
tool_invoke_meta={},
|
||||
thought=final_answer,
|
||||
observation={},
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: Mapping[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
"""
|
||||
handle invoke action
|
||||
:param action: action
|
||||
:param tool_instances: tool instances
|
||||
:param message_file_ids: message file ids
|
||||
:param trace_manager: trace manager
|
||||
:return: observation, meta
|
||||
"""
|
||||
# action is tool call, invoke tool
|
||||
tool_call_name = action.action_name
|
||||
tool_call_args = action.action_input
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
|
||||
if not tool_instance:
|
||||
answer = f"there is not a tool named {tool_call_name}"
|
||||
return answer, ToolInvokeMeta.error_instance(answer)
|
||||
|
||||
if isinstance(tool_call_args, str):
|
||||
try:
|
||||
tool_call_args = json.loads(tool_call_args)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# invoke tool
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# publish files
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, tool_invoke_meta
|
||||
|
||||
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
|
||||
"""
|
||||
convert dict to action
|
||||
"""
|
||||
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
fill in inputs from external data tools
|
||||
"""
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return instruction
|
||||
|
||||
def _init_react_state(self, query):
|
||||
"""
|
||||
init agent scratchpad
|
||||
"""
|
||||
self._query = query
|
||||
self._agent_scratchpad = []
|
||||
self._historic_prompt_messages = self._organize_historic_prompt_messages()
|
||||
|
||||
@abstractmethod
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
organize prompt messages
|
||||
"""
|
||||
|
||||
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
format assistant message
|
||||
"""
|
||||
message = ""
|
||||
for scratchpad in agent_scratchpad:
|
||||
if scratchpad.is_final():
|
||||
message += f"Final Answer: {scratchpad.agent_response}"
|
||||
else:
|
||||
message += f"Thought: {scratchpad.thought}\n\n"
|
||||
if scratchpad.action_str:
|
||||
message += f"Action: {scratchpad.action_str}\n\n"
|
||||
if scratchpad.observation:
|
||||
message += f"Observation: {scratchpad.observation}\n\n"
|
||||
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: list[PromptMessage] | None = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
scratchpads: list[AgentScratchpadUnit] = []
|
||||
current_scratchpad: AgentScratchpadUnit | None = None
|
||||
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
if not current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
scratchpads.append(current_scratchpad)
|
||||
if message.tool_calls:
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except Exception:
|
||||
logger.exception("Failed to parse tool call from assistant message")
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad.observation = message.content
|
||||
else:
|
||||
raise NotImplementedError("expected str type")
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
scratchpads = []
|
||||
current_scratchpad = None
|
||||
|
||||
result.append(message)
|
||||
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
|
||||
historic_prompts = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=current_session_messages or [],
|
||||
history_messages=result,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
return historic_prompts
|
||||
@ -1,118 +0,0 @@
|
||||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class CotChatAgentRunner(CotAgentRunner):
|
||||
def _organize_system_prompt(self) -> SystemPromptMessage:
|
||||
"""
|
||||
Organize system prompt
|
||||
"""
|
||||
assert self.app_config.agent
|
||||
assert self.app_config.agent.prompt
|
||||
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if not prompt_entity:
|
||||
raise ValueError("Agent prompt configuration is not set")
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return SystemPromptMessage(content=system_prompt)
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize
|
||||
"""
|
||||
# organize system prompt
|
||||
system_message = self._organize_system_prompt()
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
if not agent_scratchpad:
|
||||
assistant_messages = []
|
||||
else:
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
||||
if unit.observation:
|
||||
assistant_message.content += f"Observation: {unit.observation}\n\n"
|
||||
|
||||
assistant_messages = [assistant_message]
|
||||
|
||||
# query messages
|
||||
query_messages = self._organize_user_query(self._query, [])
|
||||
|
||||
if assistant_messages:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages(
|
||||
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
|
||||
)
|
||||
messages = [
|
||||
system_message,
|
||||
*historic_messages,
|
||||
*query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content="continue"),
|
||||
]
|
||||
else:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
|
||||
messages = [system_message, *historic_messages, *query_messages]
|
||||
|
||||
# join all messages
|
||||
return messages
|
||||
@ -1,87 +0,0 @@
|
||||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class CotCompletionAgentRunner(CotAgentRunner):
|
||||
def _organize_instruction_prompt(self) -> str:
|
||||
"""
|
||||
Organize instruction prompt
|
||||
"""
|
||||
if self.app_config.agent is None:
|
||||
raise ValueError("Agent configuration is not set")
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if prompt_entity is None:
|
||||
raise ValueError("prompt entity is not set")
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str:
|
||||
"""
|
||||
Organize historic prompt
|
||||
"""
|
||||
historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
|
||||
historic_prompt = ""
|
||||
|
||||
for message in historic_prompt_messages:
|
||||
if isinstance(message, UserPromptMessage):
|
||||
historic_prompt += f"Question: {message.content}\n\n"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
if isinstance(message.content, str):
|
||||
historic_prompt += message.content + "\n\n"
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
historic_prompt += content.data
|
||||
|
||||
return historic_prompt
|
||||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
"""
|
||||
# organize system prompt
|
||||
system_prompt = self._organize_instruction_prompt()
|
||||
|
||||
# organize historic prompt messages
|
||||
historic_prompt = self._organize_historic_prompt()
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
assistant_prompt = ""
|
||||
for unit in agent_scratchpad or []:
|
||||
if unit.is_final():
|
||||
assistant_prompt += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assistant_prompt += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_prompt += f"Action: {unit.action_str}\n\n"
|
||||
if unit.observation:
|
||||
assistant_prompt += f"Observation: {unit.observation}\n\n"
|
||||
|
||||
# query messages
|
||||
query_prompt = f"Question: {self._query}"
|
||||
|
||||
# join all messages
|
||||
prompt = (
|
||||
system_prompt.replace("{{historic_messages}}", historic_prompt)
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt)
|
||||
.replace("{{query}}", query_prompt)
|
||||
)
|
||||
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
@ -92,3 +94,96 @@ class AgentInvokeMessage(ToolInvokeMessage):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
"""Execution context containing trace and audit information.
|
||||
|
||||
This context carries all the IDs and metadata that are not part of
|
||||
the core business logic but needed for tracing, auditing, and
|
||||
correlation purposes.
|
||||
"""
|
||||
|
||||
user_id: str | None = None
|
||||
app_id: str | None = None
|
||||
conversation_id: str | None = None
|
||||
message_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext":
|
||||
"""Create a minimal context with only essential fields."""
|
||||
return cls(user_id=user_id)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for passing to legacy code."""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"message_id": self.message_id,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
def with_updates(self, **kwargs) -> "ExecutionContext":
|
||||
"""Create a new context with updated fields."""
|
||||
data = self.to_dict()
|
||||
data.update(kwargs)
|
||||
|
||||
return ExecutionContext(
|
||||
user_id=data.get("user_id"),
|
||||
app_id=data.get("app_id"),
|
||||
conversation_id=data.get("conversation_id"),
|
||||
message_id=data.get("message_id"),
|
||||
tenant_id=data.get("tenant_id"),
|
||||
)
|
||||
|
||||
|
||||
class AgentLog(BaseModel):
|
||||
"""
|
||||
Agent Log.
|
||||
"""
|
||||
|
||||
class LogType(StrEnum):
|
||||
"""Type of agent log entry."""
|
||||
|
||||
ROUND = "round" # A complete iteration round
|
||||
THOUGHT = "thought" # LLM thinking/reasoning
|
||||
TOOL_CALL = "tool_call" # Tool invocation
|
||||
|
||||
class LogMetadata(StrEnum):
|
||||
STARTED_AT = "started_at"
|
||||
FINISHED_AT = "finished_at"
|
||||
ELAPSED_TIME = "elapsed_time"
|
||||
TOTAL_PRICE = "total_price"
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
PROVIDER = "provider"
|
||||
CURRENCY = "currency"
|
||||
LLM_USAGE = "llm_usage"
|
||||
ICON = "icon"
|
||||
ICON_DARK = "icon_dark"
|
||||
|
||||
class LogStatus(StrEnum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="The id of the log")
|
||||
label: str = Field(..., description="The label of the log")
|
||||
log_type: LogType = Field(..., description="The type of the log")
|
||||
parent_id: str | None = Field(default=None, description="Leave empty for root log")
|
||||
error: str | None = Field(default=None, description="The error message")
|
||||
status: LogStatus = Field(..., description="The status of the log")
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
|
||||
|
||||
|
||||
class AgentResult(BaseModel):
|
||||
"""
|
||||
Agent execution result.
|
||||
"""
|
||||
|
||||
text: str = Field(default="", description="The generated text")
|
||||
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
|
||||
usage: Any | None = Field(default=None, description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(default=None, description="Reason for completion")
|
||||
|
||||
@ -188,7 +188,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
),
|
||||
)
|
||||
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
|
||||
if tool_calls:
|
||||
assistant_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
@ -200,8 +200,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
else:
|
||||
assistant_message.content = response
|
||||
|
||||
self._current_thoughts.append(assistant_message)
|
||||
|
||||
|
||||
55
api/core/agent/patterns/README.md
Normal file
55
api/core/agent/patterns/README.md
Normal file
@ -0,0 +1,55 @@
|
||||
# Agent Patterns
|
||||
|
||||
A unified agent pattern module that powers both Agent V2 workflow nodes and agent applications. Strategies share a common execution contract while adapting to model capabilities and tool availability.
|
||||
|
||||
## Overview
|
||||
|
||||
The module applies a strategy pattern around LLM/tool orchestration. `StrategyFactory` auto-selects the best implementation based on model features or an explicit agent strategy, and each strategy streams logs and usage consistently.
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Dual strategies**
|
||||
- `FunctionCallStrategy`: uses native LLM function/tool calling when the model exposes `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL`.
|
||||
- `ReActStrategy`: ReAct (reasoning + acting) flow driven by `CotAgentOutputParser`, used when function calling is unavailable or explicitly requested.
|
||||
- **Explicit or auto selection**
|
||||
- `StrategyFactory.create_strategy` prefers an explicit `AgentEntity.Strategy` (FUNCTION_CALLING or CHAIN_OF_THOUGHT).
|
||||
- Otherwise it falls back to function calling when tool-call features exist, or ReAct when they do not.
|
||||
- **Unified execution contract**
|
||||
- `AgentPattern.run` yields streaming `AgentLog` entries and `LLMResultChunk` data, returning an `AgentResult` with text, files, usage, and `finish_reason`.
|
||||
- Iterations are configurable and hard-capped at 99 rounds; the last round forces a final answer by withholding tools.
|
||||
- **Tool handling and hooks**
|
||||
- Tools convert to `PromptMessageTool` objects before invocation.
|
||||
- Optional `tool_invoke_hook` lets callers override tool execution (e.g., agent apps) while workflow runs use `ToolEngine.generic_invoke`.
|
||||
- Tool outputs support text, links, JSON, variables, blobs, retriever resources, and file attachments; `target=="self"` files are reloaded into model context, others are returned as outputs.
|
||||
- **File-aware arguments**
|
||||
- Tool args accept `[File: <id>]` or `[Files: <id1, id2>]` placeholders that resolve to `File` objects before invocation, enabling models to reference uploaded files safely.
|
||||
- **ReAct prompt shaping**
|
||||
- System prompts replace `{{instruction}}`, `{{tools}}`, and `{{tool_names}}` placeholders.
|
||||
- Adds `Observation` to stop sequences and appends scratchpad text so the model sees prior Thought/Action/Observation history.
|
||||
- **Observability and accounting**
|
||||
- Standardized `AgentLog` entries for rounds, model thoughts, and tool calls, including usage aggregation (`LLMUsage`) across streaming and non-streaming paths.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
agent/patterns/
|
||||
├── base.py # Shared utilities: logging, usage, tool invocation, file handling
|
||||
├── function_call.py # Native function-calling loop with tool execution
|
||||
├── react.py # ReAct loop with CoT parsing and scratchpad wiring
|
||||
└── strategy_factory.py # Strategy selection by model features or explicit override
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
- For auto-selection:
|
||||
- Call `StrategyFactory.create_strategy(model_features, model_instance, context, tools, files, ...)` and run the returned strategy with prompt messages and model params.
|
||||
- For explicit behavior:
|
||||
- Pass `agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING` to force native calls (falls back to ReAct if unsupported), or `CHAIN_OF_THOUGHT` to force ReAct.
|
||||
- Both strategies stream chunks and logs; collect the generator output until it returns an `AgentResult`.
|
||||
|
||||
## Integration Points
|
||||
|
||||
- **Model runtime**: delegates to `ModelInstance.invoke_llm` for both streaming and non-streaming calls.
|
||||
- **Tool system**: defaults to `ToolEngine.generic_invoke`, with `tool_invoke_hook` for custom callers.
|
||||
- **Files**: flows through `File` objects for tool inputs/outputs and model-context attachments.
|
||||
- **Execution context**: `ExecutionContext` fields (user/app/conversation/message) propagate to tool invocations and logging.
|
||||
19
api/core/agent/patterns/__init__.py
Normal file
19
api/core/agent/patterns/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
"""Agent patterns module.
|
||||
|
||||
This module provides different strategies for agent execution:
|
||||
- FunctionCallStrategy: Uses native function/tool calling
|
||||
- ReActStrategy: Uses ReAct (Reasoning + Acting) approach
|
||||
- StrategyFactory: Factory for creating strategies based on model features
|
||||
"""
|
||||
|
||||
from .base import AgentPattern
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
from .strategy_factory import StrategyFactory
|
||||
|
||||
__all__ = [
|
||||
"AgentPattern",
|
||||
"FunctionCallStrategy",
|
||||
"ReActStrategy",
|
||||
"StrategyFactory",
|
||||
]
|
||||
474
api/core/agent/patterns/base.py
Normal file
474
api/core/agent/patterns/base.py
Normal file
@ -0,0 +1,474 @@
|
||||
"""Base class for agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
# Type alias for tool invoke hook
|
||||
# Returns: (response_content, message_file_ids, tool_invoke_meta)
|
||||
ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]]
|
||||
|
||||
|
||||
class AgentPattern(ABC):
|
||||
"""Base class for agent execution strategies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
):
|
||||
"""Initialize the agent strategy."""
|
||||
self.model_instance = model_instance
|
||||
self.tools = tools
|
||||
self.context = context
|
||||
self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.files: list[File] = files
|
||||
self.tool_invoke_hook = tool_invoke_hook
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the agent strategy."""
|
||||
pass
|
||||
|
||||
def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None:
|
||||
"""Accumulate LLM usage statistics."""
|
||||
if not total_usage.get("usage"):
|
||||
# Create a copy to avoid modifying the original
|
||||
total_usage["usage"] = LLMUsage(
|
||||
prompt_tokens=delta_usage.prompt_tokens,
|
||||
prompt_unit_price=delta_usage.prompt_unit_price,
|
||||
prompt_price_unit=delta_usage.prompt_price_unit,
|
||||
prompt_price=delta_usage.prompt_price,
|
||||
completion_tokens=delta_usage.completion_tokens,
|
||||
completion_unit_price=delta_usage.completion_unit_price,
|
||||
completion_price_unit=delta_usage.completion_price_unit,
|
||||
completion_price=delta_usage.completion_price,
|
||||
total_tokens=delta_usage.total_tokens,
|
||||
total_price=delta_usage.total_price,
|
||||
currency=delta_usage.currency,
|
||||
latency=delta_usage.latency,
|
||||
)
|
||||
else:
|
||||
current: LLMUsage = total_usage["usage"]
|
||||
current.prompt_tokens += delta_usage.prompt_tokens
|
||||
current.completion_tokens += delta_usage.completion_tokens
|
||||
current.total_tokens += delta_usage.total_tokens
|
||||
current.prompt_price += delta_usage.prompt_price
|
||||
current.completion_price += delta_usage.completion_price
|
||||
current.total_price += delta_usage.total_price
|
||||
|
||||
def _extract_content(self, content: Any) -> str:
|
||||
"""Extract text content from message content."""
|
||||
if isinstance(content, list):
|
||||
# Content items are PromptMessageContentUnionTypes
|
||||
text_parts = []
|
||||
for c in content:
|
||||
# Check if it's a TextPromptMessageContent (which has data attribute)
|
||||
if isinstance(c, TextPromptMessageContent):
|
||||
text_parts.append(c.data)
|
||||
return "".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
def _has_tool_calls(self, chunk: LLMResultChunk) -> bool:
|
||||
"""Check if chunk contains tool calls."""
|
||||
# LLMResultChunk always has delta attribute
|
||||
return bool(chunk.delta.message and chunk.delta.message.tool_calls)
|
||||
|
||||
def _has_tool_calls_result(self, result: LLMResult) -> bool:
|
||||
"""Check if result contains tool calls (non-streaming)."""
|
||||
# LLMResult always has message attribute
|
||||
return bool(result.message and result.message.tool_calls)
|
||||
|
||||
def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from streaming chunk."""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
if chunk.delta.message and chunk.delta.message.tool_calls:
|
||||
for tool_call in chunk.delta.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from non-streaming result."""
|
||||
tool_calls = []
|
||||
if result.message and result.message.tool_calls:
|
||||
for tool_call in result.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_text_from_message(self, message: PromptMessage) -> str:
|
||||
"""Extract text content from a prompt message."""
|
||||
# PromptMessage always has content attribute
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from content list
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
text_parts.append(item.data)
|
||||
return " ".join(text_parts)
|
||||
return ""
|
||||
|
||||
def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]:
|
||||
"""Get metadata for a tool including provider and icon info."""
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {}
|
||||
if tool_instance.entity and tool_instance.entity.identity:
|
||||
identity = tool_instance.entity.identity
|
||||
if identity.provider:
|
||||
metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider
|
||||
|
||||
# Get icon using ToolManager for proper URL generation
|
||||
tenant_id = self.context.tenant_id
|
||||
if tenant_id and identity.provider:
|
||||
try:
|
||||
provider_type = tool_instance.tool_provider_type()
|
||||
icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider)
|
||||
if isinstance(icon, str):
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
elif isinstance(icon, dict):
|
||||
# Handle icon dict with background/content or light/dark variants
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
except Exception:
|
||||
# Fallback to identity.icon if ToolManager fails
|
||||
if identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
elif identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
return metadata
|
||||
|
||||
def _create_log(
|
||||
self,
|
||||
label: str,
|
||||
log_type: AgentLog.LogType,
|
||||
status: AgentLog.LogStatus,
|
||||
data: dict[str, Any] | None = None,
|
||||
parent_id: str | None = None,
|
||||
extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None,
|
||||
) -> AgentLog:
|
||||
"""Create a new AgentLog with standard metadata."""
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {
|
||||
AgentLog.LogMetadata.STARTED_AT: time.perf_counter(),
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
|
||||
return AgentLog(
|
||||
label=label,
|
||||
log_type=log_type,
|
||||
status=status,
|
||||
data=data or {},
|
||||
parent_id=parent_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _finish_log(
|
||||
self,
|
||||
log: AgentLog,
|
||||
data: dict[str, Any] | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
) -> AgentLog:
|
||||
"""Finish an AgentLog by updating its status and metadata."""
|
||||
log.status = AgentLog.LogStatus.SUCCESS
|
||||
|
||||
if data is not None:
|
||||
log.data = data
|
||||
|
||||
# Calculate elapsed time
|
||||
started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter())
|
||||
finished_at = time.perf_counter()
|
||||
|
||||
# Update metadata
|
||||
log.metadata = {
|
||||
**log.metadata,
|
||||
AgentLog.LogMetadata.FINISHED_AT: finished_at,
|
||||
# Calculate elapsed time in seconds
|
||||
AgentLog.LogMetadata.ELAPSED_TIME: round(finished_at - started_at, 4),
|
||||
}
|
||||
|
||||
# Add usage information if provided
|
||||
if usage:
|
||||
log.metadata.update(
|
||||
{
|
||||
AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price,
|
||||
AgentLog.LogMetadata.CURRENCY: usage.currency,
|
||||
AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens,
|
||||
AgentLog.LogMetadata.LLM_USAGE: usage,
|
||||
}
|
||||
)
|
||||
|
||||
return log
|
||||
|
||||
def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Replace file references in tool arguments with actual File objects.
|
||||
|
||||
Args:
|
||||
tool_args: Dictionary of tool arguments
|
||||
|
||||
Returns:
|
||||
Updated tool arguments with file references replaced
|
||||
"""
|
||||
# Process each argument in the dictionary
|
||||
processed_args: dict[str, Any] = {}
|
||||
for key, value in tool_args.items():
|
||||
processed_args[key] = self._process_file_reference(value)
|
||||
return processed_args
|
||||
|
||||
def _process_file_reference(self, data: Any) -> Any:
|
||||
"""
|
||||
Recursively process data to replace file references.
|
||||
Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...].
|
||||
|
||||
Args:
|
||||
data: The data to process (can be dict, list, str, or other types)
|
||||
|
||||
Returns:
|
||||
Processed data with file references replaced
|
||||
"""
|
||||
single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$")
|
||||
multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$")
|
||||
|
||||
if isinstance(data, dict):
|
||||
# Process dictionary recursively
|
||||
return {key: self._process_file_reference(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
# Process list recursively
|
||||
return [self._process_file_reference(item) for item in data]
|
||||
elif isinstance(data, str):
|
||||
# Check for single file pattern [File: file_id]
|
||||
single_match = single_file_pattern.match(data.strip())
|
||||
if single_match:
|
||||
file_id = single_match.group(1).strip()
|
||||
# Find the file in self.files
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
return file
|
||||
# If file not found, return original value
|
||||
return data
|
||||
|
||||
# Check for multiple files pattern [Files: file_id1, file_id2, ...]
|
||||
multiple_match = multiple_files_pattern.match(data.strip())
|
||||
if multiple_match:
|
||||
file_ids_str = multiple_match.group(1).strip()
|
||||
# Split by comma and strip whitespace
|
||||
file_ids = [fid.strip() for fid in file_ids_str.split(",")]
|
||||
|
||||
# Find all matching files
|
||||
matched_files: list[File] = []
|
||||
for file_id in file_ids:
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
matched_files.append(file)
|
||||
break
|
||||
|
||||
# Return list of files if any were found, otherwise return original
|
||||
return matched_files or data
|
||||
|
||||
return data
|
||||
else:
|
||||
# Return other types as-is
|
||||
return data
|
||||
|
||||
def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk:
|
||||
"""Create a text chunk for streaming."""
|
||||
return LLMResultChunk(
|
||||
model=self.model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=None,
|
||||
),
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
def _invoke_tool(
|
||||
self,
|
||||
tool_instance: Tool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> tuple[str, list[File], ToolInvokeMeta | None]:
|
||||
"""
|
||||
Invoke a tool and collect its response.
|
||||
|
||||
Args:
|
||||
tool_instance: The tool instance to invoke
|
||||
tool_args: Tool arguments
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Tuple of (response_content, tool_files, tool_invoke_meta)
|
||||
"""
|
||||
# Process tool_args to replace file references with actual File objects
|
||||
tool_args = self._replace_file_references(tool_args)
|
||||
|
||||
# If a tool invoke hook is set, use it instead of generic_invoke
|
||||
if self.tool_invoke_hook:
|
||||
response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name)
|
||||
# Note: message_file_ids are stored in DB, we don't convert them to File objects here
|
||||
# The caller (AgentAppRunner) handles file publishing
|
||||
return response_content, [], tool_invoke_meta
|
||||
|
||||
# Default: use generic_invoke for workflow scenarios
|
||||
# Import here to avoid circular import
|
||||
from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine
|
||||
|
||||
tool_response = ToolEngine().generic_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.context.user_id or "",
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=self.context.app_id,
|
||||
conversation_id=self.context.conversation_id,
|
||||
message_id=self.context.message_id,
|
||||
)
|
||||
|
||||
# Collect response and files
|
||||
response_content = ""
|
||||
tool_files: list[File] = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
response_content += response.message.text
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# Handle link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Link: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# Handle image URL messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK:
|
||||
# Handle image link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK:
|
||||
# Handle binary file link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
filename = response.meta.get("filename", "file") if response.meta else "file"
|
||||
response_content += f"[File: {filename} - {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
# Handle JSON messages
|
||||
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
|
||||
response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2)
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# Handle blob messages - convert to text representation
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobMessage):
|
||||
mime_type = (
|
||||
response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream"
|
||||
)
|
||||
size = len(response.message.blob)
|
||||
response_content += f"[Binary data: {mime_type}, size: {size} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
# Handle variable messages
|
||||
if isinstance(response.message, ToolInvokeMessage.VariableMessage):
|
||||
var_name = response.message.variable_name
|
||||
var_value = response.message.variable_value
|
||||
if isinstance(var_value, str):
|
||||
response_content += var_value
|
||||
else:
|
||||
response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
|
||||
# Handle blob chunk messages - these are parts of a larger blob
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage):
|
||||
response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
|
||||
# Handle retriever resources messages
|
||||
if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage):
|
||||
response_content += response.message.context
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.FILE:
|
||||
# Extract file from meta
|
||||
if response.meta and "file" in response.meta:
|
||||
file = response.meta["file"]
|
||||
if isinstance(file, File):
|
||||
# Check if file is for model or tool output
|
||||
if response.meta.get("target") == "self":
|
||||
# File is for model - add to files for next prompt
|
||||
self.files.append(file)
|
||||
response_content += f"File '{file.filename}' has been loaded into your context."
|
||||
else:
|
||||
# File is tool output
|
||||
tool_files.append(file)
|
||||
|
||||
return response_content, tool_files, None
|
||||
|
||||
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
|
||||
"""Find a tool instance by its name."""
|
||||
for tool in self.tools:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]:
|
||||
"""Convert tools to prompt message format."""
|
||||
prompt_tools: list[PromptMessageTool] = []
|
||||
for tool in self.tools:
|
||||
prompt_tools.append(tool.to_prompt_message_tool())
|
||||
return prompt_tools
|
||||
|
||||
def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None:
|
||||
"""Initialize usage tracking with empty usage if not set."""
|
||||
if "usage" not in llm_usage or llm_usage["usage"] is None:
|
||||
llm_usage["usage"] = LLMUsage.empty_usage()
|
||||
299
api/core/agent/patterns/function_call.py
Normal file
299
api/core/agent/patterns/function_call.py
Normal file
@ -0,0 +1,299 @@
|
||||
"""Function Call strategy implementation."""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
from .base import AgentPattern
|
||||
|
||||
|
||||
class FunctionCallStrategy(AgentPattern):
|
||||
"""Function Call strategy using model's native tool calling capability."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the function call agent strategy."""
|
||||
# Convert tools to prompt format
|
||||
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
|
||||
|
||||
# Initialize tracking
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
function_call_state: bool = True
|
||||
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
|
||||
while function_call_state and iteration_step <= max_iterations:
|
||||
function_call_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, remove tools to force final answer
|
||||
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=current_tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=self.context.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log
|
||||
)
|
||||
messages.append(self._create_assistant_message(response_content, tool_calls))
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update final text if no tool calls (this is likely the final answer)
|
||||
if not tool_calls:
|
||||
final_text = response_content
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Process tool calls
|
||||
tool_outputs: dict[str, str] = {}
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
# Execute tools
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"llm_result": response_content,
|
||||
"tool_calls": [
|
||||
{"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
"final_answer": final_text if not function_call_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, LLMUsage | None],
|
||||
start_log: AgentLog,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[list[tuple[str, str, dict[str, Any]]], str, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract tool calls and content.
|
||||
|
||||
Returns a tuple of (tool_calls, response_content, finish_reason).
|
||||
"""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
response_content: str = ""
|
||||
finish_reason: str | None = None
|
||||
if isinstance(chunks, Generator):
|
||||
# Streaming response
|
||||
for chunk in chunks:
|
||||
# Extract tool calls
|
||||
if self._has_tool_calls(chunk):
|
||||
tool_calls.extend(self._extract_tool_calls(chunk))
|
||||
|
||||
# Extract content
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
response_content += self._extract_content(chunk.delta.message.content)
|
||||
|
||||
# Track usage
|
||||
if chunk.delta.usage:
|
||||
self._accumulate_usage(llm_usage, chunk.delta.usage)
|
||||
|
||||
# Capture finish reason
|
||||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
# Non-streaming response
|
||||
result: LLMResult = chunks
|
||||
|
||||
if self._has_tool_calls_result(result):
|
||||
tool_calls.extend(self._extract_tool_calls_result(result))
|
||||
|
||||
if result.message and result.message.content:
|
||||
response_content += self._extract_content(result.message.content)
|
||||
|
||||
if result.usage:
|
||||
self._accumulate_usage(llm_usage, result.usage)
|
||||
|
||||
# Convert to streaming format
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
yield self._finish_log(
|
||||
start_log,
|
||||
data={
|
||||
"result": response_content,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
return tool_calls, response_content, finish_reason
|
||||
|
||||
def _create_assistant_message(
|
||||
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
||||
) -> AssistantPromptMessage:
|
||||
"""Create assistant message with tool calls."""
|
||||
if tool_calls is None:
|
||||
return AssistantPromptMessage(content=content)
|
||||
return AssistantPromptMessage(
|
||||
content=content or "",
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tc[0],
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])),
|
||||
)
|
||||
for tc in tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
tool_call_id: str,
|
||||
messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None]]:
|
||||
"""Handle a single tool call and return response with files and meta."""
|
||||
# Find tool
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
if not tool_instance:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
# Get tool metadata (provider, icon, etc.)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance)
|
||||
|
||||
# Create tool call log
|
||||
tool_call_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_call_log
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
|
||||
|
||||
yield self._finish_log(
|
||||
tool_call_log,
|
||||
data={
|
||||
**tool_call_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
final_content = response_content or "Tool executed successfully"
|
||||
# Add tool response to messages
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=final_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return response_content, tool_files, tool_invoke_meta
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = error_message
|
||||
tool_call_log.data = {
|
||||
**tool_call_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_call_log
|
||||
|
||||
# Add error message to conversation
|
||||
error_content = f"Tool execution failed: {error_message}"
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=error_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return error_content, [], None
|
||||
418
api/core/agent/patterns/react.py
Normal file
418
api/core/agent/patterns/react.py
Normal file
@ -0,0 +1,418 @@
|
||||
"""ReAct strategy implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
)
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class ReActStrategy(AgentPattern):
|
||||
"""ReAct strategy using reasoning and acting approach."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
):
|
||||
"""Initialize the ReAct strategy with instruction support."""
|
||||
super().__init__(
|
||||
model_instance=model_instance,
|
||||
tools=tools,
|
||||
context=context,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
files=files,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
self.instruction = instruction
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the ReAct agent strategy."""
|
||||
# Initialize tracking
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
react_state: bool = True
|
||||
total_usage: dict[str, Any] = {"usage": None}
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Add "Observation" to stop sequences
|
||||
if "Observation" not in stop:
|
||||
stop = stop.copy()
|
||||
stop.append("Observation")
|
||||
|
||||
while react_state and iteration_step <= max_iterations:
|
||||
react_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
|
||||
# Build prompt with/without tools based on iteration
|
||||
include_tools = iteration_step < max_iterations
|
||||
current_messages = self._build_prompt_with_react_format(
|
||||
prompt_messages, agent_scratchpad, include_tools, self.instruction
|
||||
)
|
||||
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, Any] = {"usage": None}
|
||||
|
||||
# Use current messages directly (files are handled by base class if needed)
|
||||
messages_to_use = current_messages
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=self.context.user_id or "",
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log, current_messages
|
||||
)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Check if we have an action to execute
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
react_state = True
|
||||
# Execute tool
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
# Add observation to scratchpad for display
|
||||
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
|
||||
else:
|
||||
# Extract final answer
|
||||
if scratchpad.action and scratchpad.action.action_input:
|
||||
final_answer = scratchpad.action.action_input
|
||||
if isinstance(final_answer, dict):
|
||||
final_answer = json.dumps(final_answer, ensure_ascii=False)
|
||||
final_text = str(final_answer)
|
||||
elif scratchpad.thought:
|
||||
# If no action but we have thought, use thought as final answer
|
||||
final_text = scratchpad.thought
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
"observation": scratchpad.observation or None,
|
||||
"final_answer": final_text if not react_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
|
||||
)
|
||||
|
||||
def _build_prompt_with_react_format(
|
||||
self,
|
||||
original_messages: list[PromptMessage],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
include_tools: bool = True,
|
||||
instruction: str = "",
|
||||
) -> list[PromptMessage]:
|
||||
"""Build prompt messages with ReAct format."""
|
||||
# Copy messages to avoid modifying original
|
||||
messages = list(original_messages)
|
||||
|
||||
# Find and update the system prompt that should already exist
|
||||
system_prompt_found = False
|
||||
for i, msg in enumerate(messages):
|
||||
if isinstance(msg, SystemPromptMessage):
|
||||
system_prompt_found = True
|
||||
# The system prompt from frontend already has the template, just replace placeholders
|
||||
|
||||
# Format tools
|
||||
tools_str = ""
|
||||
tool_names = []
|
||||
if include_tools and self.tools:
|
||||
# Convert tools to prompt message tools format
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
|
||||
tool_names = [tool.name for tool in prompt_tools]
|
||||
|
||||
# Format tools as JSON for comprehensive information
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2)
|
||||
tool_names_str = ", ".join(f'"{name}"' for name in tool_names)
|
||||
else:
|
||||
tools_str = "No tools available"
|
||||
tool_names_str = ""
|
||||
|
||||
# Replace placeholders in the existing system prompt
|
||||
updated_content = msg.content
|
||||
assert isinstance(updated_content, str)
|
||||
updated_content = updated_content.replace("{{instruction}}", instruction)
|
||||
updated_content = updated_content.replace("{{tools}}", tools_str)
|
||||
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
|
||||
|
||||
# Create new SystemPromptMessage with updated content
|
||||
messages[i] = SystemPromptMessage(content=updated_content)
|
||||
break
|
||||
|
||||
# If no system prompt found, that's unexpected but add scratchpad anyway
|
||||
if not system_prompt_found:
|
||||
# This shouldn't happen if frontend is working correctly
|
||||
pass
|
||||
|
||||
# Format agent scratchpad
|
||||
scratchpad_str = ""
|
||||
if agent_scratchpad:
|
||||
scratchpad_parts: list[str] = []
|
||||
for unit in agent_scratchpad:
|
||||
if unit.thought:
|
||||
scratchpad_parts.append(f"Thought: {unit.thought}")
|
||||
if unit.action_str:
|
||||
scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```")
|
||||
if unit.observation:
|
||||
scratchpad_parts.append(f"Observation: {unit.observation}")
|
||||
scratchpad_str = "\n".join(scratchpad_parts)
|
||||
|
||||
# If there's a scratchpad, append it to the last message
|
||||
if scratchpad_str:
|
||||
messages.append(AssistantPromptMessage(content=scratchpad_str))
|
||||
|
||||
return messages
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, Any],
|
||||
model_log: AgentLog,
|
||||
current_messages: list[PromptMessage],
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[AgentScratchpadUnit, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract action/thought.
|
||||
|
||||
Returns a tuple of (scratchpad_unit, finish_reason).
|
||||
"""
|
||||
usage_dict: dict[str, Any] = {}
|
||||
|
||||
# Convert non-streaming to streaming format if needed
|
||||
if isinstance(chunks, LLMResult):
|
||||
# Create a generator from the LLMResult
|
||||
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=chunks.model,
|
||||
prompt_messages=chunks.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=chunks.message,
|
||||
usage=chunks.usage,
|
||||
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
|
||||
),
|
||||
system_fingerprint=chunks.system_fingerprint or "",
|
||||
)
|
||||
|
||||
streaming_chunks = result_to_chunks()
|
||||
else:
|
||||
streaming_chunks = chunks
|
||||
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
|
||||
|
||||
# Initialize scratchpad unit
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Process chunks
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
# Action detected
|
||||
action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + action_str
|
||||
scratchpad.action_str = action_str
|
||||
scratchpad.action = chunk
|
||||
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
else:
|
||||
# Text chunk
|
||||
chunk_text = str(chunk)
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
||||
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
|
||||
# Update usage
|
||||
if usage_dict.get("usage"):
|
||||
if llm_usage.get("usage"):
|
||||
self._accumulate_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
llm_usage["usage"] = usage_dict["usage"]
|
||||
|
||||
# Clean up thought
|
||||
scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you"
|
||||
|
||||
# Finish model log
|
||||
yield self._finish_log(
|
||||
model_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
|
||||
return scratchpad, finish_reason
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
prompt_messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File]]]:
|
||||
"""Handle tool call and return observation with files."""
|
||||
tool_name = action.action_name
|
||||
tool_args: dict[str, Any] | str = action.action_input
|
||||
|
||||
# Find tool instance first to get metadata
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {}
|
||||
|
||||
# Start tool log with tool metadata
|
||||
tool_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_log
|
||||
|
||||
if not tool_instance:
|
||||
# Finish tool log with error
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"error": f"Tool {tool_name} not found",
|
||||
},
|
||||
)
|
||||
return f"Tool {tool_name} not found", []
|
||||
|
||||
# Ensure tool_args is a dict
|
||||
tool_args_dict: dict[str, Any]
|
||||
if isinstance(tool_args, str):
|
||||
try:
|
||||
tool_args_dict = json.loads(tool_args)
|
||||
except json.JSONDecodeError:
|
||||
tool_args_dict = {"input": tool_args}
|
||||
elif not isinstance(tool_args, dict):
|
||||
tool_args_dict = {"input": str(tool_args)}
|
||||
else:
|
||||
tool_args_dict = tool_args
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
|
||||
|
||||
# Finish tool log
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
|
||||
return response_content or "Tool executed successfully", tool_files
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_log.error = error_message
|
||||
tool_log.data = {
|
||||
**tool_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_log
|
||||
|
||||
return f"Tool execution failed: {error_message}", []
|
||||
107
api/core/agent/patterns/strategy_factory.py
Normal file
107
api/core/agent/patterns/strategy_factory.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""Strategy factory for creating agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.file.models import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
"""Factory for creating agent strategies based on model features."""
|
||||
|
||||
# Tool calling related features
|
||||
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
|
||||
|
||||
@staticmethod
|
||||
def create_strategy(
|
||||
model_features: list[ModelFeature],
|
||||
model_instance: ModelInstance,
|
||||
context: ExecutionContext,
|
||||
tools: list[Tool],
|
||||
files: list[File],
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
agent_strategy: AgentEntity.Strategy | None = None,
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
) -> AgentPattern:
|
||||
"""
|
||||
Create an appropriate strategy based on model features.
|
||||
|
||||
Args:
|
||||
model_features: List of model features/capabilities
|
||||
model_instance: Model instance to use
|
||||
context: Execution context containing trace/audit information
|
||||
tools: Available tools
|
||||
files: Available files
|
||||
max_iterations: Maximum iterations for the strategy
|
||||
workflow_call_depth: Depth of workflow calls
|
||||
agent_strategy: Optional explicit strategy override
|
||||
tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke)
|
||||
instruction: Optional instruction for ReAct strategy
|
||||
|
||||
Returns:
|
||||
AgentStrategy instance
|
||||
"""
|
||||
# If explicit strategy is provided and it's Function Calling, try to use it if supported
|
||||
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
# Fallback to ReAct if FC is requested but not supported
|
||||
|
||||
# If explicit strategy is Chain of Thought (ReAct)
|
||||
if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Default auto-selection logic
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
# Model supports native function calling
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
else:
|
||||
# Use ReAct strategy for models without function calling
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
@ -121,7 +120,7 @@ class VariableEntity(BaseModel):
|
||||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
json_schema: str | None = Field(default=None)
|
||||
json_schema: dict | None = Field(default=None)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
@ -135,17 +134,11 @@ class VariableEntity(BaseModel):
|
||||
|
||||
@field_validator("json_schema")
|
||||
@classmethod
|
||||
def validate_json_schema(cls, schema: str | None) -> str | None:
|
||||
def validate_json_schema(cls, schema: dict | None) -> dict | None:
|
||||
if schema is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
json_schema = json.loads(schema)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"invalid json_schema value {schema}")
|
||||
|
||||
try:
|
||||
Draft7Validator.check_schema(json_schema)
|
||||
Draft7Validator.check_schema(schema)
|
||||
except SchemaError as e:
|
||||
raise ValueError(f"Invalid JSON schema: {e.message}")
|
||||
return schema
|
||||
|
||||
@ -26,7 +26,6 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
|
||||
features_dict = workflow.features_dict
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = AdvancedChatAppConfig(
|
||||
tenant_id=app_model.tenant_id,
|
||||
|
||||
@ -4,6 +4,7 @@ import re
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Any, Union
|
||||
|
||||
@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
ChunkType,
|
||||
MessageQueueMessage,
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAgentLogEvent,
|
||||
@ -70,13 +72,122 @@ from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models import Account, Conversation, EndUser, LLMGenerationDetail, Message, MessageFile
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamEventBuffer:
|
||||
"""
|
||||
Buffer for recording stream events in order to reconstruct the generation sequence.
|
||||
Records the exact order of text chunks, thoughts, and tool calls as they stream.
|
||||
"""
|
||||
|
||||
# Accumulated reasoning content (each thought block is a separate element)
|
||||
reasoning_content: list[str] = field(default_factory=list)
|
||||
# Current reasoning buffer (accumulates until we see a different event type)
|
||||
_current_reasoning: str = ""
|
||||
# Tool calls with their details
|
||||
tool_calls: list[dict] = field(default_factory=list)
|
||||
# Tool call ID to index mapping for updating results
|
||||
_tool_call_id_map: dict[str, int] = field(default_factory=dict)
|
||||
# Sequence of events in stream order
|
||||
sequence: list[dict] = field(default_factory=list)
|
||||
# Current position in answer text
|
||||
_content_position: int = 0
|
||||
# Track last event type to detect transitions
|
||||
_last_event_type: str | None = None
|
||||
|
||||
def _flush_current_reasoning(self) -> None:
|
||||
"""Flush accumulated reasoning to the list and add to sequence."""
|
||||
if self._current_reasoning.strip():
|
||||
self.reasoning_content.append(self._current_reasoning.strip())
|
||||
self.sequence.append({"type": "reasoning", "index": len(self.reasoning_content) - 1})
|
||||
self._current_reasoning = ""
|
||||
|
||||
def record_text_chunk(self, text: str) -> None:
|
||||
"""Record a text chunk event."""
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Flush any pending reasoning first
|
||||
if self._last_event_type == "thought":
|
||||
self._flush_current_reasoning()
|
||||
|
||||
text_len = len(text)
|
||||
start_pos = self._content_position
|
||||
|
||||
# If last event was also content, extend it; otherwise create new
|
||||
if self.sequence and self.sequence[-1].get("type") == "content":
|
||||
self.sequence[-1]["end"] = start_pos + text_len
|
||||
else:
|
||||
self.sequence.append({"type": "content", "start": start_pos, "end": start_pos + text_len})
|
||||
|
||||
self._content_position += text_len
|
||||
self._last_event_type = "content"
|
||||
|
||||
def record_thought_chunk(self, text: str) -> None:
|
||||
"""Record a thought/reasoning chunk event."""
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Accumulate thought content
|
||||
self._current_reasoning += text
|
||||
self._last_event_type = "thought"
|
||||
|
||||
def record_tool_call(self, tool_call_id: str, tool_name: str, tool_arguments: str) -> None:
|
||||
"""Record a tool call event."""
|
||||
if not tool_call_id:
|
||||
return
|
||||
|
||||
# Flush any pending reasoning first
|
||||
if self._last_event_type == "thought":
|
||||
self._flush_current_reasoning()
|
||||
|
||||
# Check if this tool call already exists (we might get multiple chunks)
|
||||
if tool_call_id in self._tool_call_id_map:
|
||||
idx = self._tool_call_id_map[tool_call_id]
|
||||
# Update arguments if provided
|
||||
if tool_arguments:
|
||||
self.tool_calls[idx]["arguments"] = tool_arguments
|
||||
else:
|
||||
# New tool call
|
||||
tool_call = {
|
||||
"id": tool_call_id or "",
|
||||
"name": tool_name or "",
|
||||
"arguments": tool_arguments or "",
|
||||
"result": "",
|
||||
"elapsed_time": None,
|
||||
}
|
||||
self.tool_calls.append(tool_call)
|
||||
idx = len(self.tool_calls) - 1
|
||||
self._tool_call_id_map[tool_call_id] = idx
|
||||
self.sequence.append({"type": "tool_call", "index": idx})
|
||||
|
||||
self._last_event_type = "tool_call"
|
||||
|
||||
def record_tool_result(self, tool_call_id: str, result: str, tool_elapsed_time: float | None = None) -> None:
|
||||
"""Record a tool result event (update existing tool call)."""
|
||||
if not tool_call_id:
|
||||
return
|
||||
if tool_call_id in self._tool_call_id_map:
|
||||
idx = self._tool_call_id_map[tool_call_id]
|
||||
self.tool_calls[idx]["result"] = result
|
||||
self.tool_calls[idx]["elapsed_time"] = tool_elapsed_time
|
||||
|
||||
def finalize(self) -> None:
|
||||
"""Finalize the buffer, flushing any pending data."""
|
||||
if self._last_event_type == "thought":
|
||||
self._flush_current_reasoning()
|
||||
|
||||
def has_data(self) -> bool:
|
||||
"""Check if there's any meaningful data recorded."""
|
||||
return bool(self.reasoning_content or self.tool_calls or self.sequence)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
@ -144,6 +255,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._workflow_run_id: str = ""
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
self._graph_runtime_state: GraphRuntimeState | None = None
|
||||
# Stream event buffer for recording generation sequence
|
||||
self._stream_buffer = StreamEventBuffer()
|
||||
self._seed_graph_runtime_state_from_queue_manager()
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
@ -358,25 +471,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
|
||||
# For ANSWER nodes, check if we need to send a message_replace event
|
||||
# Only send if the final output differs from the accumulated task_state.answer
|
||||
# This happens when variables were updated by variable_assigner during workflow execution
|
||||
if event.node_type == NodeType.ANSWER and event.outputs:
|
||||
final_answer = event.outputs.get("answer")
|
||||
if final_answer is not None and final_answer != self._task_state.answer:
|
||||
logger.info(
|
||||
"ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event",
|
||||
final_answer,
|
||||
self._task_state.answer,
|
||||
)
|
||||
# Update the task state answer
|
||||
self._task_state.answer = str(final_answer)
|
||||
# Send message_replace event to update the UI
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=str(final_answer),
|
||||
reason="variable_update",
|
||||
)
|
||||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||
@ -402,7 +496,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle text chunk events."""
|
||||
"""Handle text chunk events and record to stream buffer for sequence reconstruction."""
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
return
|
||||
@ -424,9 +518,52 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
tool_call = event.tool_call
|
||||
tool_result = event.tool_result
|
||||
tool_payload = tool_call or tool_result
|
||||
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else ""
|
||||
tool_name = tool_payload.name if tool_payload and tool_payload.name else ""
|
||||
tool_arguments = tool_call.arguments if tool_call and tool_call.arguments else ""
|
||||
tool_files = tool_result.files if tool_result else []
|
||||
tool_elapsed_time = tool_result.elapsed_time if tool_result else None
|
||||
tool_icon = tool_payload.icon if tool_payload else None
|
||||
tool_icon_dark = tool_payload.icon_dark if tool_payload else None
|
||||
# Record stream event based on chunk type
|
||||
chunk_type = event.chunk_type or ChunkType.TEXT
|
||||
match chunk_type:
|
||||
case ChunkType.TEXT:
|
||||
self._stream_buffer.record_text_chunk(delta_text)
|
||||
self._task_state.answer += delta_text
|
||||
case ChunkType.THOUGHT:
|
||||
# Reasoning should not be part of final answer text
|
||||
self._stream_buffer.record_thought_chunk(delta_text)
|
||||
case ChunkType.TOOL_CALL:
|
||||
self._stream_buffer.record_tool_call(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
)
|
||||
case ChunkType.TOOL_RESULT:
|
||||
self._stream_buffer.record_tool_result(
|
||||
tool_call_id=tool_call_id,
|
||||
result=delta_text,
|
||||
tool_elapsed_time=tool_elapsed_time,
|
||||
)
|
||||
self._task_state.answer += delta_text
|
||||
case _:
|
||||
pass
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
answer=delta_text,
|
||||
message_id=self._message_id,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
chunk_type=event.chunk_type.value if event.chunk_type else None,
|
||||
tool_call_id=tool_call_id or None,
|
||||
tool_name=tool_name or None,
|
||||
tool_arguments=tool_arguments or None,
|
||||
tool_files=tool_files,
|
||||
tool_elapsed_time=tool_elapsed_time,
|
||||
tool_icon=tool_icon,
|
||||
tool_icon_dark=tool_icon_dark,
|
||||
)
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
@ -794,6 +931,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
answer_text = self._task_state.answer
|
||||
answer_text = self._strip_think_blocks(answer_text)
|
||||
if self._recorded_files:
|
||||
# Remove markdown image links since we're storing files separately
|
||||
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
|
||||
@ -845,6 +983,54 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
# Save generation detail (reasoning/tool calls/sequence) from stream buffer
|
||||
self._save_generation_detail(session=session, message=message)
|
||||
|
||||
@staticmethod
|
||||
def _strip_think_blocks(text: str) -> str:
|
||||
"""Remove <think>...</think> blocks (including their content) from text."""
|
||||
if not text or "<think" not in text.lower():
|
||||
return text
|
||||
|
||||
clean_text = re.sub(r"<think[^>]*>.*?</think>", "", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||
return clean_text
|
||||
|
||||
def _save_generation_detail(self, *, session: Session, message: Message) -> None:
|
||||
"""
|
||||
Save LLM generation detail for Chatflow using stream event buffer.
|
||||
The buffer records the exact order of events as they streamed,
|
||||
allowing accurate reconstruction of the generation sequence.
|
||||
"""
|
||||
# Finalize the stream buffer to flush any pending data
|
||||
self._stream_buffer.finalize()
|
||||
|
||||
# Only save if there's meaningful data
|
||||
if not self._stream_buffer.has_data():
|
||||
return
|
||||
|
||||
reasoning_content = self._stream_buffer.reasoning_content
|
||||
tool_calls = self._stream_buffer.tool_calls
|
||||
sequence = self._stream_buffer.sequence
|
||||
|
||||
# Check if generation detail already exists for this message
|
||||
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
|
||||
|
||||
if existing:
|
||||
existing.reasoning_content = json.dumps(reasoning_content) if reasoning_content else None
|
||||
existing.tool_calls = json.dumps(tool_calls) if tool_calls else None
|
||||
existing.sequence = json.dumps(sequence) if sequence else None
|
||||
else:
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
reasoning_content=json.dumps(reasoning_content) if reasoning_content else None,
|
||||
tool_calls=json.dumps(tool_calls) if tool_calls else None,
|
||||
sequence=json.dumps(sequence) if sequence else None,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
@ -3,10 +3,8 @@ from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
@ -14,8 +12,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
@ -194,22 +191,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
# check LLM mode
|
||||
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
|
||||
runner_cls = CotChatAgentRunner
|
||||
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
|
||||
runner_cls = CotCompletionAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
|
||||
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
runner_cls = FunctionCallAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
|
||||
|
||||
runner = runner_cls(
|
||||
runner = AgentAppRunner(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation_result,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
@ -76,12 +75,24 @@ class BaseAppGenerator:
|
||||
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
|
||||
|
||||
# Check if all files are converted to File
|
||||
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
|
||||
raise ValueError("Invalid input type")
|
||||
if any(
|
||||
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
|
||||
):
|
||||
raise ValueError("Invalid input type")
|
||||
invalid_dict_keys = [
|
||||
k
|
||||
for k, v in user_inputs.items()
|
||||
if isinstance(v, dict)
|
||||
and entity_dictionary[k].type not in {VariableEntityType.FILE, VariableEntityType.JSON_OBJECT}
|
||||
]
|
||||
if invalid_dict_keys:
|
||||
raise ValueError(f"Invalid input type for {invalid_dict_keys}")
|
||||
|
||||
invalid_list_dict_keys = [
|
||||
k
|
||||
for k, v in user_inputs.items()
|
||||
if isinstance(v, list)
|
||||
and any(isinstance(item, dict) for item in v)
|
||||
and entity_dictionary[k].type != VariableEntityType.FILE_LIST
|
||||
]
|
||||
if invalid_list_dict_keys:
|
||||
raise ValueError(f"Invalid input type for {invalid_list_dict_keys}")
|
||||
|
||||
return user_inputs
|
||||
|
||||
@ -178,12 +189,8 @@ class BaseAppGenerator:
|
||||
elif value == 0:
|
||||
value = False
|
||||
case VariableEntityType.JSON_OBJECT:
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a string")
|
||||
try:
|
||||
json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
|
||||
if value and not isinstance(value, dict):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a dict")
|
||||
case _:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
|
||||
@ -671,7 +671,7 @@ class WorkflowResponseConverter:
|
||||
task_id=task_id,
|
||||
data=AgentLogStreamResponse.Data(
|
||||
node_execution_id=event.node_execution_id,
|
||||
id=event.id,
|
||||
message_id=event.id,
|
||||
parent_id=event.parent_id,
|
||||
label=event.label,
|
||||
error=event.error,
|
||||
|
||||
@ -13,6 +13,7 @@ from core.app.apps.common.workflow_response_converter import WorkflowResponseCon
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
ChunkType,
|
||||
MessageQueueMessage,
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
@ -483,11 +484,33 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if delta_text is None:
|
||||
return
|
||||
|
||||
tool_call = event.tool_call
|
||||
tool_result = event.tool_result
|
||||
tool_payload = tool_call or tool_result
|
||||
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else None
|
||||
tool_name = tool_payload.name if tool_payload and tool_payload.name else None
|
||||
tool_arguments = tool_call.arguments if tool_call else None
|
||||
tool_elapsed_time = tool_result.elapsed_time if tool_result else None
|
||||
tool_files = tool_result.files if tool_result else []
|
||||
tool_icon = tool_payload.icon if tool_payload else None
|
||||
tool_icon_dark = tool_payload.icon_dark if tool_payload else None
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector)
|
||||
yield self._text_chunk_to_stream_response(
|
||||
text=delta_text,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
chunk_type=event.chunk_type,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
tool_files=tool_files,
|
||||
tool_elapsed_time=tool_elapsed_time,
|
||||
tool_icon=tool_icon,
|
||||
tool_icon_dark=tool_icon_dark,
|
||||
)
|
||||
|
||||
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle agent log events."""
|
||||
@ -650,16 +673,61 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
session.add(workflow_app_log)
|
||||
|
||||
def _text_chunk_to_stream_response(
|
||||
self, text: str, from_variable_selector: list[str] | None = None
|
||||
self,
|
||||
text: str,
|
||||
from_variable_selector: list[str] | None = None,
|
||||
chunk_type: ChunkType | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_arguments: str | None = None,
|
||||
tool_files: list[str] | None = None,
|
||||
tool_error: str | None = None,
|
||||
tool_elapsed_time: float | None = None,
|
||||
tool_icon: str | dict | None = None,
|
||||
tool_icon_dark: str | dict | None = None,
|
||||
) -> TextChunkStreamResponse:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
from core.app.entities.task_entities import ChunkType as ResponseChunkType
|
||||
|
||||
response_chunk_type = ResponseChunkType(chunk_type.value) if chunk_type else ResponseChunkType.TEXT
|
||||
|
||||
data = TextChunkStreamResponse.Data(
|
||||
text=text,
|
||||
from_variable_selector=from_variable_selector,
|
||||
chunk_type=response_chunk_type,
|
||||
)
|
||||
|
||||
if response_chunk_type == ResponseChunkType.TOOL_CALL:
|
||||
data = data.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
elif response_chunk_type == ResponseChunkType.TOOL_RESULT:
|
||||
data = data.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_files": tool_files,
|
||||
"tool_error": tool_error,
|
||||
"tool_elapsed_time": tool_elapsed_time,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
|
||||
response = TextChunkStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
|
||||
data=data,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -463,12 +463,20 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
from core.app.entities.queue_entities import ChunkType as QueueChunkType
|
||||
|
||||
if event.is_final and not event.chunk:
|
||||
return
|
||||
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk,
|
||||
from_variable_selector=list(event.selector),
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
chunk_type=QueueChunkType(event.chunk_type.value),
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
|
||||
@ -9,7 +9,6 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAp
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
@ -81,11 +80,6 @@ class InvokeFrom(StrEnum):
|
||||
|
||||
return "dev"
|
||||
|
||||
def to_creator_user_role(self) -> CreatorUserRole:
|
||||
if self in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}:
|
||||
return CreatorUserRole.ACCOUNT
|
||||
return CreatorUserRole.END_USER
|
||||
|
||||
|
||||
class ModelConfigWithCredentialsEntity(BaseModel):
|
||||
"""
|
||||
|
||||
70
api/core/app/entities/llm_generation_entities.py
Normal file
70
api/core/app/entities/llm_generation_entities.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""
|
||||
LLM Generation Detail entities.
|
||||
|
||||
Defines the structure for storing and transmitting LLM generation details
|
||||
including reasoning content, tool calls, and their sequence.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ContentSegment(BaseModel):
|
||||
"""Represents a content segment in the generation sequence."""
|
||||
|
||||
type: Literal["content"] = "content"
|
||||
start: int = Field(..., description="Start position in the text")
|
||||
end: int = Field(..., description="End position in the text")
|
||||
|
||||
|
||||
class ReasoningSegment(BaseModel):
|
||||
"""Represents a reasoning segment in the generation sequence."""
|
||||
|
||||
type: Literal["reasoning"] = "reasoning"
|
||||
index: int = Field(..., description="Index into reasoning_content array")
|
||||
|
||||
|
||||
class ToolCallSegment(BaseModel):
|
||||
"""Represents a tool call segment in the generation sequence."""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
index: int = Field(..., description="Index into tool_calls array")
|
||||
|
||||
|
||||
SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment
|
||||
|
||||
|
||||
class ToolCallDetail(BaseModel):
|
||||
"""Represents a tool call with its arguments and result."""
|
||||
|
||||
id: str = Field(default="", description="Unique identifier for the tool call")
|
||||
name: str = Field(..., description="Name of the tool")
|
||||
arguments: str = Field(default="", description="JSON string of tool arguments")
|
||||
result: str = Field(default="", description="Result from the tool execution")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds")
|
||||
|
||||
|
||||
class LLMGenerationDetailData(BaseModel):
|
||||
"""
|
||||
Domain model for LLM generation detail.
|
||||
|
||||
Contains the structured data for reasoning content, tool calls,
|
||||
and their display sequence.
|
||||
"""
|
||||
|
||||
reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments")
|
||||
tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details")
|
||||
sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments")
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if there's any meaningful generation detail."""
|
||||
return not self.reasoning_content and not self.tool_calls
|
||||
|
||||
def to_response_dict(self) -> dict:
|
||||
"""Convert to dictionary for API response."""
|
||||
return {
|
||||
"reasoning_content": self.reasoning_content,
|
||||
"tool_calls": [tc.model_dump() for tc in self.tool_calls],
|
||||
"sequence": [seg.model_dump() for seg in self.sequence],
|
||||
}
|
||||
@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
@ -177,6 +177,17 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
THOUGHT_START = "thought_start" # Agent thought start
|
||||
THOUGHT_END = "thought_end" # Agent thought end
|
||||
|
||||
|
||||
class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueTextChunkEvent entity
|
||||
@ -191,6 +202,16 @@ class QueueTextChunkEvent(AppQueueEvent):
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
# Extended fields for Agent/Tool streaming
|
||||
chunk_type: ChunkType = ChunkType.TEXT
|
||||
"""type of the chunk"""
|
||||
|
||||
# Tool streaming payloads
|
||||
tool_call: ToolCall | None = None
|
||||
"""structured tool call info"""
|
||||
tool_result: ToolResult | None = None
|
||||
"""structured tool result info"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
@ -113,6 +113,38 @@ class MessageStreamResponse(StreamResponse):
|
||||
answer: str
|
||||
from_variable_selector: list[str] | None = None
|
||||
|
||||
# Extended fields for Agent/Tool streaming (imported at runtime to avoid circular import)
|
||||
chunk_type: str | None = None
|
||||
"""type of the chunk: text, tool_call, tool_result, thought"""
|
||||
|
||||
# Tool call fields (when chunk_type == "tool_call")
|
||||
tool_call_id: str | None = None
|
||||
"""unique identifier for this tool call"""
|
||||
tool_name: str | None = None
|
||||
"""name of the tool being called"""
|
||||
tool_arguments: str | None = None
|
||||
"""accumulated tool arguments JSON"""
|
||||
|
||||
# Tool result fields (when chunk_type == "tool_result")
|
||||
tool_files: list[str] | None = None
|
||||
"""file IDs produced by tool"""
|
||||
tool_error: str | None = None
|
||||
"""error message if tool failed"""
|
||||
tool_elapsed_time: float | None = None
|
||||
"""elapsed time spent executing the tool"""
|
||||
tool_icon: str | dict | None = None
|
||||
"""icon of the tool"""
|
||||
tool_icon_dark: str | dict | None = None
|
||||
"""dark theme icon of the tool"""
|
||||
|
||||
def model_dump(self, *args, **kwargs) -> dict[str, object]:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump(*args, **kwargs)
|
||||
|
||||
def model_dump_json(self, *args, **kwargs) -> str:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump_json(*args, **kwargs)
|
||||
|
||||
|
||||
class MessageAudioStreamResponse(StreamResponse):
|
||||
"""
|
||||
@ -582,6 +614,17 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
THOUGHT_START = "thought_start" # Agent thought start
|
||||
THOUGHT_END = "thought_end" # Agent thought end
|
||||
|
||||
|
||||
class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
TextChunkStreamResponse entity
|
||||
@ -595,6 +638,36 @@ class TextChunkStreamResponse(StreamResponse):
|
||||
text: str
|
||||
from_variable_selector: list[str] | None = None
|
||||
|
||||
# Extended fields for Agent/Tool streaming
|
||||
chunk_type: ChunkType = ChunkType.TEXT
|
||||
"""type of the chunk"""
|
||||
|
||||
# Tool call fields (when chunk_type == TOOL_CALL)
|
||||
tool_call_id: str | None = None
|
||||
"""unique identifier for this tool call"""
|
||||
tool_name: str | None = None
|
||||
"""name of the tool being called"""
|
||||
tool_arguments: str | None = None
|
||||
"""accumulated tool arguments JSON"""
|
||||
|
||||
# Tool result fields (when chunk_type == TOOL_RESULT)
|
||||
tool_files: list[str] | None = None
|
||||
"""file IDs produced by tool"""
|
||||
tool_error: str | None = None
|
||||
"""error message if tool failed"""
|
||||
|
||||
# Tool elapsed time fields (when chunk_type == TOOL_RESULT)
|
||||
tool_elapsed_time: float | None = None
|
||||
"""elapsed time spent executing the tool"""
|
||||
|
||||
def model_dump(self, *args, **kwargs) -> dict[str, object]:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump(*args, **kwargs)
|
||||
|
||||
def model_dump_json(self, *args, **kwargs) -> str:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump_json(*args, **kwargs)
|
||||
|
||||
event: StreamEvent = StreamEvent.TEXT_CHUNK
|
||||
data: Data
|
||||
|
||||
@ -743,7 +816,7 @@ class AgentLogStreamResponse(StreamResponse):
|
||||
"""
|
||||
|
||||
node_execution_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
label: str
|
||||
parent_id: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
@ -58,7 +59,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought
|
||||
from models.model import AppMode, Conversation, LLMGenerationDetail, Message, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -68,6 +69,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||
|
||||
@ -409,11 +412,136 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
)
|
||||
)
|
||||
|
||||
# Save LLM generation detail if there's reasoning_content
|
||||
self._save_generation_detail(session=session, message=message, llm_result=llm_result)
|
||||
|
||||
message_was_created.send(
|
||||
message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
)
|
||||
|
||||
def _save_generation_detail(self, *, session: Session, message: Message, llm_result: LLMResult) -> None:
|
||||
"""
|
||||
Save LLM generation detail for Completion/Chat/Agent-Chat applications.
|
||||
For Agent-Chat, also merges MessageAgentThought records.
|
||||
"""
|
||||
import json
|
||||
|
||||
reasoning_list: list[str] = []
|
||||
tool_calls_list: list[dict] = []
|
||||
sequence: list[dict] = []
|
||||
answer = message.answer or ""
|
||||
|
||||
# Check if this is Agent-Chat mode by looking for agent thoughts
|
||||
agent_thoughts = (
|
||||
session.query(MessageAgentThought)
|
||||
.filter_by(message_id=message.id)
|
||||
.order_by(MessageAgentThought.position.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
if agent_thoughts:
|
||||
# Agent-Chat mode: merge MessageAgentThought records
|
||||
content_pos = 0
|
||||
cleaned_answer_parts: list[str] = []
|
||||
for thought in agent_thoughts:
|
||||
# Add thought/reasoning
|
||||
if thought.thought:
|
||||
reasoning_text = thought.thought
|
||||
if "<think" in reasoning_text.lower():
|
||||
clean_text, extracted_reasoning = self._split_reasoning_from_answer(reasoning_text)
|
||||
if extracted_reasoning:
|
||||
reasoning_text = extracted_reasoning
|
||||
thought.thought = clean_text or extracted_reasoning
|
||||
reasoning_list.append(reasoning_text)
|
||||
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
|
||||
|
||||
# Add tool calls
|
||||
if thought.tool:
|
||||
tool_calls_list.append(
|
||||
{
|
||||
"name": thought.tool,
|
||||
"arguments": thought.tool_input or "",
|
||||
"result": thought.observation or "",
|
||||
}
|
||||
)
|
||||
sequence.append({"type": "tool_call", "index": len(tool_calls_list) - 1})
|
||||
|
||||
# Add answer content if present
|
||||
if thought.answer:
|
||||
content_text = thought.answer
|
||||
if "<think" in content_text.lower():
|
||||
clean_answer, extracted_reasoning = self._split_reasoning_from_answer(content_text)
|
||||
if extracted_reasoning:
|
||||
reasoning_list.append(extracted_reasoning)
|
||||
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
|
||||
content_text = clean_answer
|
||||
thought.answer = clean_answer or content_text
|
||||
|
||||
if content_text:
|
||||
start = content_pos
|
||||
end = content_pos + len(content_text)
|
||||
sequence.append({"type": "content", "start": start, "end": end})
|
||||
content_pos = end
|
||||
cleaned_answer_parts.append(content_text)
|
||||
|
||||
if cleaned_answer_parts:
|
||||
merged_answer = "".join(cleaned_answer_parts)
|
||||
message.answer = merged_answer
|
||||
llm_result.message.content = merged_answer
|
||||
else:
|
||||
# Completion/Chat mode: use reasoning_content from llm_result
|
||||
reasoning_content = llm_result.reasoning_content
|
||||
if not reasoning_content and answer:
|
||||
# Extract reasoning from <think> blocks and clean the final answer
|
||||
clean_answer, reasoning_content = self._split_reasoning_from_answer(answer)
|
||||
if reasoning_content:
|
||||
answer = clean_answer
|
||||
llm_result.message.content = clean_answer
|
||||
llm_result.reasoning_content = reasoning_content
|
||||
message.answer = clean_answer
|
||||
if reasoning_content:
|
||||
reasoning_list = [reasoning_content]
|
||||
# Content comes first, then reasoning
|
||||
if answer:
|
||||
sequence.append({"type": "content", "start": 0, "end": len(answer)})
|
||||
sequence.append({"type": "reasoning", "index": 0})
|
||||
|
||||
# Only save if there's meaningful generation detail
|
||||
if not reasoning_list and not tool_calls_list:
|
||||
return
|
||||
|
||||
# Check if generation detail already exists
|
||||
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
|
||||
|
||||
if existing:
|
||||
existing.reasoning_content = json.dumps(reasoning_list) if reasoning_list else None
|
||||
existing.tool_calls = json.dumps(tool_calls_list) if tool_calls_list else None
|
||||
existing.sequence = json.dumps(sequence) if sequence else None
|
||||
else:
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
reasoning_content=json.dumps(reasoning_list) if reasoning_list else None,
|
||||
tool_calls=json.dumps(tool_calls_list) if tool_calls_list else None,
|
||||
sequence=json.dumps(sequence) if sequence else None,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
@classmethod
|
||||
def _split_reasoning_from_answer(cls, text: str) -> tuple[str, str]:
|
||||
"""
|
||||
Extract reasoning segments from <think> blocks and return (clean_text, reasoning).
|
||||
"""
|
||||
matches = cls._THINK_PATTERN.findall(text)
|
||||
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
|
||||
|
||||
clean_text = cls._THINK_PATTERN.sub("", text)
|
||||
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||
|
||||
return clean_text, reasoning_content or ""
|
||||
|
||||
def _handle_stop(self, event: QueueStopEvent):
|
||||
"""
|
||||
Handle stop.
|
||||
|
||||
@ -232,15 +232,31 @@ class MessageCycleManager:
|
||||
answer: str,
|
||||
message_id: str,
|
||||
from_variable_selector: list[str] | None = None,
|
||||
chunk_type: str | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_arguments: str | None = None,
|
||||
tool_files: list[str] | None = None,
|
||||
tool_error: str | None = None,
|
||||
tool_elapsed_time: float | None = None,
|
||||
tool_icon: str | dict | None = None,
|
||||
tool_icon_dark: str | dict | None = None,
|
||||
event_type: StreamEvent | None = None,
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
Message to stream response.
|
||||
:param answer: answer
|
||||
:param message_id: message id
|
||||
:param from_variable_selector: from variable selector
|
||||
:param chunk_type: type of the chunk (text, function_call, tool_result, thought)
|
||||
:param tool_call_id: unique identifier for this tool call
|
||||
:param tool_name: name of the tool being called
|
||||
:param tool_arguments: accumulated tool arguments JSON
|
||||
:param tool_files: file IDs produced by tool
|
||||
:param tool_error: error message if tool failed
|
||||
:return:
|
||||
"""
|
||||
return MessageStreamResponse(
|
||||
response = MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer,
|
||||
@ -248,6 +264,35 @@ class MessageCycleManager:
|
||||
event=event_type or StreamEvent.MESSAGE,
|
||||
)
|
||||
|
||||
if chunk_type:
|
||||
response = response.model_copy(update={"chunk_type": chunk_type})
|
||||
|
||||
if chunk_type == "tool_call":
|
||||
response = response.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
elif chunk_type == "tool_result":
|
||||
response = response.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_files": tool_files,
|
||||
"tool_error": tool_error,
|
||||
"tool_elapsed_time": tool_elapsed_time,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
"""
|
||||
Message replace to stream response.
|
||||
|
||||
@ -5,7 +5,6 @@ from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.models.document import Document
|
||||
@ -37,7 +36,9 @@ class DatasetIndexToolCallbackHandler:
|
||||
content=query,
|
||||
source="app",
|
||||
source_app_id=self._app_id,
|
||||
created_by_role=self._invoke_from.to_creator_user_role(),
|
||||
created_by_role=(
|
||||
"account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
|
||||
),
|
||||
created_by=self._user_id,
|
||||
)
|
||||
|
||||
@ -88,6 +89,8 @@ class DatasetIndexToolCallbackHandler:
|
||||
# TODO(-LAN-): Improve type check
|
||||
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
class PreviewDetail(BaseModel):
|
||||
content: str
|
||||
summary: str | None = None
|
||||
child_chunks: list[str] | None = None
|
||||
|
||||
|
||||
|
||||
@ -311,14 +311,18 @@ class IndexingRunner:
|
||||
qa_preview_texts: list[QAPreviewDetail] = []
|
||||
|
||||
total_segments = 0
|
||||
# doc_form represents the segmentation method (general, parent-child, QA)
|
||||
index_type = doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# one extract_setting is one source document
|
||||
for extract_setting in extract_settings:
|
||||
# extract
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
# Extract document content
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
|
||||
# Cleaning and segmentation
|
||||
documents = index_processor.transform(
|
||||
text_docs,
|
||||
current_user=None,
|
||||
@ -361,6 +365,12 @@ class IndexingRunner:
|
||||
|
||||
if doc_form and doc_form == "qa_model":
|
||||
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
|
||||
|
||||
# Generate summary preview
|
||||
summary_index_setting = tmp_processing_rule["summary_index_setting"] if "summary_index_setting" in tmp_processing_rule else None
|
||||
if summary_index_setting and summary_index_setting.get('enable') and preview_texts:
|
||||
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
|
||||
|
||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
|
||||
|
||||
def _extract(
|
||||
|
||||
@ -434,3 +434,6 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
|
||||
You should edit the prompt according to the IDEAL OUTPUT."""
|
||||
|
||||
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
|
||||
|
||||
DEFAULT_GENERATOR_SUMMARY_PROMPT = """
|
||||
You are a helpful assistant that summarizes long pieces of text into concise summaries. Given the following text, generate a brief summary that captures the main points and key information. The summary should be clear, concise, and written in complete sentences. """
|
||||
|
||||
@ -251,10 +251,7 @@ class AssistantPromptMessage(PromptMessage):
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_calls:
|
||||
return False
|
||||
|
||||
return True
|
||||
return super().is_empty() and not self.tool_calls
|
||||
|
||||
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from opentelemetry.trace import SpanKind
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
@ -151,6 +152,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
@ -456,6 +458,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
@ -475,6 +478,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
|
||||
)
|
||||
self.trace_client.add_span(workflow_span)
|
||||
|
||||
|
||||
@ -166,7 +166,7 @@ class SpanBuilder:
|
||||
attributes=span_data.attributes,
|
||||
events=span_data.events,
|
||||
links=span_data.links,
|
||||
kind=trace_api.SpanKind.INTERNAL,
|
||||
kind=span_data.span_kind,
|
||||
status=span_data.status,
|
||||
start_time=span_data.start_time,
|
||||
end_time=span_data.end_time,
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Any
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@ -34,3 +34,4 @@ class SpanData(BaseModel):
|
||||
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
|
||||
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
|
||||
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
|
||||
span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")
|
||||
|
||||
@ -392,6 +392,69 @@ class RetrievalService:
|
||||
records = []
|
||||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
segment_file_map = {}
|
||||
segment_summary_map = {} # Map segment_id to summary content
|
||||
summary_segment_ids = set() # Track segments retrieved via summary
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Process documents
|
||||
for document in documents:
|
||||
segment_id = None
|
||||
attachment_info = None
|
||||
child_chunk = None
|
||||
document_id = document.metadata.get("document_id")
|
||||
if document_id not in dataset_documents:
|
||||
continue
|
||||
|
||||
dataset_document = dataset_documents[document_id]
|
||||
if not dataset_document:
|
||||
continue
|
||||
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
# Handle parent-child documents
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
attachment_info_dict = cls.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
attachment_info = attachment_info_dict["attachment_info"]
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
else:
|
||||
# Check if this is a summary document
|
||||
is_summary = document.metadata.get("is_summary", False)
|
||||
if is_summary:
|
||||
# For summary documents, find the original chunk via original_chunk_id
|
||||
original_chunk_id = document.metadata.get("original_chunk_id")
|
||||
if not original_chunk_id:
|
||||
continue
|
||||
segment_id = original_chunk_id
|
||||
# Track that this segment was retrieved via summary
|
||||
summary_segment_ids.add(segment_id)
|
||||
else:
|
||||
# For normal documents, find by child chunk index_node_id
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
|
||||
child_chunk = session.scalar(child_chunk_stmt)
|
||||
|
||||
if not child_chunk:
|
||||
continue
|
||||
segment_id = child_chunk.segment_id
|
||||
|
||||
if not segment_id:
|
||||
continue
|
||||
|
||||
segment = (
|
||||
session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == segment_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
valid_dataset_documents = {}
|
||||
image_doc_ids: list[Any] = []
|
||||
@ -507,7 +570,47 @@ class RetrievalService:
|
||||
max_score = max(
|
||||
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
|
||||
)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
if segment:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
else:
|
||||
# Check if this is a summary document
|
||||
is_summary = document.metadata.get("is_summary", False)
|
||||
if is_summary:
|
||||
# For summary documents, find the original chunk via original_chunk_id
|
||||
original_chunk_id = document.metadata.get("original_chunk_id")
|
||||
if not original_chunk_id:
|
||||
continue
|
||||
# Track that this segment was retrieved via summary
|
||||
summary_segment_ids.add(original_chunk_id)
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == original_chunk_id,
|
||||
)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
else:
|
||||
# For normal documents, find by index_node_id
|
||||
index_node_id = document.metadata.get("doc_id")
|
||||
if not index_node_id:
|
||||
continue
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata.get("score"), # type: ignore
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": max_score,
|
||||
"child_chunks": child_chunk_details,
|
||||
@ -542,6 +645,23 @@ class RetrievalService:
|
||||
if record["segment"].id in attachment_map:
|
||||
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
|
||||
|
||||
# Batch query summaries for segments retrieved via summary (only enabled summaries)
|
||||
if summary_segment_ids:
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
DocumentSegmentSummary.chunk_id.in_(summary_segment_ids),
|
||||
DocumentSegmentSummary.status == "completed",
|
||||
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for summary in summaries:
|
||||
if summary.summary_content:
|
||||
segment_summary_map[summary.chunk_id] = summary.summary_content
|
||||
|
||||
result: list[RetrievalSegments] = []
|
||||
for record in records:
|
||||
# Extract segment
|
||||
@ -576,9 +696,16 @@ class RetrievalService:
|
||||
else None
|
||||
)
|
||||
|
||||
# Extract summary if this segment was retrieved via summary
|
||||
summary_content = segment_summary_map.get(segment.id)
|
||||
|
||||
# Create RetrievalSegments object
|
||||
retrieval_segment = RetrievalSegments(
|
||||
segment=segment, child_chunks=child_chunks_list, score=score, files=files
|
||||
segment=segment,
|
||||
child_chunks=child_chunks_list,
|
||||
score=score,
|
||||
files=files,
|
||||
summary=summary_content
|
||||
)
|
||||
result.append(retrieval_segment)
|
||||
|
||||
|
||||
@ -20,3 +20,4 @@ class RetrievalSegments(BaseModel):
|
||||
child_chunks: list[RetrievalChildChunk] | None = None
|
||||
score: float | None = None
|
||||
files: list[dict[str, str | int]] | None = None
|
||||
summary: str | None = None # Summary content if retrieved via summary index
|
||||
|
||||
@ -13,6 +13,7 @@ from urllib.parse import unquote, urlparse
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -45,6 +46,15 @@ class BaseIndexProcessor(ABC):
|
||||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def generate_summary_preview(self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
|
||||
The summary can be stored in a new attribute, e.g., summary.
|
||||
This method should be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load(
|
||||
self,
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
"""Paragraph index processor."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@ -17,12 +21,19 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_runtime.entities.message_entities import UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.model_manager import ModelInstance
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
@ -108,6 +119,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
keyword.add_texts(documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
@ -227,3 +261,70 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
else:
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
def generate_summary_preview(self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment, concurrently call generate_summary to generate a summary
|
||||
and write it to the summary attribute of PreviewDetail.
|
||||
"""
|
||||
import concurrent.futures
|
||||
from flask import current_app
|
||||
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def process(preview: PreviewDetail) -> None:
|
||||
"""Generate summary for a single preview item."""
|
||||
try:
|
||||
if flask_app:
|
||||
# Ensure Flask app context in worker thread
|
||||
with flask_app.app_context():
|
||||
summary = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
preview.summary = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
preview.summary = summary
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary for preview: {str(e)}")
|
||||
# Don't fail the entire preview if summary generation fails
|
||||
preview.summary = None
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
list(executor.map(process, preview_texts))
|
||||
return preview_texts
|
||||
|
||||
@staticmethod
|
||||
def generate_summary(tenant_id: str, text: str, summary_index_setting: dict = None) -> str:
|
||||
"""
|
||||
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt.
|
||||
"""
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
raise ValueError("summary_index_setting is required and must be enabled to generate summary.")
|
||||
|
||||
model_name = summary_index_setting.get("model_name")
|
||||
model_provider_name = summary_index_setting.get("model_provider_name")
|
||||
summary_prompt = summary_index_setting.get("summary_prompt")
|
||||
|
||||
# Import default summary prompt
|
||||
if not summary_prompt:
|
||||
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
|
||||
prompt = f"{summary_prompt}\n{text}"
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(tenant_id, model_provider_name, ModelType.LLM)
|
||||
model_instance = ModelInstance(provider_model_bundle, model_name)
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={},
|
||||
stream=False
|
||||
)
|
||||
|
||||
return getattr(result.message, "content", "")
|
||||
|
||||
@ -25,6 +25,7 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
@ -135,6 +136,29 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# node_ids is segment's node_ids
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
delete_child_chunks = kwargs.get("delete_child_chunks") or False
|
||||
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
|
||||
|
||||
@ -25,9 +25,10 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -144,6 +145,30 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Note: qa_model doesn't generate summaries, but we clean them for completeness
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
vector.delete_by_ids(node_ids)
|
||||
|
||||
@ -63,7 +63,6 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models import UploadFile
|
||||
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import CreatorUserRole
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
@ -177,13 +176,13 @@ class DatasetRetrieval:
|
||||
)
|
||||
|
||||
all_documents = []
|
||||
creator_user_role = invoke_from.to_creator_user_role()
|
||||
user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
all_documents = self.single_retrieve(
|
||||
app_id,
|
||||
tenant_id,
|
||||
user_id,
|
||||
creator_user_role,
|
||||
user_from,
|
||||
query,
|
||||
available_datasets,
|
||||
model_instance,
|
||||
@ -198,7 +197,7 @@ class DatasetRetrieval:
|
||||
app_id,
|
||||
tenant_id,
|
||||
user_id,
|
||||
creator_user_role,
|
||||
user_from,
|
||||
available_datasets,
|
||||
query,
|
||||
retrieve_config.top_k or 0,
|
||||
@ -335,7 +334,7 @@ class DatasetRetrieval:
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
creator_user_role: CreatorUserRole,
|
||||
user_from: str,
|
||||
query: str,
|
||||
available_datasets: list,
|
||||
model_instance: ModelInstance,
|
||||
@ -445,7 +444,7 @@ class DatasetRetrieval:
|
||||
weights=retrieval_model_config.get("weights", None),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
self._on_query(query, None, [dataset_id], app_id, creator_user_role, user_id)
|
||||
self._on_query(query, None, [dataset_id], app_id, user_from, user_id)
|
||||
|
||||
if results:
|
||||
thread = threading.Thread(
|
||||
@ -467,7 +466,7 @@ class DatasetRetrieval:
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
creator_user_role: CreatorUserRole,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str | None,
|
||||
top_k: int,
|
||||
@ -585,7 +584,7 @@ class DatasetRetrieval:
|
||||
|
||||
if thread_exceptions:
|
||||
raise thread_exceptions[0]
|
||||
self._on_query(query, attachment_ids, dataset_ids, app_id, creator_user_role, user_id)
|
||||
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
|
||||
|
||||
if all_documents:
|
||||
# add thread to call _on_retrieval_end
|
||||
@ -734,7 +733,7 @@ class DatasetRetrieval:
|
||||
attachment_ids: list[str] | None,
|
||||
dataset_ids: list[str],
|
||||
app_id: str,
|
||||
creator_user_role: CreatorUserRole,
|
||||
user_from: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
@ -756,7 +755,7 @@ class DatasetRetrieval:
|
||||
content=json.dumps(contents),
|
||||
source="app",
|
||||
source_app_id=app_id,
|
||||
created_by_role=creator_user_role,
|
||||
created_by_role=user_from,
|
||||
created_by=user_id,
|
||||
)
|
||||
dataset_queries.append(dataset_query)
|
||||
|
||||
@ -29,6 +29,7 @@ from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
LLMGenerationDetail,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
)
|
||||
@ -457,6 +458,113 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
session.merge(db_model)
|
||||
session.flush()
|
||||
|
||||
# Save LLMGenerationDetail for LLM nodes with successful execution
|
||||
if (
|
||||
domain_model.node_type == NodeType.LLM
|
||||
and domain_model.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
and domain_model.outputs is not None
|
||||
):
|
||||
self._save_llm_generation_detail(session, domain_model)
|
||||
|
||||
def _save_llm_generation_detail(self, session, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save LLM generation detail for LLM nodes.
|
||||
Extracts reasoning_content, tool_calls, and sequence from outputs and metadata.
|
||||
"""
|
||||
outputs = execution.outputs or {}
|
||||
metadata = execution.metadata or {}
|
||||
|
||||
reasoning_list = self._extract_reasoning(outputs)
|
||||
tool_calls_list = self._extract_tool_calls(metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG))
|
||||
|
||||
if not reasoning_list and not tool_calls_list:
|
||||
return
|
||||
|
||||
sequence = self._build_generation_sequence(outputs.get("text", ""), reasoning_list, tool_calls_list)
|
||||
self._upsert_generation_detail(session, execution, reasoning_list, tool_calls_list, sequence)
|
||||
|
||||
def _extract_reasoning(self, outputs: Mapping[str, Any]) -> list[str]:
|
||||
"""Extract reasoning_content as a clean list of non-empty strings."""
|
||||
reasoning_content = outputs.get("reasoning_content")
|
||||
if isinstance(reasoning_content, str):
|
||||
trimmed = reasoning_content.strip()
|
||||
return [trimmed] if trimmed else []
|
||||
if isinstance(reasoning_content, list):
|
||||
return [item.strip() for item in reasoning_content if isinstance(item, str) and item.strip()]
|
||||
return []
|
||||
|
||||
def _extract_tool_calls(self, agent_log: Any) -> list[dict[str, str]]:
|
||||
"""Extract tool call records from agent logs."""
|
||||
if not agent_log or not isinstance(agent_log, list):
|
||||
return []
|
||||
|
||||
tool_calls: list[dict[str, str]] = []
|
||||
for log in agent_log:
|
||||
log_data = log.data if hasattr(log, "data") else (log.get("data", {}) if isinstance(log, dict) else {})
|
||||
tool_name = log_data.get("tool_name")
|
||||
if tool_name and str(tool_name).strip():
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": log_data.get("tool_call_id", ""),
|
||||
"name": tool_name,
|
||||
"arguments": json.dumps(log_data.get("tool_args", {})),
|
||||
"result": str(log_data.get("output", "")),
|
||||
}
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
def _build_generation_sequence(
|
||||
self, text: str, reasoning_list: list[str], tool_calls_list: list[dict[str, str]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a simple content/reasoning/tool_call sequence."""
|
||||
sequence: list[dict[str, Any]] = []
|
||||
if text:
|
||||
sequence.append({"type": "content", "start": 0, "end": len(text)})
|
||||
for index in range(len(reasoning_list)):
|
||||
sequence.append({"type": "reasoning", "index": index})
|
||||
for index in range(len(tool_calls_list)):
|
||||
sequence.append({"type": "tool_call", "index": index})
|
||||
return sequence
|
||||
|
||||
def _upsert_generation_detail(
|
||||
self,
|
||||
session,
|
||||
execution: WorkflowNodeExecution,
|
||||
reasoning_list: list[str],
|
||||
tool_calls_list: list[dict[str, str]],
|
||||
sequence: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Insert or update LLMGenerationDetail with serialized fields."""
|
||||
existing = (
|
||||
session.query(LLMGenerationDetail)
|
||||
.filter_by(
|
||||
workflow_run_id=execution.workflow_execution_id,
|
||||
node_id=execution.node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
reasoning_json = json.dumps(reasoning_list) if reasoning_list else None
|
||||
tool_calls_json = json.dumps(tool_calls_list) if tool_calls_list else None
|
||||
sequence_json = json.dumps(sequence) if sequence else None
|
||||
|
||||
if existing:
|
||||
existing.reasoning_content = reasoning_json
|
||||
existing.tool_calls = tool_calls_json
|
||||
existing.sequence = sequence_json
|
||||
return
|
||||
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_run_id=execution.workflow_execution_id,
|
||||
node_id=execution.node_id,
|
||||
reasoning_content=reasoning_json,
|
||||
tool_calls=tool_calls_json,
|
||||
sequence=sequence_json,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
def get_db_models_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any
|
||||
if TYPE_CHECKING:
|
||||
from models.model import File
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
@ -154,6 +155,60 @@ class Tool(ABC):
|
||||
|
||||
return parameters
|
||||
|
||||
def to_prompt_message_tool(self) -> PromptMessageTool:
|
||||
message_tool = PromptMessageTool(
|
||||
name=self.entity.identity.name,
|
||||
description=self.entity.description.llm if self.entity.description else "",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
|
||||
parameters = self.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = parameter.type.as_normal_type()
|
||||
if parameter.type in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}:
|
||||
# Determine the description based on parameter type
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
file_format_desc = " Input the file id with format: [File: file_id]."
|
||||
else:
|
||||
file_format_desc = "Input the file id with format: [Files: file_id1, file_id2, ...]. "
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = {
|
||||
"type": "string",
|
||||
"description": (parameter.llm_description or "") + file_format_desc,
|
||||
}
|
||||
continue
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = (
|
||||
{
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
if parameter.input_schema is None
|
||||
else parameter.input_schema
|
||||
)
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return message_tool
|
||||
|
||||
def create_image_message(
|
||||
self,
|
||||
image: str,
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"ToolCall",
|
||||
"ToolCallResult",
|
||||
"ToolResult",
|
||||
"ToolResultStatus",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
]
|
||||
|
||||
39
api/core/workflow/entities/tool_entities.py
Normal file
39
api/core/workflow/entities/tool_entities.py
Normal file
@ -0,0 +1,39 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File
|
||||
|
||||
|
||||
class ToolResultStatus(StrEnum):
|
||||
SUCCESS = "success"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str | None = Field(default=None, description="Unique identifier for this tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool being called")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
icon: str | dict | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[str] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")
|
||||
icon: str | dict | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
|
||||
|
||||
class ToolCallResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier for the tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[File] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")
|
||||
@ -211,6 +211,10 @@ class WorkflowExecutionStatus(StrEnum):
|
||||
def is_ended(self) -> bool:
|
||||
return self in _END_STATE
|
||||
|
||||
@classmethod
|
||||
def ended_values(cls) -> list[str]:
|
||||
return [status.value for status in _END_STATE]
|
||||
|
||||
|
||||
_END_STATE = frozenset(
|
||||
[
|
||||
@ -247,6 +251,8 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
LLM_CONTENT_SEQUENCE = "llm_content_sequence"
|
||||
LLM_TRACE = "llm_trace"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
|
||||
|
||||
|
||||
@ -16,7 +16,13 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_events import (
|
||||
ChunkType,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
@ -321,11 +327,24 @@ class ResponseStreamCoordinator:
|
||||
selector: Sequence[str],
|
||||
chunk: str,
|
||||
is_final: bool = False,
|
||||
chunk_type: ChunkType = ChunkType.TEXT,
|
||||
tool_call: ToolCall | None = None,
|
||||
tool_result: ToolResult | None = None,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Create a stream chunk event with consistent structure.
|
||||
|
||||
For selectors with special prefixes (sys, env, conversation), we use the
|
||||
active response node's information since these are not actual node IDs.
|
||||
|
||||
Args:
|
||||
node_id: The node ID to attribute the event to
|
||||
execution_id: The execution ID for this node
|
||||
selector: The variable selector
|
||||
chunk: The chunk content
|
||||
is_final: Whether this is the final chunk
|
||||
chunk_type: The semantic type of the chunk being streamed
|
||||
tool_call: Structured data for tool_call chunks
|
||||
tool_result: Structured data for tool_result chunks
|
||||
"""
|
||||
# Check if this is a special selector that doesn't correspond to a node
|
||||
if selector and selector[0] not in self._graph.nodes and self._active_session:
|
||||
@ -338,6 +357,9 @@ class ResponseStreamCoordinator:
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
chunk_type=chunk_type,
|
||||
tool_call=tool_call,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
# Standard case: selector refers to an actual node
|
||||
@ -349,6 +371,9 @@ class ResponseStreamCoordinator:
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
chunk_type=chunk_type,
|
||||
tool_call=tool_call,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
|
||||
@ -356,6 +381,8 @@ class ResponseStreamCoordinator:
|
||||
|
||||
Handles both regular node selectors and special system selectors (sys, env, conversation).
|
||||
For special selectors, we attribute the output to the active response node.
|
||||
|
||||
For object-type variables, automatically streams all child fields that have stream events.
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
@ -364,60 +391,81 @@ class ResponseStreamCoordinator:
|
||||
# Determine which node to attribute the output to
|
||||
# For special selectors (sys, env, conversation), use the active response node
|
||||
# For regular selectors, use the source node
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
# Special selector - use active response node
|
||||
output_node_id = self._active_session.node_id
|
||||
else:
|
||||
# Regular node selector
|
||||
output_node_id = source_selector_prefix
|
||||
active_session = self._active_session
|
||||
special_selector = bool(active_session and source_selector_prefix not in self._graph.nodes)
|
||||
output_node_id = active_session.node_id if special_selector and active_session else source_selector_prefix
|
||||
execution_id = self._get_or_create_execution_id(output_node_id)
|
||||
|
||||
# Stream all available chunks
|
||||
while self._has_unread_stream(segment.selector):
|
||||
if event := self._pop_stream_chunk(segment.selector):
|
||||
# For special selectors, we need to update the event to use
|
||||
# the active response node's information
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
# Create a new event with the response node's information
|
||||
# but keep the original selector
|
||||
updated_event = NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=event.selector, # Keep original selector
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
)
|
||||
events.append(updated_event)
|
||||
else:
|
||||
# Regular node selector - use event as is
|
||||
events.append(event)
|
||||
# Check if there's a direct stream for this selector
|
||||
has_direct_stream = (
|
||||
tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams
|
||||
)
|
||||
|
||||
# Check if this is the last chunk by looking ahead
|
||||
stream_closed = self._is_stream_closed(segment.selector)
|
||||
# Check if stream is closed to determine if segment is complete
|
||||
if stream_closed:
|
||||
is_complete = True
|
||||
stream_targets = [segment.selector] if has_direct_stream else sorted(self._find_child_streams(segment.selector))
|
||||
|
||||
elif value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=segment.selector,
|
||||
chunk=value.markdown,
|
||||
is_final=is_last_segment,
|
||||
if stream_targets:
|
||||
all_complete = True
|
||||
|
||||
for target_selector in stream_targets:
|
||||
while self._has_unread_stream(target_selector):
|
||||
if event := self._pop_stream_chunk(target_selector):
|
||||
events.append(
|
||||
self._rewrite_stream_event(
|
||||
event=event,
|
||||
output_node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
special_selector=bool(special_selector),
|
||||
)
|
||||
)
|
||||
|
||||
if not self._is_stream_closed(target_selector):
|
||||
all_complete = False
|
||||
|
||||
is_complete = all_complete
|
||||
|
||||
# Fallback: check if scalar value exists in variable pool
|
||||
if not is_complete and not has_direct_stream:
|
||||
if value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self._active_session
|
||||
and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
)
|
||||
is_complete = True
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=segment.selector,
|
||||
chunk=value.markdown,
|
||||
is_final=is_last_segment,
|
||||
)
|
||||
)
|
||||
is_complete = True
|
||||
|
||||
return events, is_complete
|
||||
|
||||
def _rewrite_stream_event(
|
||||
self,
|
||||
event: NodeRunStreamChunkEvent,
|
||||
output_node_id: str,
|
||||
execution_id: str,
|
||||
special_selector: bool,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Rewrite event to attribute to active response node when selector is special."""
|
||||
if not special_selector:
|
||||
return event
|
||||
|
||||
return self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=event.chunk_type,
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
|
||||
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""Process a text segment. Returns (events, is_complete)."""
|
||||
assert self._active_session is not None
|
||||
@ -513,6 +561,36 @@ class ResponseStreamCoordinator:
|
||||
|
||||
# ============= Internal Stream Management Methods =============
|
||||
|
||||
def _find_child_streams(self, parent_selector: Sequence[str]) -> list[tuple[str, ...]]:
|
||||
"""Find all child stream selectors that are descendants of the parent selector.
|
||||
|
||||
For example, if parent_selector is ['llm', 'generation'], this will find:
|
||||
- ['llm', 'generation', 'content']
|
||||
- ['llm', 'generation', 'tool_calls']
|
||||
- ['llm', 'generation', 'tool_results']
|
||||
- ['llm', 'generation', 'thought']
|
||||
|
||||
Args:
|
||||
parent_selector: The parent selector to search for children
|
||||
|
||||
Returns:
|
||||
List of child selector tuples found in stream buffers or closed streams
|
||||
"""
|
||||
parent_key = tuple(parent_selector)
|
||||
parent_len = len(parent_key)
|
||||
child_streams: set[tuple[str, ...]] = set()
|
||||
|
||||
# Search in both active buffers and closed streams
|
||||
all_selectors = set(self._stream_buffers.keys()) | self._closed_streams
|
||||
|
||||
for selector_key in all_selectors:
|
||||
# Check if this selector is a direct child of the parent
|
||||
# Direct child means: len(child) == len(parent) + 1 and child starts with parent
|
||||
if len(selector_key) == parent_len + 1 and selector_key[:parent_len] == parent_key:
|
||||
child_streams.add(selector_key)
|
||||
|
||||
return sorted(child_streams)
|
||||
|
||||
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Append a stream chunk to the internal buffer.
|
||||
|
||||
@ -36,6 +36,7 @@ from .loop import (
|
||||
|
||||
# Node events
|
||||
from .node import (
|
||||
ChunkType,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
@ -44,10 +45,13 @@ from .node import (
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseGraphEvent",
|
||||
"ChunkType",
|
||||
"GraphEngineEvent",
|
||||
"GraphNodeEventBase",
|
||||
"GraphRunAbortedEvent",
|
||||
@ -73,4 +77,6 @@ __all__ = [
|
||||
"NodeRunStartedEvent",
|
||||
"NodeRunStreamChunkEvent",
|
||||
"NodeRunSucceededEvent",
|
||||
"ToolCall",
|
||||
"ToolResult",
|
||||
]
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
@ -21,13 +22,39 @@ class NodeRunStartedEvent(GraphNodeEventBase):
|
||||
provider_id: str = ""
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
THOUGHT_START = "thought_start" # Agent thought start
|
||||
THOUGHT_END = "thought_end" # Agent thought end
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
||||
# Spec-compliant fields
|
||||
"""Stream chunk event for workflow node execution."""
|
||||
|
||||
# Base fields
|
||||
selector: Sequence[str] = Field(
|
||||
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
|
||||
)
|
||||
chunk: str = Field(..., description="the actual chunk content")
|
||||
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
|
||||
|
||||
# Tool call fields (when chunk_type == TOOL_CALL)
|
||||
tool_call: ToolCall | None = Field(
|
||||
default=None,
|
||||
description="structured payload for tool_call chunks",
|
||||
)
|
||||
|
||||
# Tool result fields (when chunk_type == TOOL_RESULT)
|
||||
tool_result: ToolResult | None = Field(
|
||||
default=None,
|
||||
description="structured payload for tool_result chunks",
|
||||
)
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
||||
|
||||
@ -13,16 +13,21 @@ from .loop import (
|
||||
LoopSucceededEvent,
|
||||
)
|
||||
from .node import (
|
||||
ChunkType,
|
||||
ModelInvokeCompletedEvent,
|
||||
PauseRequestedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunRetryEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
ThoughtChunkEvent,
|
||||
ToolCallChunkEvent,
|
||||
ToolResultChunkEvent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentLogEvent",
|
||||
"ChunkType",
|
||||
"IterationFailedEvent",
|
||||
"IterationNextEvent",
|
||||
"IterationStartedEvent",
|
||||
@ -39,4 +44,7 @@ __all__ = [
|
||||
"RunRetryEvent",
|
||||
"StreamChunkEvent",
|
||||
"StreamCompletedEvent",
|
||||
"ThoughtChunkEvent",
|
||||
"ToolCallChunkEvent",
|
||||
"ToolResultChunkEvent",
|
||||
]
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.file import File
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import ToolCall, ToolResult
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
@ -32,13 +34,60 @@ class RunRetryEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="Retry start time")
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
THOUGHT_START = "thought_start" # Agent thought start
|
||||
THOUGHT_END = "thought_end" # Agent thought end
|
||||
|
||||
|
||||
class StreamChunkEvent(NodeEventBase):
|
||||
# Spec-compliant fields
|
||||
"""Base stream chunk event - normal text streaming output."""
|
||||
|
||||
selector: Sequence[str] = Field(
|
||||
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
|
||||
)
|
||||
chunk: str = Field(..., description="the actual chunk content")
|
||||
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
|
||||
tool_call: ToolCall | None = Field(default=None, description="structured payload for tool_call chunks")
|
||||
tool_result: ToolResult | None = Field(default=None, description="structured payload for tool_result chunks")
|
||||
|
||||
|
||||
class ToolCallChunkEvent(StreamChunkEvent):
|
||||
"""Tool call streaming event - tool call arguments streaming output."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TOOL_CALL, frozen=True)
|
||||
tool_call: ToolCall | None = Field(default=None, description="structured tool call payload")
|
||||
|
||||
|
||||
class ToolResultChunkEvent(StreamChunkEvent):
|
||||
"""Tool result event - tool execution result."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TOOL_RESULT, frozen=True)
|
||||
tool_result: ToolResult | None = Field(default=None, description="structured tool result payload")
|
||||
|
||||
|
||||
class ThoughtStartChunkEvent(StreamChunkEvent):
|
||||
"""Agent thought start streaming event - Agent thinking process (ReAct)."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT_START, frozen=True)
|
||||
|
||||
|
||||
class ThoughtEndChunkEvent(StreamChunkEvent):
|
||||
"""Agent thought end streaming event - Agent thinking process (ReAct)."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT_END, frozen=True)
|
||||
|
||||
|
||||
class ThoughtChunkEvent(StreamChunkEvent):
|
||||
"""Agent thought streaming event - Agent thinking process (ReAct)."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT, frozen=True)
|
||||
|
||||
|
||||
class StreamCompletedEvent(NodeEventBase):
|
||||
|
||||
@ -48,6 +48,9 @@ from core.workflow.node_events import (
|
||||
RunRetrieverResourceEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
ThoughtChunkEvent,
|
||||
ToolCallChunkEvent,
|
||||
ToolResultChunkEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -564,6 +567,8 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
@ -571,6 +576,60 @@ class Node(Generic[NodeDataT]):
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType(event.chunk_type.value),
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: ToolCallChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType.TOOL_CALL,
|
||||
tool_call=event.tool_call,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: ToolResultChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.entities import ToolResult, ToolResultStatus
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
tool_result = event.tool_result or ToolResult()
|
||||
status: ToolResultStatus = tool_result.status or ToolResultStatus.SUCCESS
|
||||
tool_result = tool_result.model_copy(
|
||||
update={"status": status, "files": tool_result.files or []},
|
||||
)
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType.TOOL_RESULT,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: ThoughtChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType.THOUGHT,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
|
||||
@ -62,6 +62,21 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
inputs = {"variable_selector": variable_selector}
|
||||
process_data = {"documents": value if isinstance(value, list) else [value]}
|
||||
|
||||
# Ensure storage_key is loaded for File objects
|
||||
files_to_check = value if isinstance(value, list) else [value]
|
||||
files_needing_storage_key = [
|
||||
f for f in files_to_check
|
||||
if isinstance(f, File) and not f.storage_key and f.related_id
|
||||
]
|
||||
if files_needing_storage_key:
|
||||
from factories.file_factory import StorageKeyLoader
|
||||
from extensions.ext_database import db
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
with Session(bind=db.engine) as session:
|
||||
storage_key_loader = StorageKeyLoader(session, tenant_id=self.tenant_id)
|
||||
storage_key_loader.load_storage_keys(files_needing_storage_key)
|
||||
|
||||
try:
|
||||
if isinstance(value, list):
|
||||
extracted_text_list = list(map(_extract_text_from_file, value))
|
||||
@ -415,6 +430,15 @@ def _download_file_content(file: File) -> bytes:
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
else:
|
||||
# Check if storage_key is set
|
||||
if not file.storage_key:
|
||||
raise FileDownloadError(f"File storage_key is missing for file: {file.filename}")
|
||||
|
||||
# Check if file exists before downloading
|
||||
from extensions.ext_storage import storage
|
||||
if not storage.exists(file.storage_key):
|
||||
raise FileDownloadError(f"File not found in storage: {file.storage_key}")
|
||||
|
||||
return file_manager.download(file)
|
||||
except Exception as e:
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
@ -158,3 +158,5 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
||||
type: str = "knowledge-index"
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
indexing_technique: str | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import concurrent.futures
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -16,7 +18,9 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
@ -67,7 +71,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
# index knowledge
|
||||
try:
|
||||
if is_preview:
|
||||
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
|
||||
# Preview mode: generate summaries for chunks directly without saving to database
|
||||
# Format preview and generate summaries on-the-fly
|
||||
# Get indexing_technique and summary_index_setting from node_data (workflow graph config)
|
||||
# or fallback to dataset if not available in node_data
|
||||
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
|
||||
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
|
||||
|
||||
outputs = self._get_preview_output_with_summaries(
|
||||
node_data.chunk_structure, chunks, dataset=dataset,
|
||||
indexing_technique=indexing_technique,
|
||||
summary_index_setting=summary_index_setting
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
@ -163,6 +178,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Generate summary index if enabled
|
||||
self._handle_summary_index_generation(dataset, document, variable_pool)
|
||||
|
||||
return {
|
||||
"dataset_id": ds_id_value,
|
||||
"dataset_name": dataset_name_value,
|
||||
@ -173,9 +191,269 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
"display_status": "completed",
|
||||
}
|
||||
|
||||
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
|
||||
def _handle_summary_index_generation(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
document: Document,
|
||||
variable_pool: VariablePool,
|
||||
) -> None:
|
||||
"""
|
||||
Handle summary index generation based on mode (debug/preview or production).
|
||||
|
||||
Args:
|
||||
dataset: Dataset containing the document
|
||||
document: Document to generate summaries for
|
||||
variable_pool: Variable pool to check invoke_from
|
||||
"""
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
return
|
||||
|
||||
# Check if summary index is enabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return
|
||||
|
||||
# Skip qa_model documents
|
||||
if document.doc_form == "qa_model":
|
||||
return
|
||||
|
||||
# Determine if in preview/debug mode
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER
|
||||
|
||||
# Determine if only parent chunks should be processed
|
||||
only_parent_chunks = dataset.chunk_structure == "parent_child_index"
|
||||
|
||||
if is_preview:
|
||||
try:
|
||||
# Query segments that need summary generation
|
||||
query = db.session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
segments = query.all()
|
||||
|
||||
if not segments:
|
||||
logger.info(f"No segments found for document {document.id}")
|
||||
return
|
||||
|
||||
# Filter segments based on mode
|
||||
segments_to_process = []
|
||||
for segment in segments:
|
||||
# Skip if summary already exists
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed")
|
||||
.first()
|
||||
)
|
||||
if existing_summary:
|
||||
continue
|
||||
|
||||
# For parent-child mode, all segments are parent chunks, so process all
|
||||
segments_to_process.append(segment)
|
||||
|
||||
if not segments_to_process:
|
||||
logger.info(f"No segments need summary generation for document {document.id}")
|
||||
return
|
||||
|
||||
# Use ThreadPoolExecutor for concurrent generation
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
max_workers = min(10, len(segments_to_process)) # Limit to 10 workers
|
||||
|
||||
def process_segment(segment: DocumentSegment) -> None:
|
||||
"""Process a single segment in a thread with Flask app context."""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
SummaryIndexService.generate_and_vectorize_summary(
|
||||
segment, dataset, summary_index_setting
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary for segment {segment.id}: {str(e)}")
|
||||
# Continue processing other segments
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [
|
||||
executor.submit(process_segment, segment) for segment in segments_to_process
|
||||
]
|
||||
# Wait for all tasks to complete
|
||||
concurrent.futures.wait(futures, timeout=300)
|
||||
|
||||
logger.info(
|
||||
f"Successfully generated summary index for {len(segments_to_process)} segments "
|
||||
f"in document {document.id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to generate summary index for document {document.id}: {str(e)}")
|
||||
# Don't fail the entire indexing process if summary generation fails
|
||||
else:
|
||||
# Production mode: asynchronous generation
|
||||
logger.info(f"Queuing summary index generation task for document {document.id} (production mode)")
|
||||
try:
|
||||
generate_summary_index_task.delay(dataset.id, document.id, None)
|
||||
logger.info(f"Summary index generation task queued for document {document.id}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to queue summary index generation task for document {document.id}: {str(e)}")
|
||||
# Don't fail the entire indexing process if task queuing fails
|
||||
|
||||
def _get_preview_output_with_summaries(
|
||||
self, chunk_structure: str, chunks: Any, dataset: Dataset,
|
||||
indexing_technique: str | None = None,
|
||||
summary_index_setting: dict | None = None
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Generate preview output with summaries for chunks in preview mode.
|
||||
This method generates summaries on-the-fly without saving to database.
|
||||
|
||||
Args:
|
||||
chunk_structure: Chunk structure type
|
||||
chunks: Chunks to generate preview for
|
||||
dataset: Dataset object (for tenant_id)
|
||||
indexing_technique: Indexing technique from node config or dataset
|
||||
summary_index_setting: Summary index setting from node config or dataset
|
||||
"""
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
return index_processor.format_preview(chunks)
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
|
||||
# Check if summary index is enabled
|
||||
if indexing_technique != "high_quality":
|
||||
return preview_output
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return preview_output
|
||||
|
||||
# Generate summaries for chunks
|
||||
if "preview" in preview_output and isinstance(preview_output["preview"], list):
|
||||
chunk_count = len(preview_output["preview"])
|
||||
logger.info(
|
||||
f"Generating summaries for {chunk_count} chunks in preview mode "
|
||||
f"(dataset: {dataset.id})"
|
||||
)
|
||||
# Use ParagraphIndexProcessor's generate_summary method
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
# Get Flask app for application context in worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def generate_summary_for_chunk(preview_item: dict) -> None:
|
||||
"""Generate summary for a single chunk."""
|
||||
if "content" in preview_item:
|
||||
try:
|
||||
# Set Flask application context in worker thread
|
||||
if flask_app:
|
||||
with flask_app.app_context():
|
||||
summary = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary for chunk: {str(e)}")
|
||||
# Don't fail the entire preview if summary generation fails
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
|
||||
timeout_seconds = min(300, 60 * len(preview_output["preview"]))
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor:
|
||||
futures = [
|
||||
executor.submit(generate_summary_for_chunk, preview_item)
|
||||
for preview_item in preview_output["preview"]
|
||||
]
|
||||
# Wait for all tasks to complete with timeout
|
||||
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
|
||||
|
||||
# Cancel tasks that didn't complete in time
|
||||
if not_done:
|
||||
logger.warning(
|
||||
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s. "
|
||||
"Cancelling remaining tasks..."
|
||||
)
|
||||
for future in not_done:
|
||||
future.cancel()
|
||||
# Wait a bit for cancellation to take effect
|
||||
concurrent.futures.wait(not_done, timeout=5)
|
||||
|
||||
completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None)
|
||||
logger.info(
|
||||
f"Completed summary generation for preview chunks: {completed_count}/{len(preview_output['preview'])} succeeded"
|
||||
)
|
||||
|
||||
return preview_output
|
||||
|
||||
def _get_preview_output(
|
||||
self, chunk_structure: str, chunks: Any, dataset: Dataset | None = None, variable_pool: VariablePool | None = None
|
||||
) -> Mapping[str, Any]:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
|
||||
# If dataset is provided, try to enrich preview with summaries
|
||||
if dataset and variable_pool:
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if document_id:
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if document:
|
||||
# Query summaries for this document
|
||||
summaries = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if summaries:
|
||||
# Create a map of segment content to summary for matching
|
||||
# Use content matching as chunks in preview might not be indexed yet
|
||||
summary_by_content = {}
|
||||
for summary in summaries:
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(id=summary.chunk_id, dataset_id=dataset.id)
|
||||
.first()
|
||||
)
|
||||
if segment:
|
||||
# Normalize content for matching (strip whitespace)
|
||||
normalized_content = segment.content.strip()
|
||||
summary_by_content[normalized_content] = summary.summary_content
|
||||
|
||||
# Enrich preview with summaries by content matching
|
||||
if "preview" in preview_output and isinstance(preview_output["preview"], list):
|
||||
matched_count = 0
|
||||
for preview_item in preview_output["preview"]:
|
||||
if "content" in preview_item:
|
||||
# Normalize content for matching
|
||||
normalized_chunk_content = preview_item["content"].strip()
|
||||
if normalized_chunk_content in summary_by_content:
|
||||
preview_item["summary"] = summary_by_content[normalized_chunk_content]
|
||||
matched_count += 1
|
||||
|
||||
if matched_count > 0:
|
||||
logger.info(
|
||||
f"Enriched preview with {matched_count} existing summaries "
|
||||
f"(dataset: {dataset.id}, document: {document.id})"
|
||||
)
|
||||
|
||||
return preview_output
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
|
||||
@ -268,7 +268,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
usage = self._merge_usage(usage, metadata_usage)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
creator_user_role = self.user_from.to_creator_user_role()
|
||||
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
|
||||
# fetch model config
|
||||
if node_data.single_retrieval_config is None:
|
||||
@ -293,7 +292,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
creator_user_role=creator_user_role,
|
||||
user_from=self.user_from.value,
|
||||
query=query,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
@ -335,7 +334,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
creator_user_role=creator_user_role,
|
||||
user_from=self.user_from.value,
|
||||
available_datasets=available_datasets,
|
||||
query=query,
|
||||
top_k=node_data.multiple_retrieval_config.top_k,
|
||||
|
||||
@ -3,6 +3,7 @@ from .entities import (
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
ToolMetadata,
|
||||
VisionConfig,
|
||||
)
|
||||
from .node import LLMNode
|
||||
@ -13,5 +14,6 @@ __all__ = [
|
||||
"LLMNodeCompletionModelPromptTemplate",
|
||||
"LLMNodeData",
|
||||
"ModelConfig",
|
||||
"ToolMetadata",
|
||||
"VisionConfig",
|
||||
]
|
||||
|
||||
@ -1,10 +1,17 @@
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.entities import ToolCall, ToolCallResult
|
||||
from core.workflow.node_events import AgentLogEvent
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
@ -58,6 +65,268 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class ToolMetadata(BaseModel):
|
||||
"""
|
||||
Tool metadata for LLM node with tool support.
|
||||
|
||||
Defines the essential fields needed for tool configuration,
|
||||
particularly the 'type' field to identify tool provider type.
|
||||
"""
|
||||
|
||||
# Core fields
|
||||
enabled: bool = True
|
||||
type: ToolProviderType = Field(..., description="Tool provider type: builtin, api, mcp, workflow")
|
||||
provider_name: str = Field(..., description="Tool provider name/identifier")
|
||||
tool_name: str = Field(..., description="Tool name")
|
||||
|
||||
# Optional fields
|
||||
plugin_unique_identifier: str | None = Field(None, description="Plugin unique identifier for plugin tools")
|
||||
credential_id: str | None = Field(None, description="Credential ID for tools requiring authentication")
|
||||
|
||||
# Configuration fields
|
||||
parameters: dict[str, Any] = Field(default_factory=dict, description="Tool parameters")
|
||||
settings: dict[str, Any] = Field(default_factory=dict, description="Tool settings configuration")
|
||||
extra: dict[str, Any] = Field(default_factory=dict, description="Extra tool configuration like custom description")
|
||||
|
||||
|
||||
class ModelTraceSegment(BaseModel):
|
||||
"""Model invocation trace segment with token usage and output."""
|
||||
|
||||
text: str | None = Field(None, description="Model output text content")
|
||||
reasoning: str | None = Field(None, description="Reasoning/thought content from model")
|
||||
tool_calls: list[ToolCall] = Field(default_factory=list, description="Tool calls made by the model")
|
||||
|
||||
|
||||
class ToolTraceSegment(BaseModel):
|
||||
"""Tool invocation trace segment with call details and result."""
|
||||
|
||||
id: str | None = Field(default=None, description="Unique identifier for this tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool being called")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
output: str | None = Field(default=None, description="Tool call result")
|
||||
|
||||
|
||||
class LLMTraceSegment(BaseModel):
|
||||
"""
|
||||
Streaming trace segment for LLM tool-enabled runs.
|
||||
|
||||
Represents alternating model and tool invocations in sequence:
|
||||
model -> tool -> model -> tool -> ...
|
||||
|
||||
Each segment records its execution duration.
|
||||
"""
|
||||
|
||||
type: Literal["model", "tool"]
|
||||
duration: float = Field(..., description="Execution duration in seconds")
|
||||
usage: LLMUsage | None = Field(default=None, description="Token usage statistics for this model call")
|
||||
output: ModelTraceSegment | ToolTraceSegment = Field(..., description="Output of the segment")
|
||||
|
||||
# Common metadata for both model and tool segments
|
||||
provider: str | None = Field(default=None, description="Model or tool provider identifier")
|
||||
name: str | None = Field(default=None, description="Name of the model or tool")
|
||||
icon: str | None = Field(default=None, description="Icon for the provider")
|
||||
icon_dark: str | None = Field(default=None, description="Dark theme icon for the provider")
|
||||
error: str | None = Field(default=None, description="Error message if segment failed")
|
||||
status: Literal["success", "error"] | None = Field(default=None, description="Tool execution status")
|
||||
|
||||
|
||||
class LLMGenerationData(BaseModel):
|
||||
"""Generation data from LLM invocation with tools.
|
||||
|
||||
For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2
|
||||
- reasoning_contents: [thought1, thought2, ...] - one element per turn
|
||||
- tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results
|
||||
"""
|
||||
|
||||
text: str = Field(..., description="Accumulated text content from all turns")
|
||||
reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn")
|
||||
tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results")
|
||||
sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering")
|
||||
usage: LLMUsage = Field(..., description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(None, description="Finish reason from LLM")
|
||||
files: list[File] = Field(default_factory=list, description="Generated files")
|
||||
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
|
||||
|
||||
|
||||
class ThinkTagStreamParser:
|
||||
"""Lightweight state machine to split streaming chunks by <think> tags."""
|
||||
|
||||
_START_PATTERN = re.compile(r"<think(?:\s[^>]*)?>", re.IGNORECASE)
|
||||
_END_PATTERN = re.compile(r"</think>", re.IGNORECASE)
|
||||
_START_PREFIX = "<think"
|
||||
_END_PREFIX = "</think"
|
||||
|
||||
def __init__(self):
|
||||
self._buffer = ""
|
||||
self._in_think = False
|
||||
|
||||
@staticmethod
|
||||
def _suffix_prefix_len(text: str, prefix: str) -> int:
|
||||
"""Return length of the longest suffix of `text` that is a prefix of `prefix`."""
|
||||
max_len = min(len(text), len(prefix) - 1)
|
||||
for i in range(max_len, 0, -1):
|
||||
if text[-i:].lower() == prefix[:i].lower():
|
||||
return i
|
||||
return 0
|
||||
|
||||
def process(self, chunk: str) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Split incoming chunk into ('thought' | 'text', content) tuples.
|
||||
Content excludes the <think> tags themselves and handles split tags across chunks.
|
||||
"""
|
||||
parts: list[tuple[str, str]] = []
|
||||
self._buffer += chunk
|
||||
|
||||
while self._buffer:
|
||||
if self._in_think:
|
||||
end_match = self._END_PATTERN.search(self._buffer)
|
||||
if end_match:
|
||||
thought_text = self._buffer[: end_match.start()]
|
||||
if thought_text:
|
||||
parts.append(("thought", thought_text))
|
||||
parts.append(("thought_end", ""))
|
||||
self._buffer = self._buffer[end_match.end() :]
|
||||
self._in_think = False
|
||||
continue
|
||||
|
||||
hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX)
|
||||
emit = self._buffer[: len(self._buffer) - hold_len]
|
||||
if emit:
|
||||
parts.append(("thought", emit))
|
||||
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
|
||||
break
|
||||
|
||||
start_match = self._START_PATTERN.search(self._buffer)
|
||||
if start_match:
|
||||
prefix = self._buffer[: start_match.start()]
|
||||
if prefix:
|
||||
parts.append(("text", prefix))
|
||||
self._buffer = self._buffer[start_match.end() :]
|
||||
parts.append(("thought_start", ""))
|
||||
self._in_think = True
|
||||
continue
|
||||
|
||||
hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX)
|
||||
emit = self._buffer[: len(self._buffer) - hold_len]
|
||||
if emit:
|
||||
parts.append(("text", emit))
|
||||
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
|
||||
break
|
||||
|
||||
cleaned_parts: list[tuple[str, str]] = []
|
||||
for kind, content in parts:
|
||||
# Extra safeguard: strip any stray tags that slipped through.
|
||||
content = self._START_PATTERN.sub("", content)
|
||||
content = self._END_PATTERN.sub("", content)
|
||||
if content or kind in {"thought_start", "thought_end"}:
|
||||
cleaned_parts.append((kind, content))
|
||||
|
||||
return cleaned_parts
|
||||
|
||||
def flush(self) -> list[tuple[str, str]]:
|
||||
"""Flush remaining buffer when the stream ends."""
|
||||
if not self._buffer:
|
||||
return []
|
||||
kind = "thought" if self._in_think else "text"
|
||||
content = self._buffer
|
||||
# Drop dangling partial tags instead of emitting them
|
||||
if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX):
|
||||
content = ""
|
||||
self._buffer = ""
|
||||
if not content and not self._in_think:
|
||||
return []
|
||||
# Strip any complete tags that might still be present.
|
||||
content = self._START_PATTERN.sub("", content)
|
||||
content = self._END_PATTERN.sub("", content)
|
||||
|
||||
result: list[tuple[str, str]] = []
|
||||
if content:
|
||||
result.append((kind, content))
|
||||
if self._in_think:
|
||||
result.append(("thought_end", ""))
|
||||
self._in_think = False
|
||||
return result
|
||||
|
||||
|
||||
class StreamBuffers(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
think_parser: ThinkTagStreamParser = Field(default_factory=ThinkTagStreamParser)
|
||||
pending_thought: list[str] = Field(default_factory=list)
|
||||
pending_content: list[str] = Field(default_factory=list)
|
||||
pending_tool_calls: list[ToolCall] = Field(default_factory=list)
|
||||
current_turn_reasoning: list[str] = Field(default_factory=list)
|
||||
reasoning_per_turn: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TraceState(BaseModel):
|
||||
trace_segments: list[LLMTraceSegment] = Field(default_factory=list)
|
||||
tool_trace_map: dict[str, LLMTraceSegment] = Field(default_factory=dict)
|
||||
tool_call_index_map: dict[str, int] = Field(default_factory=dict)
|
||||
model_segment_start_time: float | None = Field(default=None, description="Start time for current model segment")
|
||||
pending_usage: LLMUsage | None = Field(default=None, description="Pending usage for current model segment")
|
||||
|
||||
|
||||
class AggregatedResult(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
text: str = ""
|
||||
files: list[File] = Field(default_factory=list)
|
||||
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
agent_logs: list[AgentLogEvent] = Field(default_factory=list)
|
||||
agent_result: AgentResult | None = None
|
||||
|
||||
|
||||
class ToolOutputState(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
stream: StreamBuffers = Field(default_factory=StreamBuffers)
|
||||
trace: TraceState = Field(default_factory=TraceState)
|
||||
aggregate: AggregatedResult = Field(default_factory=AggregatedResult)
|
||||
agent: AgentContext = Field(default_factory=AgentContext)
|
||||
|
||||
|
||||
class ToolLogPayload(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
tool_name: str = ""
|
||||
tool_call_id: str = ""
|
||||
tool_args: dict[str, Any] = Field(default_factory=dict)
|
||||
tool_output: Any = None
|
||||
tool_error: Any = None
|
||||
files: list[Any] = Field(default_factory=list)
|
||||
meta: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_log(cls, log: AgentLog) -> "ToolLogPayload":
|
||||
data = log.data or {}
|
||||
return cls(
|
||||
tool_name=data.get("tool_name", ""),
|
||||
tool_call_id=data.get("tool_call_id", ""),
|
||||
tool_args=data.get("tool_args") or {},
|
||||
tool_output=data.get("output"),
|
||||
tool_error=data.get("error"),
|
||||
files=data.get("files") or [],
|
||||
meta=data.get("meta") or {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, data: Mapping[str, Any]) -> "ToolLogPayload":
|
||||
return cls(
|
||||
tool_name=data.get("tool_name", ""),
|
||||
tool_call_id=data.get("tool_call_id", ""),
|
||||
tool_args=data.get("tool_args") or {},
|
||||
tool_output=data.get("output"),
|
||||
tool_error=data.get("error"),
|
||||
files=data.get("files") or [],
|
||||
meta=data.get("meta") or {},
|
||||
)
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
@ -86,6 +355,10 @@ class LLMNodeData(BaseNodeData):
|
||||
),
|
||||
)
|
||||
|
||||
# Tool support
|
||||
tools: Sequence[ToolMetadata] = Field(default_factory=list)
|
||||
max_iterations: int | None = Field(default=None, description="Maximum number of iterations for the LLM node")
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
def convert_none_prompt_config(cls, v: Any):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from jsonschema import Draft7Validator, ValidationError
|
||||
@ -43,25 +42,22 @@ class StartNode(Node[StartNodeData]):
|
||||
if value is None and variable.required:
|
||||
raise ValueError(f"{key} is required in input form")
|
||||
|
||||
# If no value provided, skip further processing for this key
|
||||
if not value:
|
||||
continue
|
||||
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"JSON object for '{key}' must be an object")
|
||||
|
||||
# Overwrite with normalized dict to ensure downstream consistency
|
||||
node_inputs[key] = value
|
||||
|
||||
# If schema exists, then validate against it
|
||||
schema = variable.json_schema
|
||||
if not schema:
|
||||
continue
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
||||
try:
|
||||
json_schema = json.loads(schema)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{schema} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
json_value = json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{value} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
Draft7Validator(json_schema).validate(json_value)
|
||||
Draft7Validator(schema).validate(value)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
|
||||
node_inputs[key] = json_value
|
||||
|
||||
@ -33,6 +33,15 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||
"""
|
||||
Check if this Variable Assigner node blocks the output of specific variables.
|
||||
|
||||
Returns True if this node updates any of the requested conversation variables.
|
||||
"""
|
||||
assigned_selector = tuple(self.node_data.assigned_variable_selector)
|
||||
return assigned_selector in variable_selectors
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@ -19,6 +19,7 @@ from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
@ -136,13 +137,11 @@ class WorkflowEntry:
|
||||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
node_config = workflow.get_node_config_by_id(node_id)
|
||||
node_config = dict(workflow.get_node_config_by_id(node_id))
|
||||
node_config_data = node_config.get("data", {})
|
||||
|
||||
# Get node class
|
||||
# Get node type
|
||||
node_type = NodeType(node_config_data.get("type"))
|
||||
node_version = node_config_data.get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
@ -158,12 +157,12 @@ class WorkflowEntry:
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init workflow run state
|
||||
node = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node = node_factory.create_node(node_config)
|
||||
node_cls = type(node)
|
||||
|
||||
try:
|
||||
# variable selector to variable mapping
|
||||
|
||||
@ -102,6 +102,8 @@ def init_app(app: DifyApp) -> Celery:
|
||||
imports = [
|
||||
"tasks.async_workflow_tasks", # trigger workers
|
||||
"tasks.trigger_processing_tasks", # async trigger processing
|
||||
"tasks.generate_summary_index_task", # summary index generation
|
||||
"tasks.regenerate_summary_index_task", # summary index regeneration
|
||||
]
|
||||
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
||||
|
||||
@ -163,6 +165,13 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise",
|
||||
"schedule": crontab(minute="0", hour="2"),
|
||||
}
|
||||
if dify_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK:
|
||||
# for saas only
|
||||
imports.append("schedule.clean_workflow_runs_task")
|
||||
beat_schedule["clean_workflow_runs_task"] = {
|
||||
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
|
||||
"schedule": crontab(minute="0", hour="0"),
|
||||
}
|
||||
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
|
||||
imports.append("schedule.workflow_schedule_task")
|
||||
beat_schedule["workflow_schedule_task"] = {
|
||||
|
||||
@ -4,6 +4,7 @@ from dify_app import DifyApp
|
||||
def init_app(app: DifyApp):
|
||||
from commands import (
|
||||
add_qdrant_index,
|
||||
clean_workflow_runs,
|
||||
cleanup_orphaned_draft_variables,
|
||||
clear_free_plan_tenant_expired_logs,
|
||||
clear_orphaned_file_records,
|
||||
@ -56,6 +57,7 @@ def init_app(app: DifyApp):
|
||||
setup_datasource_oauth_client,
|
||||
transform_datasource_credentials,
|
||||
install_rag_pipeline_plugins,
|
||||
clean_workflow_runs,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
app.cli.add_command(cmd)
|
||||
|
||||
@ -169,6 +169,7 @@ class MessageDetail(ResponseModel):
|
||||
status: str
|
||||
error: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
generation_detail: JSONValue | None = Field(default=None, validation_alias="generation_detail_dict")
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
|
||||
@ -39,6 +39,14 @@ dataset_retrieval_model_fields = {
|
||||
"score_threshold_enabled": fields.Boolean,
|
||||
"score_threshold": fields.Float,
|
||||
}
|
||||
|
||||
dataset_summary_index_fields = {
|
||||
"enable": fields.Boolean,
|
||||
"model_name": fields.String,
|
||||
"model_provider_name": fields.String,
|
||||
"summary_prompt": fields.String,
|
||||
}
|
||||
|
||||
external_retrieval_model_fields = {
|
||||
"top_k": fields.Integer,
|
||||
"score_threshold": fields.Float,
|
||||
@ -83,6 +91,7 @@ dataset_detail_fields = {
|
||||
"embedding_model_provider": fields.String,
|
||||
"embedding_available": fields.Boolean,
|
||||
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
|
||||
"summary_index_setting": fields.Nested(dataset_summary_index_fields),
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
"doc_form": fields.String,
|
||||
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user