mirror of
https://github.com/langgenius/dify.git
synced 2026-01-22 04:55:36 +08:00
Compare commits
497 Commits
feat/quota
...
deploy/age
| Author | SHA1 | Date | |
|---|---|---|---|
| aac90133d6 | |||
| 0ac847fb3c | |||
| 70f5365398 | |||
| b1eecb7051 | |||
| d82943f48c | |||
| c4249f94de | |||
| 9ed83a808a | |||
| d7ccea8ac5 | |||
| 1fcff5f8d1 | |||
| 78c7be09f8 | |||
| a37adddacd | |||
| ccbf908d22 | |||
| d444a8eadc | |||
| b5e31c0f25 | |||
| c4943ff4f5 | |||
| 699650565e | |||
| 1c90c729bc | |||
| 45a76fa90b | |||
| 911c1852d5 | |||
| e85b0c49d8 | |||
| b0a059250a | |||
| b94b7860d9 | |||
| 478833f069 | |||
| 5657bf52f0 | |||
| c3333006cf | |||
| c2885077c2 | |||
| 8e20ef6cb5 | |||
| 468d84faba | |||
| cf6c089e72 | |||
| 2f70f778c9 | |||
| 9400863949 | |||
| f831d3bbd6 | |||
| 7fd9ef3d22 | |||
| 705d4cbba9 | |||
| c9e53bf78c | |||
| 7cd280557c | |||
| 58da9c3c11 | |||
| 68d36ff3ed | |||
| 0ed5ed20b5 | |||
| 18a589003e | |||
| da6fdc963c | |||
| 1c76ed2c40 | |||
| ceb410fb5c | |||
| 4fa7843050 | |||
| 3205f98d05 | |||
| 0092254007 | |||
| ee91c9d5f1 | |||
| 2151676db1 | |||
| dc9658b003 | |||
| b527921f3f | |||
| 0e66b51ca0 | |||
| 33e96fd11a | |||
| 2e037014c3 | |||
| 8c4aaa8286 | |||
| dc8c018e28 | |||
| 57a8c453b9 | |||
| e5dc56c483 | |||
| 812df81d92 | |||
| 67c29be3c6 | |||
| cf5e8491df | |||
| 53f828f00e | |||
| 357489d444 | |||
| 331c65fd1d | |||
| 56b09d9f72 | |||
| d4ed398e4f | |||
| 951a580907 | |||
| 3b72b45319 | |||
| 2650ceb0a6 | |||
| c5fc3cc08e | |||
| fdaf471a03 | |||
| 27de07e93d | |||
| 8154d0af53 | |||
| 466f76345b | |||
| fc83e2b1c4 | |||
| 552f9a8989 | |||
| 4f5b175e55 | |||
| 13d6923c11 | |||
| 1483a51aa1 | |||
| f5a34e9ee8 | |||
| d69e7eb12a | |||
| c44aaf1883 | |||
| 4b91969d0f | |||
| 92c54d3c9d | |||
| bc9ce23fdc | |||
| cab33d440b | |||
| 267de1861d | |||
| b3793b0198 | |||
| 8486c675c8 | |||
| 5e49b27dba | |||
| b6df7b3afe | |||
| 6f74a66c8a | |||
| 31a7db2657 | |||
| 68fd7c021c | |||
| e1e64ae430 | |||
| 9080607028 | |||
| 6e9a5139b4 | |||
| f44305af0d | |||
| 8f4a4214a1 | |||
| ff210a98db | |||
| 9ad1f30a8c | |||
| 5053fae5b4 | |||
| d297167fef | |||
| 41aec357b0 | |||
| 96da3b9560 | |||
| 3bb9625ced | |||
| 1bdc47220b | |||
| 5aa4088051 | |||
| 9f444f1f6a | |||
| 49effca35d | |||
| fb28f03155 | |||
| 2afc4704ad | |||
| 5496fc014c | |||
| 7756c151ed | |||
| 83c458d2fe | |||
| 956436b943 | |||
| 3bb9c4b280 | |||
| c38463c9a9 | |||
| fc49592769 | |||
| 6643569efc | |||
| fe0ea13f70 | |||
| c979b59e1e | |||
| 144ca11c03 | |||
| a432fa5fcf | |||
| dbc70f8f05 | |||
| 4b67008dba | |||
| f4b683aa2f | |||
| 7de6ecdedf | |||
| bd070857ed | |||
| d3d1ba2488 | |||
| eae82b1085 | |||
| f9fd234cf8 | |||
| 1dfee05b7e | |||
| dd42e7706a | |||
| 066d18df7a | |||
| 06f6ded20f | |||
| 3a775fc2bf | |||
| 0d5e971a0c | |||
| 9aed4f830f | |||
| 5947e04226 | |||
| 611ff05bde | |||
| 0e890e5692 | |||
| 6584dc2480 | |||
| a922e844eb | |||
| 4bd05ed96e | |||
| 0de32f682a | |||
| 245567118c | |||
| 021f055c36 | |||
| 5f707c5585 | |||
| 232da66b53 | |||
| ebeee92e51 | |||
| f481947b0d | |||
| 94ea7031e8 | |||
| 2f081fa6fa | |||
| 3b27d9e819 | |||
| c0a76220dd | |||
| 9d04fb4992 | |||
| 02fcf33067 | |||
| bbf1247f80 | |||
| b82b73ef94 | |||
| 15d6f60f25 | |||
| ad8c5f5452 | |||
| 721d82b91a | |||
| 0c62c39a1d | |||
| 8d643e4b85 | |||
| d542a74733 | |||
| 16078a9df6 | |||
| 0bd17c6d0f | |||
| 77401e6f5c | |||
| 8b42435f7a | |||
| 3147e850be | |||
| 0b33381efb | |||
| ee7a9a34e0 | |||
| 148f92f92d | |||
| 4ee49552ce | |||
| 40caaaab23 | |||
| 1bc1c04be5 | |||
| 18abc66585 | |||
| f79df6982d | |||
| e85e31773a | |||
| e5336a2d75 | |||
| 649283df09 | |||
| 7222a896d8 | |||
| b5712bf8b0 | |||
| 06b6625c01 | |||
| 7bc2e33e83 | |||
| eb4f57fb8b | |||
| 0f5d3f38da | |||
| 76da178cc1 | |||
| 38a2d2fe68 | |||
| 9397ba5bd2 | |||
| 7093962f30 | |||
| 7022e4b9ca | |||
| b8d67a42bd | |||
| 106cb8e373 | |||
| 9492eda5ef | |||
| a7826d9ea4 | |||
| 64ddcc8960 | |||
| c7bca6a3fb | |||
| f1ce933b33 | |||
| 17990512ce | |||
| a30fb5909b | |||
| 3dea5adf5c | |||
| 5aca563a01 | |||
| bf1ebcdf8f | |||
| 3252748345 | |||
| 72eb29c01b | |||
| 0f3156dfbe | |||
| b21875eaaf | |||
| 2591615a3c | |||
| 691554ad1c | |||
| f43fde5797 | |||
| 783cdb1357 | |||
| 2de17cb1a4 | |||
| 3b6946d3da | |||
| b8adc8f498 | |||
| ca7c4d2c86 | |||
| d8bafb0d1c | |||
| cd0724b827 | |||
| 6e66e2591b | |||
| fd0556909f | |||
| ac2120da1e | |||
| f3904a7e39 | |||
| b3923ec3ca | |||
| 9ffdad6465 | |||
| f247ebfbe1 | |||
| 713e040481 | |||
| f58f36fc8f | |||
| 195cd2c898 | |||
| 6bb09dc58c | |||
| 33f3374ea6 | |||
| 41baaca21d | |||
| d650cde323 | |||
| d641c845dd | |||
| e651c6cacf | |||
| 2e10d67610 | |||
| eab395f58a | |||
| 2f92957e15 | |||
| e89d4e14ea | |||
| 5525f63032 | |||
| 7bc1390366 | |||
| e91fb94d0e | |||
| 5c03a2e251 | |||
| 1741fcf84d | |||
| 52215e9166 | |||
| 4cfc135652 | |||
| 8ee643e88d | |||
| ff632bf9b8 | |||
| ce9ed88b03 | |||
| e6a4a08120 | |||
| 388ee087c0 | |||
| 2fb8883918 | |||
| 28ccd42a1c | |||
| fcd814a2c3 | |||
| fe17cbc1a8 | |||
| 63b3e71909 | |||
| c1c8b6af44 | |||
| 3bd434ddf2 | |||
| 834a5df580 | |||
| e40c2354d5 | |||
| b0eca12d88 | |||
| 3a86983207 | |||
| f461ddeb7e | |||
| 7b534baf15 | |||
| 74d8bdd3a7 | |||
| 657739d48b | |||
| f8b27dd662 | |||
| 18c7f4698a | |||
| ccb337e8eb | |||
| 1ff677c300 | |||
| 04145b19a1 | |||
| 6cb8d03bf6 | |||
| 94ff904a04 | |||
| a0c388f283 | |||
| 56e537786f | |||
| 810f9eaaad | |||
| 31427e9c42 | |||
| 384b99435b | |||
| 425d182f21 | |||
| 4394ba1fe1 | |||
| 4828348532 | |||
| be5a4cf5e3 | |||
| d17a92f713 | |||
| 5ac2230c5d | |||
| ab531d946e | |||
| 1a8fd08563 | |||
| c6ddf89980 | |||
| 71c39ae583 | |||
| 7209ef4aa7 | |||
| 6b55e6781f | |||
| c8c048c3a3 | |||
| 4887c9ea6f | |||
| 495d575ebc | |||
| 18170a1de5 | |||
| 7ce144f493 | |||
| 2279b605c6 | |||
| 3b78f9c2a5 | |||
| 7c029ce808 | |||
| b9052bc244 | |||
| b7025ad9d6 | |||
| c5482c2503 | |||
| d394adfaf7 | |||
| bc771d9c50 | |||
| 96ec176b83 | |||
| f57d2ef31f | |||
| f28ded8455 | |||
| e80bc78780 | |||
| c6ba51127f | |||
| ddbbddbd14 | |||
| 9b961fb41e | |||
| 1db995be0d | |||
| 5675a44ffd | |||
| 48295e5161 | |||
| 4f79d09d7b | |||
| dbed937fc6 | |||
| ffc39b0235 | |||
| f72f58dbc4 | |||
| 9d0f4a2152 | |||
| 1ed4ab4299 | |||
| 969c96b070 | |||
| 3f69d348a1 | |||
| 63fff151c7 | |||
| 9920e0b89a | |||
| 3042f29c15 | |||
| 99273e1118 | |||
| 041dbd482d | |||
| b4aa1de10a | |||
| c5a9b98cbe | |||
| 21f47fbe58 | |||
| 49f115dce3 | |||
| a81d0327d2 | |||
| 9eafe982ee | |||
| a46bfdd0fc | |||
| 16f26c4f99 | |||
| 03e0c4c617 | |||
| 47790b49d4 | |||
| b25b069917 | |||
| bb190f9610 | |||
| d65ae68668 | |||
| f625350439 | |||
| f4e8f64bf7 | |||
| 42fd0a0a62 | |||
| b78439b334 | |||
| 1082d73355 | |||
| d91087492d | |||
| cab7cd37b8 | |||
| 201a18d6ba | |||
| f990f4a8d4 | |||
| aa5e37f2db | |||
| e7c89b6153 | |||
| 3e49d6b900 | |||
| 8aaff7fec1 | |||
| 51ac23c9f1 | |||
| 9dd0361d0e | |||
| 3d2840edb6 | |||
| ce0a59b60d | |||
| 2d8acf92f0 | |||
| bc2ffa39fc | |||
| 390c805ef4 | |||
| 5b753dfd6e | |||
| 5c8b80b01a | |||
| 95d62039b1 | |||
| 78acfb0040 | |||
| eb821efda7 | |||
| 925825a41b | |||
| f925266c1b | |||
| 07ff8df58d | |||
| 0a0f02c0c6 | |||
| d2f41ae9ef | |||
| 5a4f5f54a7 | |||
| eabfa8f3af | |||
| 1557f48740 | |||
| 00d787a75b | |||
| 3b454fa95a | |||
| 0da4d64d38 | |||
| 6e2cf23a73 | |||
| 8b0bc6937d | |||
| 872fd98eda | |||
| 5bcd3b6fe6 | |||
| 1aed585a19 | |||
| 831eba8b1c | |||
| b09a831d15 | |||
| 4d3d8b35d9 | |||
| c323028179 | |||
| 94dbda503f | |||
| beefff3d48 | |||
| c2e5081437 | |||
| 786c3e4137 | |||
| 0d33714f28 | |||
| 1fbba38436 | |||
| 15c3d712d3 | |||
| 5b01f544d1 | |||
| 8b8e521c4e | |||
| fe4c591cfd | |||
| 0cd613ae52 | |||
| 0082f468b4 | |||
| eec57e84e4 | |||
| 70149ea05e | |||
| 1d93f41fcf | |||
| cd0f41a3e0 | |||
| 094c9fd802 | |||
| 1584a78fc9 | |||
| 88248ad2d3 | |||
| 1a203031e0 | |||
| 05c3344554 | |||
| 888be71639 | |||
| 3902929d9f | |||
| 760a739e91 | |||
| 1c7c475c43 | |||
| cef7fd484b | |||
| caabca3f02 | |||
| d92c476388 | |||
| 36b7075cf4 | |||
| f3761c26e9 | |||
| 43daf4f82c | |||
| 932be0ad64 | |||
| 9012dced6a | |||
| 50bed78d7a | |||
| 60250355cb | |||
| 75afc2dc0e | |||
| 225b13da93 | |||
| 37c748192d | |||
| b7a2957340 | |||
| a6ce6a249b | |||
| 8834e6e531 | |||
| 04f40303fd | |||
| ececc5ec2c | |||
| 81547c5981 | |||
| a911b268aa | |||
| 39010fd153 | |||
| dc8a618b6a | |||
| f3e7fea628 | |||
| 926349b1f8 | |||
| ec29c24916 | |||
| 3842eade67 | |||
| bd338a9043 | |||
| cf7e2d5d75 | |||
| 2673fe05a5 | |||
| 180fdffab1 | |||
| 62e422f75a | |||
| 41565e91ed | |||
| c9610e9949 | |||
| 29dc083d8d | |||
| 39d6383474 | |||
| f679065d2c | |||
| 0a97e87a8e | |||
| 4d81455a83 | |||
| 39091fe4df | |||
| bac5245cd0 | |||
| 274f9a3f32 | |||
| a513ab9a59 | |||
| e83635ee5a | |||
| d79372a46d | |||
| bbd11c9e89 | |||
| 152fd52cd7 | |||
| ccabdbc83b | |||
| 56c8221b3f | |||
| add8980790 | |||
| 5157e1a96c | |||
| d132abcdb4 | |||
| d60348572e | |||
| f55faae31b | |||
| 0cff94d90e | |||
| 7fc25cafb2 | |||
| a7859de625 | |||
| 4bb76acc37 | |||
| b513933040 | |||
| 18ea9d3f18 | |||
| 7b660a9ebc | |||
| 783a49bd97 | |||
| d3c6b09354 | |||
| 3d61496d25 | |||
| 16bff9e82f | |||
| 22f25731e8 | |||
| 035f51ad58 | |||
| e9795bd772 | |||
| 93b516a4ec | |||
| fc9d5b2a62 | |||
| e3bfb95c52 | |||
| 047ea8c143 | |||
| 752cb9e4f4 | |||
| f54b9b12b0 | |||
| cb99b8f04d | |||
| 7c03bcba2b | |||
| 92fa7271ed | |||
| d3486cab31 | |||
| dd0a870969 | |||
| 0c4c268003 | |||
| ff57848268 | |||
| d223fee9b9 | |||
| ad18d084f3 | |||
| 9941d1f160 | |||
| 13fa56b5b1 | |||
| 9ce48b4dc4 | |||
| abb2b860f2 | |||
| 930c36e757 | |||
| 2d2ce5df85 | |||
| 2b23c43434 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -209,6 +209,7 @@ api/.vscode
|
||||
.history
|
||||
|
||||
.idea/
|
||||
web/migration/
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
|
||||
@ -717,3 +717,13 @@ SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
|
||||
|
||||
# Sandbox Dify CLI configuration
|
||||
# Directory containing dify CLI binaries (dify-cli-<os>-<arch>). Defaults to api/bin when unset.
|
||||
SANDBOX_DIFY_CLI_ROOT=
|
||||
|
||||
# CLI API URL for sandbox (dify-sandbox or e2b) to call back to Dify API.
|
||||
# This URL must be accessible from the sandbox environment.
|
||||
# For local development: use http://localhost:5001 or http://127.0.0.1:5001
|
||||
# For Docker deployment: use http://api:5001 (internal Docker network)
|
||||
# For external sandbox (e.g., e2b): use a publicly accessible URL
|
||||
CLI_API_URL=http://localhost:5001
|
||||
|
||||
14
api/agent-notes/core/app_assets/packager/zip_packager.py.md
Normal file
14
api/agent-notes/core/app_assets/packager/zip_packager.py.md
Normal file
@ -0,0 +1,14 @@
|
||||
# Zip Packager Notes
|
||||
|
||||
## Purpose
|
||||
- Builds a ZIP archive of asset contents stored via the configured storage backend.
|
||||
|
||||
## Key Decisions
|
||||
- Packaging writes assets into an in-memory zip buffer returned as bytes.
|
||||
- Asset fetch + zip writing are executed via a thread pool with a lock guarding `ZipFile` writes.
|
||||
|
||||
## Edge Cases
|
||||
- ZIP writes are serialized by the lock; storage reads still run in parallel.
|
||||
|
||||
## Tests/Verification
|
||||
- None yet.
|
||||
@ -0,0 +1,16 @@
|
||||
# E2B Sandbox Provider Notes
|
||||
|
||||
## Purpose
|
||||
- Implements the E2B-backed `VirtualEnvironment` provider and bootstraps sandbox metadata, file I/O, and command execution.
|
||||
|
||||
## Key Decisions
|
||||
- Sandbox metadata is gathered during `_construct_environment` using the E2B SDK before returning `Metadata`.
|
||||
- Architecture/OS detection uses a single `uname -m -s` call split by whitespace to reduce round-trips.
|
||||
- Command execution streams stdout/stderr through `QueueTransportReadCloser`; stdin is unsupported.
|
||||
|
||||
## Edge Cases
|
||||
- `release_environment` raises when sandbox termination fails.
|
||||
- `execute_command` runs in a background thread; consumers must read stdout/stderr until EOF.
|
||||
|
||||
## Tests/Verification
|
||||
- None yet. Add targeted service tests when behavior changes.
|
||||
BIN
api/bin/dify-cli-darwin-amd64
Executable file
BIN
api/bin/dify-cli-darwin-amd64
Executable file
Binary file not shown.
BIN
api/bin/dify-cli-darwin-arm64
Executable file
BIN
api/bin/dify-cli-darwin-arm64
Executable file
Binary file not shown.
BIN
api/bin/dify-cli-linux-amd64
Executable file
BIN
api/bin/dify-cli-linux-amd64
Executable file
Binary file not shown.
BIN
api/bin/dify-cli-linux-arm64
Executable file
BIN
api/bin/dify-cli-linux-arm64
Executable file
Binary file not shown.
@ -23,7 +23,8 @@ from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from core.sandbox import SandboxBuilder, SandboxType
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -1245,7 +1246,7 @@ def remove_orphaned_files_on_storage(force: bool):
|
||||
click.echo(click.style(f"- Scanning files on storage path {storage_path}", fg="white"))
|
||||
files = storage.scan(path=storage_path, files=True, directories=False)
|
||||
all_files_on_storage.extend(files)
|
||||
except FileNotFoundError as e:
|
||||
except FileNotFoundError:
|
||||
click.echo(click.style(f" -> Skipping path {storage_path} as it does not exist.", fg="yellow"))
|
||||
continue
|
||||
except Exception as e:
|
||||
@ -1493,6 +1494,57 @@ def file_usage(
|
||||
click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white"))
|
||||
|
||||
|
||||
@click.command("setup-sandbox-system-config", help="Setup system-level sandbox provider configuration.")
|
||||
@click.option(
|
||||
"--provider-type", prompt=True, type=click.Choice(["e2b", "docker", "local"]), help="Sandbox provider type"
|
||||
)
|
||||
@click.option("--config", prompt=True, help='Configuration JSON (e.g., {"api_key": "xxx"} for e2b)')
|
||||
def setup_sandbox_system_config(provider_type: str, config: str):
|
||||
"""
|
||||
Setup system-level sandbox provider configuration.
|
||||
|
||||
Examples:
|
||||
flask setup-sandbox-system-config --provider-type e2b --config '{"api_key": "e2b_xxx"}'
|
||||
flask setup-sandbox-system-config --provider-type docker --config '{"docker_sock": "unix:///var/run/docker.sock"}'
|
||||
flask setup-sandbox-system-config --provider-type local --config '{}'
|
||||
"""
|
||||
from models.sandbox import SandboxProviderSystemConfig
|
||||
|
||||
try:
|
||||
click.echo(click.style(f"Validating config: {config}", fg="yellow"))
|
||||
config_dict = TypeAdapter(dict[str, Any]).validate_json(config)
|
||||
click.echo(click.style("Config validated successfully.", fg="green"))
|
||||
|
||||
click.echo(click.style(f"Validating config schema for provider type: {provider_type}", fg="yellow"))
|
||||
SandboxBuilder.validate(SandboxType(provider_type), config_dict)
|
||||
click.echo(click.style("Config schema validated successfully.", fg="green"))
|
||||
|
||||
click.echo(click.style("Encrypting config...", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
encrypted_config = encrypt_system_params(config_dict)
|
||||
click.echo(click.style("Config encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error validating/encrypting config: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = db.session.query(SandboxProviderSystemConfig).filter_by(provider_type=provider_type).delete()
|
||||
if deleted_count > 0:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Deleted {deleted_count} existing system config for provider type: {provider_type}", fg="yellow"
|
||||
)
|
||||
)
|
||||
|
||||
system_config = SandboxProviderSystemConfig(
|
||||
provider_type=provider_type,
|
||||
encrypted_config=encrypted_config,
|
||||
)
|
||||
db.session.add(system_config)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"Sandbox system config setup successfully. id: {system_config.id}", fg="green"))
|
||||
click.echo(click.style(f"Provider type: {provider_type}", fg="green"))
|
||||
|
||||
|
||||
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
@ -1512,7 +1564,7 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
@ -1561,7 +1613,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, TomlConfigSettingsSource
|
||||
|
||||
@ -82,6 +83,14 @@ class DifyConfig(
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
SANDBOX_DIFY_CLI_ROOT: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Filesystem directory containing dify CLI binaries named dify-cli-<os>-<arch>. "
|
||||
"Defaults to api/bin when unset."
|
||||
),
|
||||
)
|
||||
|
||||
# Before adding any config,
|
||||
# please consider to arrange it in the proper config group of existed or added
|
||||
# for better readability and maintainability.
|
||||
|
||||
@ -244,6 +244,17 @@ class PluginConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CliApiConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for CLI API (for dify-cli to call back from external sandbox environments)
|
||||
"""
|
||||
|
||||
CLI_API_URL: str = Field(
|
||||
description="CLI API URL for external sandbox (e.g., e2b) to call back.",
|
||||
default="http://localhost:5001",
|
||||
)
|
||||
|
||||
|
||||
class MarketplaceConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for marketplace
|
||||
@ -1313,6 +1324,7 @@ class FeatureConfig(
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
CliApiConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
EndpointConfig,
|
||||
|
||||
27
api/controllers/cli_api/__init__.py
Normal file
27
api/controllers/cli_api/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("cli_api", __name__, url_prefix="/cli/api")
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="CLI API",
|
||||
description="APIs for Dify CLI to call back from external sandbox environments (e.g., e2b)",
|
||||
)
|
||||
|
||||
# Create namespace
|
||||
cli_api_ns = Namespace("cli_api", description="CLI API operations", path="/")
|
||||
|
||||
from .plugin import plugin as _plugin
|
||||
|
||||
api.add_namespace(cli_api_ns)
|
||||
|
||||
__all__ = [
|
||||
"_plugin",
|
||||
"api",
|
||||
"bp",
|
||||
"cli_api_ns",
|
||||
]
|
||||
0
api/controllers/cli_api/plugin/__init__.py
Normal file
0
api/controllers/cli_api/plugin/__init__.py
Normal file
137
api/controllers/cli_api/plugin/plugin.py
Normal file
137
api/controllers/cli_api/plugin/plugin.py
Normal file
@ -0,0 +1,137 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.cli_api import cli_api_ns
|
||||
from controllers.cli_api.plugin.wraps import get_cli_user_tenant, plugin_data
|
||||
from controllers.cli_api.wraps import cli_api_only
|
||||
from controllers.console.wraps import setup_required
|
||||
from core.file.helpers import get_signed_file_url_for_plugin
|
||||
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
|
||||
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeApp,
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeTool,
|
||||
RequestRequestUploadFile,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from libs.helper import length_prefixed_response
|
||||
from models import Account, Tenant
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
@cli_api_ns.route("/invoke/llm")
|
||||
class CliInvokeLLMApi(Resource):
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@cli_api_only
|
||||
@plugin_data(payload_type=RequestInvokeLLM)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@cli_api_ns.route("/invoke/tool")
|
||||
class CliInvokeToolApi(Resource):
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@cli_api_only
|
||||
@plugin_data(payload_type=RequestInvokeTool)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool):
|
||||
def generator():
|
||||
return PluginToolBackwardsInvocation.convert_to_event_stream(
|
||||
PluginToolBackwardsInvocation.invoke_tool(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
tool_type=ToolProviderType.value_of(payload.tool_type),
|
||||
provider=payload.provider,
|
||||
tool_name=payload.tool,
|
||||
tool_parameters=payload.tool_parameters,
|
||||
credential_id=payload.credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@cli_api_ns.route("/invoke/app")
|
||||
class CliInvokeAppApi(Resource):
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@cli_api_only
|
||||
@plugin_data(payload_type=RequestInvokeApp)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp):
|
||||
response = PluginAppBackwardsInvocation.invoke_app(
|
||||
app_id=payload.app_id,
|
||||
user_id=user_model.id,
|
||||
tenant_id=tenant_model.id,
|
||||
conversation_id=payload.conversation_id,
|
||||
query=payload.query,
|
||||
stream=payload.response_mode == "streaming",
|
||||
inputs=payload.inputs,
|
||||
files=payload.files,
|
||||
)
|
||||
|
||||
return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
|
||||
|
||||
@cli_api_ns.route("/upload/file/request")
|
||||
class CliUploadFileRequestApi(Resource):
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@cli_api_only
|
||||
@plugin_data(payload_type=RequestRequestUploadFile)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
|
||||
# generate signed url
|
||||
url = get_signed_file_url_for_plugin(
|
||||
filename=payload.filename,
|
||||
mimetype=payload.mimetype,
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
)
|
||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||
|
||||
|
||||
@cli_api_ns.route("/fetch/tools/list")
|
||||
class CliFetchToolsListApi(Resource):
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@cli_api_only
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant):
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
|
||||
providers = []
|
||||
|
||||
# Get builtin tools
|
||||
builtin_providers = BuiltinToolManageService.list_builtin_tools(user_model.id, tenant_model.id)
|
||||
for provider in builtin_providers:
|
||||
providers.append(provider.to_dict())
|
||||
|
||||
# Get API tools
|
||||
api_providers = ApiToolManageService.list_api_tools(tenant_model.id)
|
||||
for provider in api_providers:
|
||||
providers.append(provider.to_dict())
|
||||
|
||||
# Get workflow tools
|
||||
workflow_providers = WorkflowToolManageService.list_tenant_workflow_tools(user_model.id, tenant_model.id)
|
||||
for provider in workflow_providers:
|
||||
providers.append(provider.to_dict())
|
||||
|
||||
# Get MCP tools
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session)
|
||||
mcp_providers = mcp_service.list_providers(tenant_id=tenant_model.id, for_list=True)
|
||||
for provider in mcp_providers:
|
||||
providers.append(provider.to_dict())
|
||||
|
||||
return BaseBackwardsInvocationResponse(data={"providers": providers}).model_dump()
|
||||
146
api/controllers/cli_api/plugin/wraps.py
Normal file
146
api/controllers/cli_api/plugin/wraps.py
Normal file
@ -0,0 +1,146 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.session.cli_api import CliApiSession, CliApiSessionManager
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Tenant
|
||||
from models.model import DefaultEndUserSessionID, EndUser
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class TenantUserPayload(BaseModel):
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
|
||||
|
||||
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
"""
|
||||
Get current user
|
||||
|
||||
NOTE: user_id is not trusted, it could be maliciously set to any value.
|
||||
As a result, it could only be considered as an end user id.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
user_model = None
|
||||
|
||||
if is_anonymous:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
type="service_api",
|
||||
is_anonymous=is_anonymous,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(user_model)
|
||||
session.commit()
|
||||
session.refresh(user_model)
|
||||
|
||||
except Exception:
|
||||
raise ValueError("user not found")
|
||||
|
||||
return user_model
|
||||
|
||||
|
||||
def get_cli_user_tenant(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
session_id = request.headers.get("X-Cli-Api-Session-Id")
|
||||
|
||||
if session_id:
|
||||
session: CliApiSession | None = CliApiSessionManager().get(session_id)
|
||||
if not session:
|
||||
raise ValueError("session not found")
|
||||
user_id = session.user_id
|
||||
tenant_id = session.tenant_id
|
||||
|
||||
else:
|
||||
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
|
||||
user_id = payload.user_id
|
||||
tenant_id = payload.tenant_id
|
||||
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required")
|
||||
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
try:
|
||||
tenant_model = (
|
||||
db.session.query(Tenant)
|
||||
.where(
|
||||
Tenant.id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
if not tenant_model:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
kwargs["tenant_model"] = tenant_model
|
||||
|
||||
user = get_user(tenant_id, user_id)
|
||||
kwargs["user_model"] = user
|
||||
|
||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
try:
|
||||
data = request.get_json()
|
||||
except Exception:
|
||||
raise ValueError("invalid json")
|
||||
|
||||
try:
|
||||
payload = payload_type.model_validate(data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid payload: {str(e)}")
|
||||
|
||||
kwargs["payload"] = payload
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
54
api/controllers/cli_api/wraps.py
Normal file
54
api/controllers/cli_api/wraps.py
Normal file
@ -0,0 +1,54 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
|
||||
from core.session.cli_api import CliApiSessionManager
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
SIGNATURE_TTL_SECONDS = 300
|
||||
|
||||
|
||||
def _verify_signature(session_secret: str, timestamp: str, body: bytes, signature: str) -> bool:
|
||||
expected = hmac.new(
|
||||
session_secret.encode(),
|
||||
f"{timestamp}.".encode() + body,
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
return hmac.compare_digest(f"sha256={expected}", signature)
|
||||
|
||||
|
||||
def cli_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
session_id = request.headers.get("X-Cli-Api-Session-Id")
|
||||
timestamp = request.headers.get("X-Cli-Api-Timestamp")
|
||||
signature = request.headers.get("X-Cli-Api-Signature")
|
||||
|
||||
if not session_id or not timestamp or not signature:
|
||||
abort(401)
|
||||
|
||||
try:
|
||||
ts = int(timestamp)
|
||||
if abs(time.time() - ts) > SIGNATURE_TTL_SECONDS:
|
||||
abort(401)
|
||||
except ValueError:
|
||||
abort(401)
|
||||
|
||||
session = CliApiSessionManager().get(session_id)
|
||||
if not session:
|
||||
abort(401)
|
||||
|
||||
body = request.get_data()
|
||||
if not _verify_signature(session.secret, timestamp, body, signature):
|
||||
abort(401)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
@ -50,6 +50,7 @@ from .app import (
|
||||
agent,
|
||||
annotation,
|
||||
app,
|
||||
app_asset,
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
@ -126,6 +127,7 @@ from .workspace import (
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
sandbox_providers,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
@ -144,6 +146,7 @@ __all__ = [
|
||||
"api",
|
||||
"apikey",
|
||||
"app",
|
||||
"app_asset",
|
||||
"audio",
|
||||
"billing",
|
||||
"bp",
|
||||
@ -191,6 +194,7 @@ __all__ = [
|
||||
"rag_pipeline_import",
|
||||
"rag_pipeline_workflow",
|
||||
"recommended_app",
|
||||
"sandbox_providers",
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
|
||||
274
api/controllers/console/app/app_asset.py
Normal file
274
api/controllers/console/app/app_asset.py
Normal file
@ -0,0 +1,274 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppAssetFileRequiredError,
|
||||
AppAssetNodeNotFoundError,
|
||||
AppAssetPathConflictError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.app_asset_service import AppAssetService
|
||||
from services.errors.app_asset import (
|
||||
AppAssetNodeNotFoundError as ServiceNodeNotFoundError,
|
||||
)
|
||||
from services.errors.app_asset import (
|
||||
AppAssetParentNotFoundError,
|
||||
)
|
||||
from services.errors.app_asset import (
|
||||
AppAssetPathConflictError as ServicePathConflictError,
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class CreateFolderPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
parent_id: str | None = None
|
||||
|
||||
|
||||
class CreateFilePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
parent_id: str | None = None
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def strip_name(cls, v: str) -> str:
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
@field_validator("parent_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, v: str | None) -> str | None:
|
||||
return v or None
|
||||
|
||||
|
||||
class UpdateFileContentPayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class RenameNodePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
|
||||
|
||||
class MoveNodePayload(BaseModel):
|
||||
parent_id: str | None = None
|
||||
|
||||
|
||||
class ReorderNodePayload(BaseModel):
|
||||
after_node_id: str | None = Field(default=None, description="Place after this node, None for first position")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]) -> None:
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(CreateFolderPayload)
|
||||
reg(CreateFilePayload)
|
||||
reg(UpdateFileContentPayload)
|
||||
reg(RenameNodePayload)
|
||||
reg(MoveNodePayload)
|
||||
reg(ReorderNodePayload)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/tree")
|
||||
class AppAssetTreeResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tree = AppAssetService.get_asset_tree(app_model, current_user.id)
|
||||
return {"children": [view.model_dump() for view in tree.transform()]}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/folders")
|
||||
class AppAssetFolderResource(Resource):
|
||||
@console_ns.expect(console_ns.models[CreateFolderPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = CreateFolderPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.create_folder(app_model, current_user.id, payload.name, payload.parent_id)
|
||||
return node.model_dump(), 201
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files")
|
||||
class AppAssetFileResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
file = request.files.get("file")
|
||||
if not file:
|
||||
raise AppAssetFileRequiredError()
|
||||
|
||||
payload = CreateFilePayload.model_validate(request.form.to_dict())
|
||||
content = file.read()
|
||||
|
||||
try:
|
||||
node = AppAssetService.create_file(app_model, current_user.id, payload.name, content, payload.parent_id)
|
||||
return node.model_dump(), 201
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files/<string:node_id>")
|
||||
class AppAssetFileDetailResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
try:
|
||||
content = AppAssetService.get_file_content(app_model, current_user.id, node_id)
|
||||
return {"content": content.decode("utf-8", errors="replace")}
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
@console_ns.expect(console_ns.models[UpdateFileContentPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def put(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
file = request.files.get("file")
|
||||
if file:
|
||||
content = file.read()
|
||||
else:
|
||||
payload = UpdateFileContentPayload.model_validate(console_ns.payload or {})
|
||||
content = payload.content.encode("utf-8")
|
||||
|
||||
try:
|
||||
node = AppAssetService.update_file_content(app_model, current_user.id, node_id, content)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>")
|
||||
class AppAssetNodeResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
try:
|
||||
AppAssetService.delete_node(app_model, current_user.id, node_id)
|
||||
return {"result": "success"}, 200
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/rename")
|
||||
class AppAssetNodeRenameResource(Resource):
|
||||
@console_ns.expect(console_ns.models[RenameNodePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = RenameNodePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.rename_node(app_model, current_user.id, node_id, payload.name)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/move")
|
||||
class AppAssetNodeMoveResource(Resource):
|
||||
@console_ns.expect(console_ns.models[MoveNodePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = MoveNodePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.move_node(app_model, current_user.id, node_id, payload.parent_id)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/reorder")
|
||||
class AppAssetNodeReorderResource(Resource):
|
||||
@console_ns.expect(console_ns.models[ReorderNodePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = ReorderNodePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.reorder_node(app_model, current_user.id, node_id, payload.after_node_id)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/publish")
|
||||
class AppAssetPublishResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
published = AppAssetService.publish(app_model, current_user.id)
|
||||
return {
|
||||
"id": published.id,
|
||||
"version": published.version,
|
||||
"asset_tree": published.asset_tree.model_dump(),
|
||||
}, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files/<string:node_id>/download-url")
|
||||
class AppAssetFileDownloadUrlResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
try:
|
||||
download_url = AppAssetService.get_file_download_url(app_model, current_user.id, node_id)
|
||||
return {"download_url": download_url}
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
@ -110,8 +110,24 @@ class TracingConfigCheckError(BaseHTTPException):
|
||||
|
||||
|
||||
class InvokeRateLimitError(BaseHTTPException):
|
||||
"""Raised when the Invoke returns rate limit error."""
|
||||
|
||||
error_code = "rate_limit_error"
|
||||
description = "Rate Limit Error"
|
||||
code = 429
|
||||
|
||||
|
||||
class AppAssetNodeNotFoundError(BaseHTTPException):
|
||||
error_code = "app_asset_node_not_found"
|
||||
description = "App asset node not found."
|
||||
code = 404
|
||||
|
||||
|
||||
class AppAssetFileRequiredError(BaseHTTPException):
|
||||
error_code = "app_asset_file_required"
|
||||
description = "File is required."
|
||||
code = 400
|
||||
|
||||
|
||||
class AppAssetPathConflictError(BaseHTTPException):
|
||||
error_code = "app_asset_path_conflict"
|
||||
description = "Path already exists."
|
||||
code = 409
|
||||
|
||||
@ -55,6 +55,35 @@ class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
class ContextGeneratePayload(BaseModel):
|
||||
"""Payload for generating extractor code node."""
|
||||
|
||||
workflow_id: str = Field(..., description="Workflow ID")
|
||||
node_id: str = Field(..., description="Current tool/llm node ID")
|
||||
parameter_name: str = Field(..., description="Parameter name to generate code for")
|
||||
language: str = Field(default="python3", description="Code language (python3/javascript)")
|
||||
prompt_messages: list[dict[str, Any]] = Field(
|
||||
..., description="Multi-turn conversation history, last message is the current instruction"
|
||||
)
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
|
||||
|
||||
class SuggestedQuestionsPayload(BaseModel):
|
||||
"""Payload for generating suggested questions."""
|
||||
|
||||
workflow_id: str = Field(..., description="Workflow ID")
|
||||
node_id: str = Field(..., description="Current tool/llm node ID")
|
||||
parameter_name: str = Field(..., description="Parameter name")
|
||||
language: str = Field(
|
||||
default="English", description="Language for generated questions (e.g. English, Chinese, Japanese)"
|
||||
)
|
||||
model_config_data: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
alias="model_config",
|
||||
description="Model configuration (optional, uses system default if not provided)",
|
||||
)
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
@ -64,6 +93,8 @@ reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
reg(ContextGeneratePayload)
|
||||
reg(SuggestedQuestionsPayload)
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
@ -278,3 +309,74 @@ class InstructionGenerationTemplateApi(Resource):
|
||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||
case _:
|
||||
raise ValueError(f"Invalid type: {args.type}")
|
||||
|
||||
|
||||
@console_ns.route("/context-generate")
|
||||
class ContextGenerateApi(Resource):
|
||||
@console_ns.doc("generate_with_context")
|
||||
@console_ns.doc(description="Generate with multi-turn conversation context")
|
||||
@console_ns.expect(console_ns.models[ContextGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Content generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or workflow not found")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
from core.llm_generator.utils import deserialize_prompt_messages
|
||||
|
||||
args = ContextGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
prompt_messages = deserialize_prompt_messages(args.prompt_messages)
|
||||
|
||||
try:
|
||||
return LLMGenerator.generate_with_context(
|
||||
tenant_id=current_tenant_id,
|
||||
workflow_id=args.workflow_id,
|
||||
node_id=args.node_id,
|
||||
parameter_name=args.parameter_name,
|
||||
language=args.language,
|
||||
prompt_messages=prompt_messages,
|
||||
model_config=args.model_config_data,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
@console_ns.route("/context-generate/suggested-questions")
|
||||
class SuggestedQuestionsApi(Resource):
|
||||
@console_ns.doc("generate_suggested_questions")
|
||||
@console_ns.doc(description="Generate suggested questions for context generation")
|
||||
@console_ns.expect(console_ns.models[SuggestedQuestionsPayload.__name__])
|
||||
@console_ns.response(200, "Questions generated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = SuggestedQuestionsPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return LLMGenerator.generate_suggested_questions(
|
||||
tenant_id=current_tenant_id,
|
||||
workflow_id=args.workflow_id,
|
||||
node_id=args.node_id,
|
||||
parameter_name=args.parameter_name,
|
||||
language=args.language,
|
||||
model_config=args.model_config_data,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
@ -202,6 +202,7 @@ message_detail_model = console_ns.model(
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"generation_detail": fields.Raw,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -46,6 +46,8 @@ from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow.entities import MentionGraphRequest, MentionParameterSchema
|
||||
from services.workflow.mention_graph_service import MentionGraphService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -188,6 +190,15 @@ class DraftWorkflowTriggerRunAllPayload(BaseModel):
|
||||
node_ids: list[str]
|
||||
|
||||
|
||||
class MentionGraphPayload(BaseModel):
|
||||
"""Request payload for generating mention graph."""
|
||||
|
||||
parent_node_id: str = Field(description="ID of the parent node that uses the extracted value")
|
||||
parameter_key: str = Field(description="Key of the parameter being extracted")
|
||||
context_source: list[str] = Field(description="Variable selector for the context source")
|
||||
parameter_schema: dict[str, Any] = Field(description="Schema of the parameter to extract")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
@ -205,6 +216,7 @@ reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
reg(MentionGraphPayload)
|
||||
|
||||
|
||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||
@ -1166,3 +1178,54 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
"status": "error",
|
||||
}
|
||||
), 400
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/mention-graph")
|
||||
class MentionGraphApi(Resource):
|
||||
"""
|
||||
API for generating Mention LLM node graph structures.
|
||||
|
||||
This endpoint creates a complete graph structure containing an LLM node
|
||||
configured to extract values from list[PromptMessage] variables.
|
||||
"""
|
||||
|
||||
@console_ns.doc("generate_mention_graph")
|
||||
@console_ns.doc(description="Generate a Mention LLM node graph structure")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MentionGraphPayload.__name__])
|
||||
@console_ns.response(200, "Mention graph generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Generate a Mention LLM node graph structure.
|
||||
|
||||
Returns a complete graph structure containing a single LLM node
|
||||
configured for extracting values from list[PromptMessage] context.
|
||||
"""
|
||||
|
||||
payload = MentionGraphPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
parameter_schema = MentionParameterSchema(
|
||||
name=payload.parameter_schema.get("name", payload.parameter_key),
|
||||
type=payload.parameter_schema.get("type", "string"),
|
||||
description=payload.parameter_schema.get("description", ""),
|
||||
)
|
||||
|
||||
request = MentionGraphRequest(
|
||||
parent_node_id=payload.parent_node_id,
|
||||
parameter_key=payload.parameter_key,
|
||||
context_source=payload.context_source,
|
||||
parameter_schema=parameter_schema,
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
service = MentionGraphService(session)
|
||||
response = service.generate_mention_graph(tenant_id=app_model.tenant_id, request=request)
|
||||
|
||||
return response.model_dump()
|
||||
|
||||
@ -17,7 +17,7 @@ from controllers.console.wraps import account_initialization_required, edit_perm
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
@ -58,6 +58,8 @@ def _convert_values_to_json_serializable_object(value: Segment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, ArrayPromptMessageSegment):
|
||||
return value.to_object()
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
|
||||
65
api/controllers/console/workspace/dsl.py
Normal file
65
api/controllers/console/workspace/dsl.py
Normal file
@ -0,0 +1,65 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.app_dsl_service import AppDslService
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/dsl/predict")
|
||||
class DSLPredictApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, _ = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_id", type=str, required=True, location="json")
|
||||
.add_argument("current_node_id", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
app_id: str = args["app_id"]
|
||||
current_node_id: str = args["current_node_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
app = session.query(App).filter_by(id=app_id).first()
|
||||
workflow = session.query(Workflow).filter_by(app_id=app_id, version=Workflow.VERSION_DRAFT).first()
|
||||
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
try:
|
||||
i = 0
|
||||
for node_id, _ in workflow.walk_nodes():
|
||||
if node_id == current_node_id:
|
||||
break
|
||||
i += 1
|
||||
|
||||
dsl = yaml.safe_load(AppDslService.export_dsl(app_model=app))
|
||||
|
||||
response = httpx.post(
|
||||
"http://spark-832c:8000/predict",
|
||||
json={"graph_data": dsl, "source_node_index": i},
|
||||
)
|
||||
return {
|
||||
"nodes": json.loads(response.json()),
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
103
api/controllers/console/workspace/sandbox_providers.py
Normal file
103
api/controllers/console/workspace/sandbox_providers.py
Normal file
@ -0,0 +1,103 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/sandbox-providers")
|
||||
class SandboxProviderListApi(Resource):
|
||||
@console_ns.doc("list_sandbox_providers")
|
||||
@console_ns.doc(description="Get list of available sandbox providers with configuration status")
|
||||
@console_ns.response(200, "Success", fields.List(fields.Raw(description="Sandbox provider information")))
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers = SandboxProviderService.list_providers(current_tenant_id)
|
||||
return jsonable_encoder([p.model_dump() for p in providers])
|
||||
|
||||
|
||||
config_parser = reqparse.RequestParser()
|
||||
config_parser.add_argument("config", type=dict, required=True, location="json")
|
||||
config_parser.add_argument("activate", type=bool, required=False, default=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/sandbox-provider/<string:provider_type>/config")
|
||||
class SandboxProviderConfigApi(Resource):
|
||||
@console_ns.doc("save_sandbox_provider_config")
|
||||
@console_ns.doc(description="Save or update configuration for a sandbox provider")
|
||||
@console_ns.expect(config_parser)
|
||||
@console_ns.response(200, "Success")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_type: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = config_parser.parse_args()
|
||||
|
||||
try:
|
||||
result = SandboxProviderService.save_config(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_type=provider_type,
|
||||
config=args["config"],
|
||||
activate=args["activate"],
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
@console_ns.doc("delete_sandbox_provider_config")
|
||||
@console_ns.doc(description="Delete configuration for a sandbox provider")
|
||||
@console_ns.response(200, "Success")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_type: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
result = SandboxProviderService.delete_config(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
|
||||
activate_parser = reqparse.RequestParser()
|
||||
activate_parser.add_argument("type", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/sandbox-provider/<string:provider_type>/activate")
|
||||
class SandboxProviderActivateApi(Resource):
|
||||
"""Activate a sandbox provider."""
|
||||
|
||||
@console_ns.doc("activate_sandbox_provider")
|
||||
@console_ns.doc(description="Activate a sandbox provider for the current workspace")
|
||||
@console_ns.response(200, "Success")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_type: str):
|
||||
"""Activate a sandbox provider."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
args = activate_parser.parse_args()
|
||||
result = SandboxProviderService.activate_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_type=provider_type,
|
||||
type=args["type"],
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
@ -14,7 +14,7 @@ api = ExternalApi(
|
||||
|
||||
files_ns = Namespace("files", description="File operations", path="/")
|
||||
|
||||
from . import image_preview, tool_files, upload
|
||||
from . import image_preview, storage_download, tool_files, upload
|
||||
|
||||
api.add_namespace(files_ns)
|
||||
|
||||
@ -23,6 +23,7 @@ __all__ = [
|
||||
"bp",
|
||||
"files_ns",
|
||||
"image_preview",
|
||||
"storage_download",
|
||||
"tool_files",
|
||||
"upload",
|
||||
]
|
||||
|
||||
56
api/controllers/files/storage_download.py
Normal file
56
api/controllers/files/storage_download.py
Normal file
@ -0,0 +1,56 @@
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.files import files_ns
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.file_presign_storage import FilePresignStorage
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class StorageDownloadQuery(BaseModel):
|
||||
timestamp: str = Field(..., description="Unix timestamp used in the signature")
|
||||
nonce: str = Field(..., description="Random string for signature")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
StorageDownloadQuery.__name__,
|
||||
StorageDownloadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/storage/<path:filename>/download")
|
||||
class StorageFileDownloadApi(Resource):
|
||||
def get(self, filename: str):
|
||||
filename = unquote(filename)
|
||||
|
||||
args = StorageDownloadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
if not FilePresignStorage.verify_signature(
|
||||
filename=filename,
|
||||
timestamp=args.timestamp,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
):
|
||||
raise Forbidden("Invalid or expired download link")
|
||||
|
||||
try:
|
||||
generator = storage.load_stream(filename)
|
||||
except FileNotFoundError:
|
||||
raise NotFound("File not found")
|
||||
|
||||
encoded_filename = quote(filename.split("/")[-1])
|
||||
|
||||
return Response(
|
||||
generator,
|
||||
mimetype="application/octet-stream",
|
||||
direct_passthrough=True,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
},
|
||||
)
|
||||
@ -448,3 +448,53 @@ class PluginFetchAppInfoApi(Resource):
|
||||
return BaseBackwardsInvocationResponse(
|
||||
data=PluginAppBackwardsInvocation.fetch_app_info(payload.app_id, tenant_model.id)
|
||||
).model_dump()
|
||||
|
||||
|
||||
@inner_api_ns.route("/fetch/tools/list")
|
||||
class PluginFetchToolsListApi(Resource):
|
||||
@get_user_tenant
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@inner_api_ns.doc("plugin_fetch_tools_list")
|
||||
@inner_api_ns.doc(description="Fetch all available tools through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Tools list retrieved successfully",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant):
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
|
||||
providers = []
|
||||
|
||||
# Get builtin tools
|
||||
builtin_providers = BuiltinToolManageService.list_builtin_tools(user_model.id, tenant_model.id)
|
||||
for provider in builtin_providers:
|
||||
providers.append(provider.to_dict())
|
||||
|
||||
# Get API tools
|
||||
api_providers = ApiToolManageService.list_api_tools(tenant_model.id)
|
||||
for provider in api_providers:
|
||||
providers.append(provider.to_dict())
|
||||
|
||||
# Get workflow tools
|
||||
workflow_providers = WorkflowToolManageService.list_tenant_workflow_tools(user_model.id, tenant_model.id)
|
||||
for provider in workflow_providers:
|
||||
providers.append(provider.to_dict())
|
||||
|
||||
# Get MCP tools
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session)
|
||||
mcp_providers = mcp_service.list_providers(tenant_id=tenant_model.id, for_list=True)
|
||||
for provider in mcp_providers:
|
||||
providers.append(provider.to_dict())
|
||||
|
||||
return BaseBackwardsInvocationResponse(data={"providers": providers}).model_dump()
|
||||
|
||||
@ -75,7 +75,6 @@ def get_user_tenant(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
|
||||
|
||||
user_id = payload.user_id
|
||||
tenant_id = payload.tenant_id
|
||||
|
||||
|
||||
@ -5,14 +5,15 @@ from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def billing_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
@ -88,11 +89,11 @@ def plugin_inner_api_only(view: Callable[P, R]):
|
||||
if not dify_config.PLUGIN_DAEMON_KEY:
|
||||
abort(404)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
# validate using inner api key
|
||||
inner_api_key = request.headers.get("X-Inner-Api-Key")
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
|
||||
abort(404)
|
||||
if inner_api_key and inner_api_key == dify_config.INNER_API_KEY_FOR_PLUGIN:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
abort(401)
|
||||
|
||||
return decorated
|
||||
|
||||
380
api/core/agent/agent_app_runner.py
Normal file
380
api/core/agent/agent_app_runner.py
Normal file
@ -0,0 +1,380 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentAppRunner(BaseAgentRunner):
|
||||
def _create_tool_invoke_hook(self, message: Message):
|
||||
"""
|
||||
Create a tool invoke hook that uses ToolEngine.agent_invoke.
|
||||
This hook handles file creation and returns proper meta information.
|
||||
"""
|
||||
# Get trace manager from app generate entity
|
||||
trace_manager = self.application_generate_entity.trace_manager
|
||||
|
||||
def tool_invoke_hook(
|
||||
tool: Tool, tool_args: dict[str, Any], tool_name: str
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""Hook that uses agent_invoke for proper file and meta handling."""
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
|
||||
# Publish files and track IDs
|
||||
for message_file_id in message_files:
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
self._current_message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, message_files, tool_invoke_meta
|
||||
|
||||
return tool_invoke_hook
|
||||
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run Agent application
|
||||
"""
|
||||
self.query = query
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, _ = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
# Create tool invoke hook for agent_invoke
|
||||
tool_invoke_hook = self._create_tool_invoke_hook(message)
|
||||
|
||||
# Get instruction for ReAct strategy
|
||||
instruction = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
# Use factory to create appropriate strategy
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=self.model_features,
|
||||
model_instance=self.model_instance,
|
||||
tools=list(tool_instances.values()),
|
||||
files=list(self.files),
|
||||
max_iterations=app_config.agent.max_iteration,
|
||||
context=self.build_execution_context(),
|
||||
agent_strategy=self.config.strategy,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Initialize state variables
|
||||
current_agent_thought_id = None
|
||||
has_published_thought = False
|
||||
current_tool_name: str | None = None
|
||||
self._current_message_file_ids: list[str] = []
|
||||
|
||||
# organize prompt messages
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
|
||||
# Run strategy
|
||||
generator = strategy.run(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Consume generator and collect result
|
||||
result: AgentResult | None = None
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
output = next(generator)
|
||||
except StopIteration as e:
|
||||
# Generator finished, get the return value
|
||||
result = e.value
|
||||
break
|
||||
|
||||
if isinstance(output, LLMResultChunk):
|
||||
# Handle LLM chunk
|
||||
if current_agent_thought_id and not has_published_thought:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
has_published_thought = True
|
||||
|
||||
yield output
|
||||
|
||||
elif isinstance(output, AgentLog):
|
||||
# Handle Agent Log using log_type for type-safe dispatch
|
||||
if output.status == AgentLog.LogStatus.START:
|
||||
if output.log_type == AgentLog.LogType.ROUND:
|
||||
# Start of a new round
|
||||
message_file_ids: list[str] = []
|
||||
current_agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message="",
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
has_published_thought = False
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call start - extract data from structured fields
|
||||
current_tool_name = output.data.get("tool_name", "")
|
||||
tool_input = output.data.get("tool_args", {})
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_input=tool_input,
|
||||
thought=None,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.status == AgentLog.LogStatus.SUCCESS:
|
||||
if output.log_type == AgentLog.LogType.THOUGHT:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
thought_text = output.data.get("thought")
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=thought_text,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call finished
|
||||
tool_output = output.data.get("output")
|
||||
# Get meta from strategy output (now properly populated)
|
||||
tool_meta = output.data.get("meta")
|
||||
|
||||
# Wrap tool_meta with tool_name as key (required by agent_service)
|
||||
if tool_meta and current_tool_name:
|
||||
tool_meta = {current_tool_name: tool_meta}
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
observation=tool_output,
|
||||
tool_invoke_meta=tool_meta,
|
||||
answer=None,
|
||||
messages_ids=self._current_message_file_ids,
|
||||
)
|
||||
# Clear message file ids after saving
|
||||
self._current_message_file_ids = []
|
||||
current_tool_name = None
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.ROUND:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Round finished - save LLM usage and answer
|
||||
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
|
||||
llm_result = output.data.get("llm_result")
|
||||
final_answer = output.data.get("final_answer")
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=llm_result,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Re-raise any other exceptions
|
||||
raise
|
||||
|
||||
# Process final result
|
||||
if isinstance(result, AgentResult):
|
||||
final_answer = result.text
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
# Publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=self.model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=usage,
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
if not prompt_template:
|
||||
return prompt_messages or []
|
||||
|
||||
prompt_messages = prompt_messages or []
|
||||
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
|
||||
return prompt_messages
|
||||
|
||||
if not prompt_messages:
|
||||
return [SystemPromptMessage(content=prompt_template)]
|
||||
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
We need to remove the image messages from the prompt messages at the first iteration.
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
# For ReAct strategy, use the agent prompt template
|
||||
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
|
||||
prompt_template = self.config.prompt.first_prompt
|
||||
else:
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
return prompt_messages
|
||||
@ -6,7 +6,7 @@ from typing import Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -116,9 +116,20 @@ class BaseAgentRunner(AppRunner):
|
||||
features = model_schema.features if model_schema and model_schema.features else []
|
||||
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
||||
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
||||
self.model_features = features
|
||||
self.query: str | None = ""
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
def build_execution_context(self) -> ExecutionContext:
|
||||
"""Build execution context."""
|
||||
return ExecutionContext(
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_config.app_id,
|
||||
conversation_id=self.conversation.id,
|
||||
message_id=self.message.id,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
def _repack_app_generate_entity(
|
||||
self, app_generate_entity: AgentChatAppGenerateEntity
|
||||
) -> AgentChatAppGenerateEntity:
|
||||
|
||||
@ -1,437 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.workflow.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage]
|
||||
_agent_scratchpad: list[AgentScratchpadUnit]
|
||||
_instruction: str
|
||||
_query: str
|
||||
_prompt_messages_tools: Sequence[PromptMessageTool]
|
||||
|
||||
def run(
|
||||
self,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: Mapping[str, str],
|
||||
) -> Generator:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
|
||||
app_generate_entity = self.application_generate_entity
|
||||
self._repack_app_generate_entity(app_generate_entity)
|
||||
self._init_react_state(query)
|
||||
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
# check model mode
|
||||
if "Observation" not in app_generate_entity.model_conf.stop:
|
||||
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config.agent
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template or ""
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
self._prompt_messages_tools = prompt_messages_tools
|
||||
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
final_answer = ""
|
||||
prompt_messages: list = [] # Initialize prompt_messages
|
||||
agent_thought_id = "" # Initialize agent_thought_id
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.total_tokens += usage.total_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
llm_usage.total_price += usage.total_price
|
||||
|
||||
model_instance = self.model_instance
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
self._prompt_messages_tools = []
|
||||
|
||||
message_file_ids: list[str] = []
|
||||
|
||||
agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
if iteration_step > 1:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
tools=[],
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
usage_dict: dict[str, LLMUsage | None] = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
action = chunk
|
||||
# detect action
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
scratchpad.action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.action = action
|
||||
else:
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += chunk
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Check if max iteration is reached and model still wants to call tools
|
||||
if iteration_step == max_iteration_steps and scratchpad.action:
|
||||
if scratchpad.action.action_name.lower() != "final answer":
|
||||
raise AgentMaxIterationError(app_config.agent.max_iteration)
|
||||
|
||||
# get llm usage
|
||||
if "usage" in usage_dict:
|
||||
if usage_dict["usage"] is not None:
|
||||
increase_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
usage_dict["usage"] = LLMUsage.empty_usage()
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought or "",
|
||||
observation="",
|
||||
answer=scratchpad.agent_response or "",
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
if not scratchpad.is_final():
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = ""
|
||||
else:
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
|
||||
elif isinstance(scratchpad.action.action_input, str):
|
||||
final_answer = scratchpad.action.action_input
|
||||
else:
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
except TypeError:
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
else:
|
||||
function_call_state = True
|
||||
# action is tool call, invoke tool
|
||||
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
|
||||
action=scratchpad.action,
|
||||
tool_instances=tool_instances,
|
||||
message_file_ids=message_file_ids,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
scratchpad.observation = tool_invoke_response
|
||||
scratchpad.agent_response = tool_invoke_response
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name=scratchpad.action.action_name,
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
thought=scratchpad.thought or "",
|
||||
observation={scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in self._prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
|
||||
),
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name="",
|
||||
tool_input={},
|
||||
tool_invoke_meta={},
|
||||
thought=final_answer,
|
||||
observation={},
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: Mapping[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
"""
|
||||
handle invoke action
|
||||
:param action: action
|
||||
:param tool_instances: tool instances
|
||||
:param message_file_ids: message file ids
|
||||
:param trace_manager: trace manager
|
||||
:return: observation, meta
|
||||
"""
|
||||
# action is tool call, invoke tool
|
||||
tool_call_name = action.action_name
|
||||
tool_call_args = action.action_input
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
|
||||
if not tool_instance:
|
||||
answer = f"there is not a tool named {tool_call_name}"
|
||||
return answer, ToolInvokeMeta.error_instance(answer)
|
||||
|
||||
if isinstance(tool_call_args, str):
|
||||
try:
|
||||
tool_call_args = json.loads(tool_call_args)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# invoke tool
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# publish files
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, tool_invoke_meta
|
||||
|
||||
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
|
||||
"""
|
||||
convert dict to action
|
||||
"""
|
||||
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
fill in inputs from external data tools
|
||||
"""
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return instruction
|
||||
|
||||
def _init_react_state(self, query):
|
||||
"""
|
||||
init agent scratchpad
|
||||
"""
|
||||
self._query = query
|
||||
self._agent_scratchpad = []
|
||||
self._historic_prompt_messages = self._organize_historic_prompt_messages()
|
||||
|
||||
@abstractmethod
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
organize prompt messages
|
||||
"""
|
||||
|
||||
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
format assistant message
|
||||
"""
|
||||
message = ""
|
||||
for scratchpad in agent_scratchpad:
|
||||
if scratchpad.is_final():
|
||||
message += f"Final Answer: {scratchpad.agent_response}"
|
||||
else:
|
||||
message += f"Thought: {scratchpad.thought}\n\n"
|
||||
if scratchpad.action_str:
|
||||
message += f"Action: {scratchpad.action_str}\n\n"
|
||||
if scratchpad.observation:
|
||||
message += f"Observation: {scratchpad.observation}\n\n"
|
||||
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: list[PromptMessage] | None = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
scratchpads: list[AgentScratchpadUnit] = []
|
||||
current_scratchpad: AgentScratchpadUnit | None = None
|
||||
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
if not current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
scratchpads.append(current_scratchpad)
|
||||
if message.tool_calls:
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except Exception:
|
||||
logger.exception("Failed to parse tool call from assistant message")
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad.observation = message.content
|
||||
else:
|
||||
raise NotImplementedError("expected str type")
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
scratchpads = []
|
||||
current_scratchpad = None
|
||||
|
||||
result.append(message)
|
||||
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
|
||||
historic_prompts = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=current_session_messages or [],
|
||||
history_messages=result,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
return historic_prompts
|
||||
@ -1,118 +0,0 @@
|
||||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class CotChatAgentRunner(CotAgentRunner):
|
||||
def _organize_system_prompt(self) -> SystemPromptMessage:
|
||||
"""
|
||||
Organize system prompt
|
||||
"""
|
||||
assert self.app_config.agent
|
||||
assert self.app_config.agent.prompt
|
||||
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if not prompt_entity:
|
||||
raise ValueError("Agent prompt configuration is not set")
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return SystemPromptMessage(content=system_prompt)
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize
|
||||
"""
|
||||
# organize system prompt
|
||||
system_message = self._organize_system_prompt()
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
if not agent_scratchpad:
|
||||
assistant_messages = []
|
||||
else:
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
||||
if unit.observation:
|
||||
assistant_message.content += f"Observation: {unit.observation}\n\n"
|
||||
|
||||
assistant_messages = [assistant_message]
|
||||
|
||||
# query messages
|
||||
query_messages = self._organize_user_query(self._query, [])
|
||||
|
||||
if assistant_messages:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages(
|
||||
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
|
||||
)
|
||||
messages = [
|
||||
system_message,
|
||||
*historic_messages,
|
||||
*query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content="continue"),
|
||||
]
|
||||
else:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
|
||||
messages = [system_message, *historic_messages, *query_messages]
|
||||
|
||||
# join all messages
|
||||
return messages
|
||||
@ -1,87 +0,0 @@
|
||||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class CotCompletionAgentRunner(CotAgentRunner):
|
||||
def _organize_instruction_prompt(self) -> str:
|
||||
"""
|
||||
Organize instruction prompt
|
||||
"""
|
||||
if self.app_config.agent is None:
|
||||
raise ValueError("Agent configuration is not set")
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if prompt_entity is None:
|
||||
raise ValueError("prompt entity is not set")
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str:
|
||||
"""
|
||||
Organize historic prompt
|
||||
"""
|
||||
historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
|
||||
historic_prompt = ""
|
||||
|
||||
for message in historic_prompt_messages:
|
||||
if isinstance(message, UserPromptMessage):
|
||||
historic_prompt += f"Question: {message.content}\n\n"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
if isinstance(message.content, str):
|
||||
historic_prompt += message.content + "\n\n"
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
historic_prompt += content.data
|
||||
|
||||
return historic_prompt
|
||||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
"""
|
||||
# organize system prompt
|
||||
system_prompt = self._organize_instruction_prompt()
|
||||
|
||||
# organize historic prompt messages
|
||||
historic_prompt = self._organize_historic_prompt()
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
assistant_prompt = ""
|
||||
for unit in agent_scratchpad or []:
|
||||
if unit.is_final():
|
||||
assistant_prompt += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assistant_prompt += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_prompt += f"Action: {unit.action_str}\n\n"
|
||||
if unit.observation:
|
||||
assistant_prompt += f"Observation: {unit.observation}\n\n"
|
||||
|
||||
# query messages
|
||||
query_prompt = f"Question: {self._query}"
|
||||
|
||||
# join all messages
|
||||
prompt = (
|
||||
system_prompt.replace("{{historic_messages}}", historic_prompt)
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt)
|
||||
.replace("{{query}}", query_prompt)
|
||||
)
|
||||
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
@ -92,3 +94,96 @@ class AgentInvokeMessage(ToolInvokeMessage):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
"""Execution context containing trace and audit information.
|
||||
|
||||
This context carries all the IDs and metadata that are not part of
|
||||
the core business logic but needed for tracing, auditing, and
|
||||
correlation purposes.
|
||||
"""
|
||||
|
||||
user_id: str | None = None
|
||||
app_id: str | None = None
|
||||
conversation_id: str | None = None
|
||||
message_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext":
|
||||
"""Create a minimal context with only essential fields."""
|
||||
return cls(user_id=user_id)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for passing to legacy code."""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"message_id": self.message_id,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
def with_updates(self, **kwargs) -> "ExecutionContext":
|
||||
"""Create a new context with updated fields."""
|
||||
data = self.to_dict()
|
||||
data.update(kwargs)
|
||||
|
||||
return ExecutionContext(
|
||||
user_id=data.get("user_id"),
|
||||
app_id=data.get("app_id"),
|
||||
conversation_id=data.get("conversation_id"),
|
||||
message_id=data.get("message_id"),
|
||||
tenant_id=data.get("tenant_id"),
|
||||
)
|
||||
|
||||
|
||||
class AgentLog(BaseModel):
|
||||
"""
|
||||
Agent Log.
|
||||
"""
|
||||
|
||||
class LogType(StrEnum):
|
||||
"""Type of agent log entry."""
|
||||
|
||||
ROUND = "round" # A complete iteration round
|
||||
THOUGHT = "thought" # LLM thinking/reasoning
|
||||
TOOL_CALL = "tool_call" # Tool invocation
|
||||
|
||||
class LogMetadata(StrEnum):
|
||||
STARTED_AT = "started_at"
|
||||
FINISHED_AT = "finished_at"
|
||||
ELAPSED_TIME = "elapsed_time"
|
||||
TOTAL_PRICE = "total_price"
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
PROVIDER = "provider"
|
||||
CURRENCY = "currency"
|
||||
LLM_USAGE = "llm_usage"
|
||||
ICON = "icon"
|
||||
ICON_DARK = "icon_dark"
|
||||
|
||||
class LogStatus(StrEnum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="The id of the log")
|
||||
label: str = Field(..., description="The label of the log")
|
||||
log_type: LogType = Field(..., description="The type of the log")
|
||||
parent_id: str | None = Field(default=None, description="Leave empty for root log")
|
||||
error: str | None = Field(default=None, description="The error message")
|
||||
status: LogStatus = Field(..., description="The status of the log")
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
|
||||
|
||||
|
||||
class AgentResult(BaseModel):
|
||||
"""
|
||||
Agent execution result.
|
||||
"""
|
||||
|
||||
text: str = Field(default="", description="The generated text")
|
||||
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
|
||||
usage: Any | None = Field(default=None, description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(default=None, description="Reason for completion")
|
||||
|
||||
@ -1,468 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.workflow.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run FunctionCall agent application
|
||||
"""
|
||||
self.query = query
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
final_answer = ""
|
||||
prompt_messages: list = [] # Initialize prompt_messages
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.total_tokens += usage.total_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
llm_usage.total_price += usage.total_price
|
||||
|
||||
model_instance = self.model_instance
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
prompt_messages_tools = []
|
||||
|
||||
message_file_ids: list[str] = []
|
||||
agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
tools=prompt_messages_tools,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=self.stream_tool_call,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ""
|
||||
|
||||
# save tool call names and inputs
|
||||
tool_call_names = ""
|
||||
tool_call_inputs = ""
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
if isinstance(chunks, Generator):
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
is_first_chunk = False
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_tool_calls(chunk) or [])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except TypeError:
|
||||
# fallback: force ASCII to handle non-serializable objects
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
for content in chunk.delta.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += str(chunk.delta.message.content)
|
||||
|
||||
if chunk.delta.usage:
|
||||
increase_usage(llm_usage, chunk.delta.usage)
|
||||
current_llm_usage = chunk.delta.usage
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
result = chunks
|
||||
# check if there is any tool call
|
||||
if self.check_blocking_tool_calls(result):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except TypeError:
|
||||
# fallback: force ASCII to handle non-serializable objects
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if result.usage:
|
||||
increase_usage(llm_usage, result.usage)
|
||||
current_llm_usage = result.usage
|
||||
|
||||
if result.message and result.message.content:
|
||||
if isinstance(result.message.content, list):
|
||||
for content in result.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += str(result.message.content)
|
||||
|
||||
if not result.message.content:
|
||||
result.message.content = ""
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
system_fingerprint=result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=result.message,
|
||||
usage=result.usage,
|
||||
),
|
||||
)
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
|
||||
if tool_calls:
|
||||
assistant_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call[0],
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
self._current_thoughts.append(assistant_message)
|
||||
|
||||
# save thought
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name=tool_call_names,
|
||||
tool_input=tool_call_inputs,
|
||||
thought=response,
|
||||
tool_invoke_meta=None,
|
||||
observation=None,
|
||||
answer=response,
|
||||
messages_ids=[],
|
||||
llm_usage=current_llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
final_answer += response + "\n"
|
||||
|
||||
# Check if max iteration is reached and model still wants to call tools
|
||||
if iteration_step == max_iteration_steps and tool_calls:
|
||||
raise AgentMaxIterationError(app_config.agent.max_iteration)
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
if not tool_instance:
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": f"there is not a tool named {tool_call_name}",
|
||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
|
||||
}
|
||||
else:
|
||||
# invoke tool
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=self.message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
# publish files
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": tool_invoke_response,
|
||||
"meta": tool_invoke_meta.to_dict(),
|
||||
}
|
||||
|
||||
tool_responses.append(tool_response)
|
||||
if tool_response["tool_response"] is not None:
|
||||
self._current_thoughts.append(
|
||||
ToolPromptMessage(
|
||||
content=str(tool_response["tool_response"]),
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
)
|
||||
|
||||
if len(tool_responses) > 0:
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
thought="",
|
||||
tool_invoke_meta={
|
||||
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
|
||||
},
|
||||
observation={
|
||||
tool_response["tool_call_name"]: tool_response["tool_response"]
|
||||
for tool_response in tool_responses
|
||||
},
|
||||
answer="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
||||
"""
|
||||
Check if there is any tool call in llm result chunk
|
||||
"""
|
||||
if llm_result_chunk.delta.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
|
||||
"""
|
||||
Check if there is any blocking tool call in llm result
|
||||
"""
|
||||
if llm_result.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
||||
"""
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
||||
"""
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
if not prompt_messages and prompt_template:
|
||||
return [
|
||||
SystemPromptMessage(content=prompt_template),
|
||||
]
|
||||
|
||||
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
|
||||
return prompt_messages or []
|
||||
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
We need to remove the image messages from the prompt messages at the first iteration.
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
return prompt_messages
|
||||
55
api/core/agent/patterns/README.md
Normal file
55
api/core/agent/patterns/README.md
Normal file
@ -0,0 +1,55 @@
|
||||
# Agent Patterns
|
||||
|
||||
A unified agent pattern module that powers both Agent V2 workflow nodes and agent applications. Strategies share a common execution contract while adapting to model capabilities and tool availability.
|
||||
|
||||
## Overview
|
||||
|
||||
The module applies a strategy pattern around LLM/tool orchestration. `StrategyFactory` auto-selects the best implementation based on model features or an explicit agent strategy, and each strategy streams logs and usage consistently.
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Dual strategies**
|
||||
- `FunctionCallStrategy`: uses native LLM function/tool calling when the model exposes `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL`.
|
||||
- `ReActStrategy`: ReAct (reasoning + acting) flow driven by `CotAgentOutputParser`, used when function calling is unavailable or explicitly requested.
|
||||
- **Explicit or auto selection**
|
||||
- `StrategyFactory.create_strategy` prefers an explicit `AgentEntity.Strategy` (FUNCTION_CALLING or CHAIN_OF_THOUGHT).
|
||||
- Otherwise it falls back to function calling when tool-call features exist, or ReAct when they do not.
|
||||
- **Unified execution contract**
|
||||
- `AgentPattern.run` yields streaming `AgentLog` entries and `LLMResultChunk` data, returning an `AgentResult` with text, files, usage, and `finish_reason`.
|
||||
- Iterations are configurable and hard-capped at 99 rounds; the last round forces a final answer by withholding tools.
|
||||
- **Tool handling and hooks**
|
||||
- Tools convert to `PromptMessageTool` objects before invocation.
|
||||
- Optional `tool_invoke_hook` lets callers override tool execution (e.g., agent apps) while workflow runs use `ToolEngine.generic_invoke`.
|
||||
- Tool outputs support text, links, JSON, variables, blobs, retriever resources, and file attachments; `target=="self"` files are reloaded into model context, others are returned as outputs.
|
||||
- **File-aware arguments**
|
||||
- Tool args accept `[File: <id>]` or `[Files: <id1, id2>]` placeholders that resolve to `File` objects before invocation, enabling models to reference uploaded files safely.
|
||||
- **ReAct prompt shaping**
|
||||
- System prompts replace `{{instruction}}`, `{{tools}}`, and `{{tool_names}}` placeholders.
|
||||
- Adds `Observation` to stop sequences and appends scratchpad text so the model sees prior Thought/Action/Observation history.
|
||||
- **Observability and accounting**
|
||||
- Standardized `AgentLog` entries for rounds, model thoughts, and tool calls, including usage aggregation (`LLMUsage`) across streaming and non-streaming paths.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
agent/patterns/
|
||||
├── base.py # Shared utilities: logging, usage, tool invocation, file handling
|
||||
├── function_call.py # Native function-calling loop with tool execution
|
||||
├── react.py # ReAct loop with CoT parsing and scratchpad wiring
|
||||
└── strategy_factory.py # Strategy selection by model features or explicit override
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
- For auto-selection:
|
||||
- Call `StrategyFactory.create_strategy(model_features, model_instance, context, tools, files, ...)` and run the returned strategy with prompt messages and model params.
|
||||
- For explicit behavior:
|
||||
- Pass `agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING` to force native calls (falls back to ReAct if unsupported), or `CHAIN_OF_THOUGHT` to force ReAct.
|
||||
- Both strategies stream chunks and logs; collect the generator output until it returns an `AgentResult`.
|
||||
|
||||
## Integration Points
|
||||
|
||||
- **Model runtime**: delegates to `ModelInstance.invoke_llm` for both streaming and non-streaming calls.
|
||||
- **Tool system**: defaults to `ToolEngine.generic_invoke`, with `tool_invoke_hook` for custom callers.
|
||||
- **Files**: flows through `File` objects for tool inputs/outputs and model-context attachments.
|
||||
- **Execution context**: `ExecutionContext` fields (user/app/conversation/message) propagate to tool invocations and logging.
|
||||
19
api/core/agent/patterns/__init__.py
Normal file
19
api/core/agent/patterns/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
"""Agent patterns module.
|
||||
|
||||
This module provides different strategies for agent execution:
|
||||
- FunctionCallStrategy: Uses native function/tool calling
|
||||
- ReActStrategy: Uses ReAct (Reasoning + Acting) approach
|
||||
- StrategyFactory: Factory for creating strategies based on model features
|
||||
"""
|
||||
|
||||
from .base import AgentPattern
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
from .strategy_factory import StrategyFactory
|
||||
|
||||
__all__ = [
|
||||
"AgentPattern",
|
||||
"FunctionCallStrategy",
|
||||
"ReActStrategy",
|
||||
"StrategyFactory",
|
||||
]
|
||||
474
api/core/agent/patterns/base.py
Normal file
474
api/core/agent/patterns/base.py
Normal file
@ -0,0 +1,474 @@
|
||||
"""Base class for agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
# Type alias for tool invoke hook
|
||||
# Returns: (response_content, message_file_ids, tool_invoke_meta)
|
||||
ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]]
|
||||
|
||||
|
||||
class AgentPattern(ABC):
|
||||
"""Base class for agent execution strategies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
):
|
||||
"""Initialize the agent strategy."""
|
||||
self.model_instance = model_instance
|
||||
self.tools = tools
|
||||
self.context = context
|
||||
self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.files: list[File] = files
|
||||
self.tool_invoke_hook = tool_invoke_hook
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the agent strategy."""
|
||||
pass
|
||||
|
||||
def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None:
|
||||
"""Accumulate LLM usage statistics."""
|
||||
if not total_usage.get("usage"):
|
||||
# Create a copy to avoid modifying the original
|
||||
total_usage["usage"] = LLMUsage(
|
||||
prompt_tokens=delta_usage.prompt_tokens,
|
||||
prompt_unit_price=delta_usage.prompt_unit_price,
|
||||
prompt_price_unit=delta_usage.prompt_price_unit,
|
||||
prompt_price=delta_usage.prompt_price,
|
||||
completion_tokens=delta_usage.completion_tokens,
|
||||
completion_unit_price=delta_usage.completion_unit_price,
|
||||
completion_price_unit=delta_usage.completion_price_unit,
|
||||
completion_price=delta_usage.completion_price,
|
||||
total_tokens=delta_usage.total_tokens,
|
||||
total_price=delta_usage.total_price,
|
||||
currency=delta_usage.currency,
|
||||
latency=delta_usage.latency,
|
||||
)
|
||||
else:
|
||||
current: LLMUsage = total_usage["usage"]
|
||||
current.prompt_tokens += delta_usage.prompt_tokens
|
||||
current.completion_tokens += delta_usage.completion_tokens
|
||||
current.total_tokens += delta_usage.total_tokens
|
||||
current.prompt_price += delta_usage.prompt_price
|
||||
current.completion_price += delta_usage.completion_price
|
||||
current.total_price += delta_usage.total_price
|
||||
|
||||
def _extract_content(self, content: Any) -> str:
|
||||
"""Extract text content from message content."""
|
||||
if isinstance(content, list):
|
||||
# Content items are PromptMessageContentUnionTypes
|
||||
text_parts = []
|
||||
for c in content:
|
||||
# Check if it's a TextPromptMessageContent (which has data attribute)
|
||||
if isinstance(c, TextPromptMessageContent):
|
||||
text_parts.append(c.data)
|
||||
return "".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
def _has_tool_calls(self, chunk: LLMResultChunk) -> bool:
|
||||
"""Check if chunk contains tool calls."""
|
||||
# LLMResultChunk always has delta attribute
|
||||
return bool(chunk.delta.message and chunk.delta.message.tool_calls)
|
||||
|
||||
def _has_tool_calls_result(self, result: LLMResult) -> bool:
|
||||
"""Check if result contains tool calls (non-streaming)."""
|
||||
# LLMResult always has message attribute
|
||||
return bool(result.message and result.message.tool_calls)
|
||||
|
||||
def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from streaming chunk."""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
if chunk.delta.message and chunk.delta.message.tool_calls:
|
||||
for tool_call in chunk.delta.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from non-streaming result."""
|
||||
tool_calls = []
|
||||
if result.message and result.message.tool_calls:
|
||||
for tool_call in result.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_text_from_message(self, message: PromptMessage) -> str:
|
||||
"""Extract text content from a prompt message."""
|
||||
# PromptMessage always has content attribute
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from content list
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
text_parts.append(item.data)
|
||||
return " ".join(text_parts)
|
||||
return ""
|
||||
|
||||
def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]:
|
||||
"""Get metadata for a tool including provider and icon info."""
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {}
|
||||
if tool_instance.entity and tool_instance.entity.identity:
|
||||
identity = tool_instance.entity.identity
|
||||
if identity.provider:
|
||||
metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider
|
||||
|
||||
# Get icon using ToolManager for proper URL generation
|
||||
tenant_id = self.context.tenant_id
|
||||
if tenant_id and identity.provider:
|
||||
try:
|
||||
provider_type = tool_instance.tool_provider_type()
|
||||
icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider)
|
||||
if isinstance(icon, str):
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
elif isinstance(icon, dict):
|
||||
# Handle icon dict with background/content or light/dark variants
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
except Exception:
|
||||
# Fallback to identity.icon if ToolManager fails
|
||||
if identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
elif identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
return metadata
|
||||
|
||||
def _create_log(
|
||||
self,
|
||||
label: str,
|
||||
log_type: AgentLog.LogType,
|
||||
status: AgentLog.LogStatus,
|
||||
data: dict[str, Any] | None = None,
|
||||
parent_id: str | None = None,
|
||||
extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None,
|
||||
) -> AgentLog:
|
||||
"""Create a new AgentLog with standard metadata."""
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {
|
||||
AgentLog.LogMetadata.STARTED_AT: time.perf_counter(),
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
|
||||
return AgentLog(
|
||||
label=label,
|
||||
log_type=log_type,
|
||||
status=status,
|
||||
data=data or {},
|
||||
parent_id=parent_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _finish_log(
|
||||
self,
|
||||
log: AgentLog,
|
||||
data: dict[str, Any] | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
) -> AgentLog:
|
||||
"""Finish an AgentLog by updating its status and metadata."""
|
||||
log.status = AgentLog.LogStatus.SUCCESS
|
||||
|
||||
if data is not None:
|
||||
log.data = data
|
||||
|
||||
# Calculate elapsed time
|
||||
started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter())
|
||||
finished_at = time.perf_counter()
|
||||
|
||||
# Update metadata
|
||||
log.metadata = {
|
||||
**log.metadata,
|
||||
AgentLog.LogMetadata.FINISHED_AT: finished_at,
|
||||
# Calculate elapsed time in seconds
|
||||
AgentLog.LogMetadata.ELAPSED_TIME: round(finished_at - started_at, 4),
|
||||
}
|
||||
|
||||
# Add usage information if provided
|
||||
if usage:
|
||||
log.metadata.update(
|
||||
{
|
||||
AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price,
|
||||
AgentLog.LogMetadata.CURRENCY: usage.currency,
|
||||
AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens,
|
||||
AgentLog.LogMetadata.LLM_USAGE: usage,
|
||||
}
|
||||
)
|
||||
|
||||
return log
|
||||
|
||||
def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Replace file references in tool arguments with actual File objects.
|
||||
|
||||
Args:
|
||||
tool_args: Dictionary of tool arguments
|
||||
|
||||
Returns:
|
||||
Updated tool arguments with file references replaced
|
||||
"""
|
||||
# Process each argument in the dictionary
|
||||
processed_args: dict[str, Any] = {}
|
||||
for key, value in tool_args.items():
|
||||
processed_args[key] = self._process_file_reference(value)
|
||||
return processed_args
|
||||
|
||||
def _process_file_reference(self, data: Any) -> Any:
|
||||
"""
|
||||
Recursively process data to replace file references.
|
||||
Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...].
|
||||
|
||||
Args:
|
||||
data: The data to process (can be dict, list, str, or other types)
|
||||
|
||||
Returns:
|
||||
Processed data with file references replaced
|
||||
"""
|
||||
single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$")
|
||||
multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$")
|
||||
|
||||
if isinstance(data, dict):
|
||||
# Process dictionary recursively
|
||||
return {key: self._process_file_reference(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
# Process list recursively
|
||||
return [self._process_file_reference(item) for item in data]
|
||||
elif isinstance(data, str):
|
||||
# Check for single file pattern [File: file_id]
|
||||
single_match = single_file_pattern.match(data.strip())
|
||||
if single_match:
|
||||
file_id = single_match.group(1).strip()
|
||||
# Find the file in self.files
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
return file
|
||||
# If file not found, return original value
|
||||
return data
|
||||
|
||||
# Check for multiple files pattern [Files: file_id1, file_id2, ...]
|
||||
multiple_match = multiple_files_pattern.match(data.strip())
|
||||
if multiple_match:
|
||||
file_ids_str = multiple_match.group(1).strip()
|
||||
# Split by comma and strip whitespace
|
||||
file_ids = [fid.strip() for fid in file_ids_str.split(",")]
|
||||
|
||||
# Find all matching files
|
||||
matched_files: list[File] = []
|
||||
for file_id in file_ids:
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
matched_files.append(file)
|
||||
break
|
||||
|
||||
# Return list of files if any were found, otherwise return original
|
||||
return matched_files or data
|
||||
|
||||
return data
|
||||
else:
|
||||
# Return other types as-is
|
||||
return data
|
||||
|
||||
def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk:
|
||||
"""Create a text chunk for streaming."""
|
||||
return LLMResultChunk(
|
||||
model=self.model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=None,
|
||||
),
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
def _invoke_tool(
|
||||
self,
|
||||
tool_instance: Tool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> tuple[str, list[File], ToolInvokeMeta | None]:
|
||||
"""
|
||||
Invoke a tool and collect its response.
|
||||
|
||||
Args:
|
||||
tool_instance: The tool instance to invoke
|
||||
tool_args: Tool arguments
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Tuple of (response_content, tool_files, tool_invoke_meta)
|
||||
"""
|
||||
# Process tool_args to replace file references with actual File objects
|
||||
tool_args = self._replace_file_references(tool_args)
|
||||
|
||||
# If a tool invoke hook is set, use it instead of generic_invoke
|
||||
if self.tool_invoke_hook:
|
||||
response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name)
|
||||
# Note: message_file_ids are stored in DB, we don't convert them to File objects here
|
||||
# The caller (AgentAppRunner) handles file publishing
|
||||
return response_content, [], tool_invoke_meta
|
||||
|
||||
# Default: use generic_invoke for workflow scenarios
|
||||
# Import here to avoid circular import
|
||||
from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine
|
||||
|
||||
tool_response = ToolEngine().generic_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.context.user_id or "",
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=self.context.app_id,
|
||||
conversation_id=self.context.conversation_id,
|
||||
message_id=self.context.message_id,
|
||||
)
|
||||
|
||||
# Collect response and files
|
||||
response_content = ""
|
||||
tool_files: list[File] = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
response_content += response.message.text
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# Handle link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Link: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# Handle image URL messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK:
|
||||
# Handle image link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK:
|
||||
# Handle binary file link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
filename = response.meta.get("filename", "file") if response.meta else "file"
|
||||
response_content += f"[File: {filename} - {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
# Handle JSON messages
|
||||
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
|
||||
response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2)
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# Handle blob messages - convert to text representation
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobMessage):
|
||||
mime_type = (
|
||||
response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream"
|
||||
)
|
||||
size = len(response.message.blob)
|
||||
response_content += f"[Binary data: {mime_type}, size: {size} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
# Handle variable messages
|
||||
if isinstance(response.message, ToolInvokeMessage.VariableMessage):
|
||||
var_name = response.message.variable_name
|
||||
var_value = response.message.variable_value
|
||||
if isinstance(var_value, str):
|
||||
response_content += var_value
|
||||
else:
|
||||
response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
|
||||
# Handle blob chunk messages - these are parts of a larger blob
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage):
|
||||
response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
|
||||
# Handle retriever resources messages
|
||||
if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage):
|
||||
response_content += response.message.context
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.FILE:
|
||||
# Extract file from meta
|
||||
if response.meta and "file" in response.meta:
|
||||
file = response.meta["file"]
|
||||
if isinstance(file, File):
|
||||
# Check if file is for model or tool output
|
||||
if response.meta.get("target") == "self":
|
||||
# File is for model - add to files for next prompt
|
||||
self.files.append(file)
|
||||
response_content += f"File '{file.filename}' has been loaded into your context."
|
||||
else:
|
||||
# File is tool output
|
||||
tool_files.append(file)
|
||||
|
||||
return response_content, tool_files, None
|
||||
|
||||
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
|
||||
"""Find a tool instance by its name."""
|
||||
for tool in self.tools:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]:
|
||||
"""Convert tools to prompt message format."""
|
||||
prompt_tools: list[PromptMessageTool] = []
|
||||
for tool in self.tools:
|
||||
prompt_tools.append(tool.to_prompt_message_tool())
|
||||
return prompt_tools
|
||||
|
||||
def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None:
|
||||
"""Initialize usage tracking with empty usage if not set."""
|
||||
if "usage" not in llm_usage or llm_usage["usage"] is None:
|
||||
llm_usage["usage"] = LLMUsage.empty_usage()
|
||||
299
api/core/agent/patterns/function_call.py
Normal file
299
api/core/agent/patterns/function_call.py
Normal file
@ -0,0 +1,299 @@
|
||||
"""Function Call strategy implementation."""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
from .base import AgentPattern
|
||||
|
||||
|
||||
class FunctionCallStrategy(AgentPattern):
|
||||
"""Function Call strategy using model's native tool calling capability."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the function call agent strategy."""
|
||||
# Convert tools to prompt format
|
||||
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
|
||||
|
||||
# Initialize tracking
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
function_call_state: bool = True
|
||||
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
|
||||
while function_call_state and iteration_step <= max_iterations:
|
||||
function_call_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, remove tools to force final answer
|
||||
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=current_tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=self.context.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log
|
||||
)
|
||||
messages.append(self._create_assistant_message(response_content, tool_calls))
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update final text if no tool calls (this is likely the final answer)
|
||||
if not tool_calls:
|
||||
final_text = response_content
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Process tool calls
|
||||
tool_outputs: dict[str, str] = {}
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
# Execute tools
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"llm_result": response_content,
|
||||
"tool_calls": [
|
||||
{"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
"final_answer": final_text if not function_call_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, LLMUsage | None],
|
||||
start_log: AgentLog,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[list[tuple[str, str, dict[str, Any]]], str, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract tool calls and content.
|
||||
|
||||
Returns a tuple of (tool_calls, response_content, finish_reason).
|
||||
"""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
response_content: str = ""
|
||||
finish_reason: str | None = None
|
||||
if isinstance(chunks, Generator):
|
||||
# Streaming response
|
||||
for chunk in chunks:
|
||||
# Extract tool calls
|
||||
if self._has_tool_calls(chunk):
|
||||
tool_calls.extend(self._extract_tool_calls(chunk))
|
||||
|
||||
# Extract content
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
response_content += self._extract_content(chunk.delta.message.content)
|
||||
|
||||
# Track usage
|
||||
if chunk.delta.usage:
|
||||
self._accumulate_usage(llm_usage, chunk.delta.usage)
|
||||
|
||||
# Capture finish reason
|
||||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
# Non-streaming response
|
||||
result: LLMResult = chunks
|
||||
|
||||
if self._has_tool_calls_result(result):
|
||||
tool_calls.extend(self._extract_tool_calls_result(result))
|
||||
|
||||
if result.message and result.message.content:
|
||||
response_content += self._extract_content(result.message.content)
|
||||
|
||||
if result.usage:
|
||||
self._accumulate_usage(llm_usage, result.usage)
|
||||
|
||||
# Convert to streaming format
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
yield self._finish_log(
|
||||
start_log,
|
||||
data={
|
||||
"result": response_content,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
return tool_calls, response_content, finish_reason
|
||||
|
||||
def _create_assistant_message(
|
||||
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
||||
) -> AssistantPromptMessage:
|
||||
"""Create assistant message with tool calls."""
|
||||
if tool_calls is None:
|
||||
return AssistantPromptMessage(content=content)
|
||||
return AssistantPromptMessage(
|
||||
content=content or "",
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tc[0],
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])),
|
||||
)
|
||||
for tc in tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
tool_call_id: str,
|
||||
messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None]]:
|
||||
"""Handle a single tool call and return response with files and meta."""
|
||||
# Find tool
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
if not tool_instance:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
# Get tool metadata (provider, icon, etc.)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance)
|
||||
|
||||
# Create tool call log
|
||||
tool_call_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_call_log
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
|
||||
|
||||
yield self._finish_log(
|
||||
tool_call_log,
|
||||
data={
|
||||
**tool_call_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
final_content = response_content or "Tool executed successfully"
|
||||
# Add tool response to messages
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=final_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return response_content, tool_files, tool_invoke_meta
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = error_message
|
||||
tool_call_log.data = {
|
||||
**tool_call_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_call_log
|
||||
|
||||
# Add error message to conversation
|
||||
error_content = f"Tool execution failed: {error_message}"
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=error_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return error_content, [], None
|
||||
418
api/core/agent/patterns/react.py
Normal file
418
api/core/agent/patterns/react.py
Normal file
@ -0,0 +1,418 @@
|
||||
"""ReAct strategy implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
)
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class ReActStrategy(AgentPattern):
|
||||
"""ReAct strategy using reasoning and acting approach."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
):
|
||||
"""Initialize the ReAct strategy with instruction support."""
|
||||
super().__init__(
|
||||
model_instance=model_instance,
|
||||
tools=tools,
|
||||
context=context,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
files=files,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
self.instruction = instruction
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the ReAct agent strategy."""
|
||||
# Initialize tracking
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
react_state: bool = True
|
||||
total_usage: dict[str, Any] = {"usage": None}
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Add "Observation" to stop sequences
|
||||
if "Observation" not in stop:
|
||||
stop = stop.copy()
|
||||
stop.append("Observation")
|
||||
|
||||
while react_state and iteration_step <= max_iterations:
|
||||
react_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
|
||||
# Build prompt with/without tools based on iteration
|
||||
include_tools = iteration_step < max_iterations
|
||||
current_messages = self._build_prompt_with_react_format(
|
||||
prompt_messages, agent_scratchpad, include_tools, self.instruction
|
||||
)
|
||||
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, Any] = {"usage": None}
|
||||
|
||||
# Use current messages directly (files are handled by base class if needed)
|
||||
messages_to_use = current_messages
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=self.context.user_id or "",
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log, current_messages
|
||||
)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Check if we have an action to execute
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
react_state = True
|
||||
# Execute tool
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
# Add observation to scratchpad for display
|
||||
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
|
||||
else:
|
||||
# Extract final answer
|
||||
if scratchpad.action and scratchpad.action.action_input:
|
||||
final_answer = scratchpad.action.action_input
|
||||
if isinstance(final_answer, dict):
|
||||
final_answer = json.dumps(final_answer, ensure_ascii=False)
|
||||
final_text = str(final_answer)
|
||||
elif scratchpad.thought:
|
||||
# If no action but we have thought, use thought as final answer
|
||||
final_text = scratchpad.thought
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
"observation": scratchpad.observation or None,
|
||||
"final_answer": final_text if not react_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
|
||||
)
|
||||
|
||||
def _build_prompt_with_react_format(
|
||||
self,
|
||||
original_messages: list[PromptMessage],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
include_tools: bool = True,
|
||||
instruction: str = "",
|
||||
) -> list[PromptMessage]:
|
||||
"""Build prompt messages with ReAct format."""
|
||||
# Copy messages to avoid modifying original
|
||||
messages = list(original_messages)
|
||||
|
||||
# Find and update the system prompt that should already exist
|
||||
system_prompt_found = False
|
||||
for i, msg in enumerate(messages):
|
||||
if isinstance(msg, SystemPromptMessage):
|
||||
system_prompt_found = True
|
||||
# The system prompt from frontend already has the template, just replace placeholders
|
||||
|
||||
# Format tools
|
||||
tools_str = ""
|
||||
tool_names = []
|
||||
if include_tools and self.tools:
|
||||
# Convert tools to prompt message tools format
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
|
||||
tool_names = [tool.name for tool in prompt_tools]
|
||||
|
||||
# Format tools as JSON for comprehensive information
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2)
|
||||
tool_names_str = ", ".join(f'"{name}"' for name in tool_names)
|
||||
else:
|
||||
tools_str = "No tools available"
|
||||
tool_names_str = ""
|
||||
|
||||
# Replace placeholders in the existing system prompt
|
||||
updated_content = msg.content
|
||||
assert isinstance(updated_content, str)
|
||||
updated_content = updated_content.replace("{{instruction}}", instruction)
|
||||
updated_content = updated_content.replace("{{tools}}", tools_str)
|
||||
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
|
||||
|
||||
# Create new SystemPromptMessage with updated content
|
||||
messages[i] = SystemPromptMessage(content=updated_content)
|
||||
break
|
||||
|
||||
# If no system prompt found, that's unexpected but add scratchpad anyway
|
||||
if not system_prompt_found:
|
||||
# This shouldn't happen if frontend is working correctly
|
||||
pass
|
||||
|
||||
# Format agent scratchpad
|
||||
scratchpad_str = ""
|
||||
if agent_scratchpad:
|
||||
scratchpad_parts: list[str] = []
|
||||
for unit in agent_scratchpad:
|
||||
if unit.thought:
|
||||
scratchpad_parts.append(f"Thought: {unit.thought}")
|
||||
if unit.action_str:
|
||||
scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```")
|
||||
if unit.observation:
|
||||
scratchpad_parts.append(f"Observation: {unit.observation}")
|
||||
scratchpad_str = "\n".join(scratchpad_parts)
|
||||
|
||||
# If there's a scratchpad, append it to the last message
|
||||
if scratchpad_str:
|
||||
messages.append(AssistantPromptMessage(content=scratchpad_str))
|
||||
|
||||
return messages
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, Any],
|
||||
model_log: AgentLog,
|
||||
current_messages: list[PromptMessage],
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[AgentScratchpadUnit, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract action/thought.
|
||||
|
||||
Returns a tuple of (scratchpad_unit, finish_reason).
|
||||
"""
|
||||
usage_dict: dict[str, Any] = {}
|
||||
|
||||
# Convert non-streaming to streaming format if needed
|
||||
if isinstance(chunks, LLMResult):
|
||||
# Create a generator from the LLMResult
|
||||
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=chunks.model,
|
||||
prompt_messages=chunks.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=chunks.message,
|
||||
usage=chunks.usage,
|
||||
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
|
||||
),
|
||||
system_fingerprint=chunks.system_fingerprint or "",
|
||||
)
|
||||
|
||||
streaming_chunks = result_to_chunks()
|
||||
else:
|
||||
streaming_chunks = chunks
|
||||
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
|
||||
|
||||
# Initialize scratchpad unit
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Process chunks
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
# Action detected
|
||||
action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + action_str
|
||||
scratchpad.action_str = action_str
|
||||
scratchpad.action = chunk
|
||||
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
else:
|
||||
# Text chunk
|
||||
chunk_text = str(chunk)
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
||||
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
|
||||
# Update usage
|
||||
if usage_dict.get("usage"):
|
||||
if llm_usage.get("usage"):
|
||||
self._accumulate_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
llm_usage["usage"] = usage_dict["usage"]
|
||||
|
||||
# Clean up thought
|
||||
scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you"
|
||||
|
||||
# Finish model log
|
||||
yield self._finish_log(
|
||||
model_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
|
||||
return scratchpad, finish_reason
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
prompt_messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File]]]:
|
||||
"""Handle tool call and return observation with files."""
|
||||
tool_name = action.action_name
|
||||
tool_args: dict[str, Any] | str = action.action_input
|
||||
|
||||
# Find tool instance first to get metadata
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {}
|
||||
|
||||
# Start tool log with tool metadata
|
||||
tool_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_log
|
||||
|
||||
if not tool_instance:
|
||||
# Finish tool log with error
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"error": f"Tool {tool_name} not found",
|
||||
},
|
||||
)
|
||||
return f"Tool {tool_name} not found", []
|
||||
|
||||
# Ensure tool_args is a dict
|
||||
tool_args_dict: dict[str, Any]
|
||||
if isinstance(tool_args, str):
|
||||
try:
|
||||
tool_args_dict = json.loads(tool_args)
|
||||
except json.JSONDecodeError:
|
||||
tool_args_dict = {"input": tool_args}
|
||||
elif not isinstance(tool_args, dict):
|
||||
tool_args_dict = {"input": str(tool_args)}
|
||||
else:
|
||||
tool_args_dict = tool_args
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
|
||||
|
||||
# Finish tool log
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
|
||||
return response_content or "Tool executed successfully", tool_files
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_log.error = error_message
|
||||
tool_log.data = {
|
||||
**tool_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_log
|
||||
|
||||
return f"Tool execution failed: {error_message}", []
|
||||
107
api/core/agent/patterns/strategy_factory.py
Normal file
107
api/core/agent/patterns/strategy_factory.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""Strategy factory for creating agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.file.models import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
"""Factory for creating agent strategies based on model features."""
|
||||
|
||||
# Tool calling related features
|
||||
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
|
||||
|
||||
@staticmethod
|
||||
def create_strategy(
|
||||
model_features: list[ModelFeature],
|
||||
model_instance: ModelInstance,
|
||||
context: ExecutionContext,
|
||||
tools: list[Tool],
|
||||
files: list[File],
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
agent_strategy: AgentEntity.Strategy | None = None,
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
) -> AgentPattern:
|
||||
"""
|
||||
Create an appropriate strategy based on model features.
|
||||
|
||||
Args:
|
||||
model_features: List of model features/capabilities
|
||||
model_instance: Model instance to use
|
||||
context: Execution context containing trace/audit information
|
||||
tools: Available tools
|
||||
files: Available files
|
||||
max_iterations: Maximum iterations for the strategy
|
||||
workflow_call_depth: Depth of workflow calls
|
||||
agent_strategy: Optional explicit strategy override
|
||||
tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke)
|
||||
instruction: Optional instruction for ReAct strategy
|
||||
|
||||
Returns:
|
||||
AgentStrategy instance
|
||||
"""
|
||||
# If explicit strategy is provided and it's Function Calling, try to use it if supported
|
||||
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
# Fallback to ReAct if FC is requested but not supported
|
||||
|
||||
# If explicit strategy is Chain of Thought (ReAct)
|
||||
if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Default auto-selection logic
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
# Model supports native function calling
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
else:
|
||||
# Use ReAct strategy for models without function calling
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
@ -24,11 +24,13 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.app.layers.sandbox_layer import SandboxLayer
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.sandbox import Sandbox, SandboxManager
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaverFactory,
|
||||
)
|
||||
@ -40,7 +42,9 @@ from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow_features import WorkflowFeatures
|
||||
from services.conversation_service import ConversationService
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVarLoader,
|
||||
WorkflowDraftVariableService,
|
||||
@ -512,6 +516,31 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
if workflow is None:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
sandbox: Sandbox | None = None
|
||||
graph_engine_layers: tuple = ()
|
||||
if workflow.get_feature(WorkflowFeatures.SANDBOX).enabled:
|
||||
sandbox_provider = SandboxProviderService.get_sandbox_provider(
|
||||
application_generate_entity.app_config.tenant_id
|
||||
)
|
||||
if workflow.version == Workflow.VERSION_DRAFT:
|
||||
sandbox = SandboxManager.create_draft(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
sandbox_provider=sandbox_provider,
|
||||
)
|
||||
else:
|
||||
if application_generate_entity.workflow_run_id is None:
|
||||
raise ValueError("workflow_run_id is required when sandbox is enabled")
|
||||
sandbox = SandboxManager.create(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
workflow_execution_id=application_generate_entity.workflow_run_id,
|
||||
sandbox_provider=sandbox_provider,
|
||||
)
|
||||
graph_engine_layers = (SandboxLayer(sandbox=sandbox),)
|
||||
|
||||
# Determine system_user_id based on invocation source
|
||||
is_external_api_call = application_generate_entity.invoke_from in {
|
||||
InvokeFrom.WEB_APP,
|
||||
@ -542,6 +571,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app=app,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
sandbox=sandbox,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -24,6 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari
|
||||
from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.sandbox import Sandbox
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
@ -66,6 +67,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
sandbox: Sandbox | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@ -82,6 +84,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
self._app = app
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._sandbox = sandbox
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
@ -156,6 +159,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
# init graph
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
|
||||
|
||||
if self._sandbox:
|
||||
graph_runtime_state.set_sandbox(self._sandbox)
|
||||
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@ -120,6 +120,6 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -4,6 +4,7 @@ import re
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Any, Union
|
||||
|
||||
@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
ChunkType,
|
||||
MessageQueueMessage,
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAgentLogEvent,
|
||||
@ -70,13 +72,131 @@ from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models import Account, Conversation, EndUser, LLMGenerationDetail, Message, MessageFile
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamEventBuffer:
|
||||
"""
|
||||
Buffer for recording stream events in order to reconstruct the generation sequence.
|
||||
Records the exact order of text chunks, thoughts, and tool calls as they stream.
|
||||
"""
|
||||
|
||||
# Accumulated reasoning content (each thought block is a separate element)
|
||||
reasoning_content: list[str] = field(default_factory=list)
|
||||
# Current reasoning buffer (accumulates until we see a different event type)
|
||||
_current_reasoning: str = ""
|
||||
# Tool calls with their details
|
||||
tool_calls: list[dict] = field(default_factory=list)
|
||||
# Tool call ID to index mapping for updating results
|
||||
_tool_call_id_map: dict[str, int] = field(default_factory=dict)
|
||||
# Sequence of events in stream order
|
||||
sequence: list[dict] = field(default_factory=list)
|
||||
# Current position in answer text
|
||||
_content_position: int = 0
|
||||
# Track last event type to detect transitions
|
||||
_last_event_type: str | None = None
|
||||
|
||||
def _flush_current_reasoning(self) -> None:
|
||||
"""Flush accumulated reasoning to the list and add to sequence."""
|
||||
if self._current_reasoning.strip():
|
||||
self.reasoning_content.append(self._current_reasoning.strip())
|
||||
self.sequence.append({"type": "reasoning", "index": len(self.reasoning_content) - 1})
|
||||
self._current_reasoning = ""
|
||||
|
||||
def record_text_chunk(self, text: str) -> None:
|
||||
"""Record a text chunk event."""
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Flush any pending reasoning first
|
||||
if self._last_event_type == "thought":
|
||||
self._flush_current_reasoning()
|
||||
|
||||
text_len = len(text)
|
||||
start_pos = self._content_position
|
||||
|
||||
# If last event was also content, extend it; otherwise create new
|
||||
if self.sequence and self.sequence[-1].get("type") == "content":
|
||||
self.sequence[-1]["end"] = start_pos + text_len
|
||||
else:
|
||||
self.sequence.append({"type": "content", "start": start_pos, "end": start_pos + text_len})
|
||||
|
||||
self._content_position += text_len
|
||||
self._last_event_type = "content"
|
||||
|
||||
def record_thought_chunk(self, text: str) -> None:
|
||||
"""Record a thought/reasoning chunk event."""
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Accumulate thought content
|
||||
self._current_reasoning += text
|
||||
self._last_event_type = "thought"
|
||||
|
||||
def record_tool_call(
|
||||
self,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
tool_arguments: str,
|
||||
tool_icon: str | dict | None = None,
|
||||
tool_icon_dark: str | dict | None = None,
|
||||
) -> None:
|
||||
"""Record a tool call event."""
|
||||
if not tool_call_id:
|
||||
return
|
||||
|
||||
# Flush any pending reasoning first
|
||||
if self._last_event_type == "thought":
|
||||
self._flush_current_reasoning()
|
||||
|
||||
# Check if this tool call already exists (we might get multiple chunks)
|
||||
if tool_call_id in self._tool_call_id_map:
|
||||
idx = self._tool_call_id_map[tool_call_id]
|
||||
# Update arguments if provided
|
||||
if tool_arguments:
|
||||
self.tool_calls[idx]["arguments"] = tool_arguments
|
||||
else:
|
||||
# New tool call
|
||||
tool_call = {
|
||||
"id": tool_call_id or "",
|
||||
"name": tool_name or "",
|
||||
"arguments": tool_arguments or "",
|
||||
"result": "",
|
||||
"elapsed_time": None,
|
||||
"icon": tool_icon,
|
||||
"icon_dark": tool_icon_dark,
|
||||
}
|
||||
self.tool_calls.append(tool_call)
|
||||
idx = len(self.tool_calls) - 1
|
||||
self._tool_call_id_map[tool_call_id] = idx
|
||||
self.sequence.append({"type": "tool_call", "index": idx})
|
||||
|
||||
self._last_event_type = "tool_call"
|
||||
|
||||
def record_tool_result(self, tool_call_id: str, result: str, tool_elapsed_time: float | None = None) -> None:
|
||||
"""Record a tool result event (update existing tool call)."""
|
||||
if not tool_call_id:
|
||||
return
|
||||
if tool_call_id in self._tool_call_id_map:
|
||||
idx = self._tool_call_id_map[tool_call_id]
|
||||
self.tool_calls[idx]["result"] = result
|
||||
self.tool_calls[idx]["elapsed_time"] = tool_elapsed_time
|
||||
|
||||
def finalize(self) -> None:
|
||||
"""Finalize the buffer, flushing any pending data."""
|
||||
if self._last_event_type == "thought":
|
||||
self._flush_current_reasoning()
|
||||
|
||||
def has_data(self) -> bool:
|
||||
"""Check if there's any meaningful data recorded."""
|
||||
return bool(self.reasoning_content or self.tool_calls or self.sequence)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
@ -144,6 +264,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._workflow_run_id: str = ""
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
self._graph_runtime_state: GraphRuntimeState | None = None
|
||||
# Stream event buffer for recording generation sequence
|
||||
self._stream_buffer = StreamEventBuffer()
|
||||
self._seed_graph_runtime_state_from_queue_manager()
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
@ -383,7 +505,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle text chunk events."""
|
||||
"""Handle text chunk events and record to stream buffer for sequence reconstruction."""
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
return
|
||||
@ -405,9 +527,53 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
tool_call = event.tool_call
|
||||
tool_result = event.tool_result
|
||||
tool_payload = tool_call or tool_result
|
||||
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else ""
|
||||
tool_name = tool_payload.name if tool_payload and tool_payload.name else ""
|
||||
tool_arguments = tool_call.arguments if tool_call and tool_call.arguments else ""
|
||||
tool_files = tool_result.files if tool_result else []
|
||||
tool_elapsed_time = tool_result.elapsed_time if tool_result else None
|
||||
tool_icon = tool_payload.icon if tool_payload else None
|
||||
tool_icon_dark = tool_payload.icon_dark if tool_payload else None
|
||||
# Record stream event based on chunk type
|
||||
chunk_type = event.chunk_type or ChunkType.TEXT
|
||||
match chunk_type:
|
||||
case ChunkType.TEXT:
|
||||
self._stream_buffer.record_text_chunk(delta_text)
|
||||
self._task_state.answer += delta_text
|
||||
case ChunkType.THOUGHT:
|
||||
# Reasoning should not be part of final answer text
|
||||
self._stream_buffer.record_thought_chunk(delta_text)
|
||||
case ChunkType.TOOL_CALL:
|
||||
self._stream_buffer.record_tool_call(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
tool_icon=tool_icon,
|
||||
tool_icon_dark=tool_icon_dark,
|
||||
)
|
||||
case ChunkType.TOOL_RESULT:
|
||||
self._stream_buffer.record_tool_result(
|
||||
tool_call_id=tool_call_id,
|
||||
result=delta_text,
|
||||
tool_elapsed_time=tool_elapsed_time,
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
answer=delta_text,
|
||||
message_id=self._message_id,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
chunk_type=event.chunk_type.value if event.chunk_type else None,
|
||||
tool_call_id=tool_call_id or None,
|
||||
tool_name=tool_name or None,
|
||||
tool_arguments=tool_arguments or None,
|
||||
tool_files=tool_files,
|
||||
tool_elapsed_time=tool_elapsed_time,
|
||||
tool_icon=tool_icon,
|
||||
tool_icon_dark=tool_icon_dark,
|
||||
)
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
@ -775,6 +941,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
answer_text = self._task_state.answer
|
||||
answer_text = self._strip_think_blocks(answer_text)
|
||||
if self._recorded_files:
|
||||
# Remove markdown image links since we're storing files separately
|
||||
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
|
||||
@ -826,6 +993,54 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
# Save generation detail (reasoning/tool calls/sequence) from stream buffer
|
||||
self._save_generation_detail(session=session, message=message)
|
||||
|
||||
@staticmethod
|
||||
def _strip_think_blocks(text: str) -> str:
|
||||
"""Remove <think>...</think> blocks (including their content) from text."""
|
||||
if not text or "<think" not in text.lower():
|
||||
return text
|
||||
|
||||
clean_text = re.sub(r"<think[^>]*>.*?</think>", "", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||
return clean_text
|
||||
|
||||
def _save_generation_detail(self, *, session: Session, message: Message) -> None:
|
||||
"""
|
||||
Save LLM generation detail for Chatflow using stream event buffer.
|
||||
The buffer records the exact order of events as they streamed,
|
||||
allowing accurate reconstruction of the generation sequence.
|
||||
"""
|
||||
# Finalize the stream buffer to flush any pending data
|
||||
self._stream_buffer.finalize()
|
||||
|
||||
# Only save if there's meaningful data
|
||||
if not self._stream_buffer.has_data():
|
||||
return
|
||||
|
||||
reasoning_content = self._stream_buffer.reasoning_content
|
||||
tool_calls = self._stream_buffer.tool_calls
|
||||
sequence = self._stream_buffer.sequence
|
||||
|
||||
# Check if generation detail already exists for this message
|
||||
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
|
||||
|
||||
if existing:
|
||||
existing.reasoning_content = json.dumps(reasoning_content) if reasoning_content else None
|
||||
existing.tool_calls = json.dumps(tool_calls) if tool_calls else None
|
||||
existing.sequence = json.dumps(sequence) if sequence else None
|
||||
else:
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
reasoning_content=json.dumps(reasoning_content) if reasoning_content else None,
|
||||
tool_calls=json.dumps(tool_calls) if tool_calls else None,
|
||||
sequence=json.dumps(sequence) if sequence else None,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
@ -3,10 +3,8 @@ from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
@ -14,8 +12,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
@ -194,22 +191,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
# check LLM mode
|
||||
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
|
||||
runner_cls = CotChatAgentRunner
|
||||
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
|
||||
runner_cls = CotCompletionAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
|
||||
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
runner_cls = FunctionCallAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
|
||||
|
||||
runner = runner_cls(
|
||||
runner = AgentAppRunner(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation_result,
|
||||
|
||||
@ -81,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -109,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@ -117,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -81,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -109,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@ -117,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -70,6 +70,8 @@ class _NodeSnapshot:
|
||||
"""Empty string means the node is not executing inside an iteration."""
|
||||
loop_id: str = ""
|
||||
"""Empty string means the node is not executing inside a loop."""
|
||||
mention_parent_id: str = ""
|
||||
"""Empty string means the node is not an extractor node."""
|
||||
|
||||
|
||||
class WorkflowResponseConverter:
|
||||
@ -131,6 +133,7 @@ class WorkflowResponseConverter:
|
||||
start_at=event.start_at,
|
||||
iteration_id=event.in_iteration_id or "",
|
||||
loop_id=event.in_loop_id or "",
|
||||
mention_parent_id=event.in_mention_parent_id or "",
|
||||
)
|
||||
node_execution_id = NodeExecutionId(event.node_execution_id)
|
||||
self._node_snapshots[node_execution_id] = snapshot
|
||||
@ -287,6 +290,7 @@ class WorkflowResponseConverter:
|
||||
created_at=int(snapshot.start_at.timestamp()),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
),
|
||||
)
|
||||
@ -373,6 +377,7 @@ class WorkflowResponseConverter:
|
||||
files=self.fetch_files_from_node_outputs(event.outputs or {}),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -422,6 +427,7 @@ class WorkflowResponseConverter:
|
||||
files=self.fetch_files_from_node_outputs(event.outputs or {}),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
retry_index=event.retry_index,
|
||||
),
|
||||
)
|
||||
@ -671,7 +677,7 @@ class WorkflowResponseConverter:
|
||||
task_id=task_id,
|
||||
data=AgentLogStreamResponse.Data(
|
||||
node_execution_id=event.node_execution_id,
|
||||
id=event.id,
|
||||
message_id=event.id,
|
||||
parent_id=event.parent_id,
|
||||
label=event.label,
|
||||
error=event.error,
|
||||
|
||||
@ -79,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -106,7 +106,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
@ -116,6 +116,6 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(cast(dict, data))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@ -23,11 +23,13 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.app.layers.sandbox_layer import SandboxLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.sandbox import Sandbox, SandboxManager
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
@ -38,6 +40,8 @@ from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow_features import WorkflowFeatures
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
|
||||
@ -488,6 +492,31 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
if workflow is None:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
sandbox: Sandbox | None = None
|
||||
if workflow.get_feature(WorkflowFeatures.SANDBOX).enabled:
|
||||
sandbox_provider = SandboxProviderService.get_sandbox_provider(
|
||||
application_generate_entity.app_config.tenant_id
|
||||
)
|
||||
if workflow.version == Workflow.VERSION_DRAFT:
|
||||
sandbox = SandboxManager.create_draft(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
sandbox_provider=sandbox_provider,
|
||||
)
|
||||
else:
|
||||
sandbox = SandboxManager.create(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
workflow_execution_id=application_generate_entity.workflow_execution_id,
|
||||
sandbox_provider=sandbox_provider,
|
||||
)
|
||||
graph_engine_layers = (
|
||||
*graph_engine_layers,
|
||||
SandboxLayer(sandbox=sandbox),
|
||||
)
|
||||
|
||||
# Determine system_user_id based on invocation source
|
||||
is_external_api_call = application_generate_entity.invoke_from in {
|
||||
InvokeFrom.WEB_APP,
|
||||
@ -512,6 +541,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
sandbox=sandbox,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -7,6 +7,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.sandbox import Sandbox
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
@ -42,6 +43,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
sandbox: Sandbox | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@ -55,6 +57,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
self._root_node_id = root_node_id
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._sandbox = sandbox
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
@ -99,6 +102,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
if self._sandbox:
|
||||
graph_runtime_state.set_sandbox(self._sandbox)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
|
||||
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@ -13,6 +13,7 @@ from core.app.apps.common.workflow_response_converter import WorkflowResponseCon
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
ChunkType,
|
||||
MessageQueueMessage,
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
@ -483,11 +484,33 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if delta_text is None:
|
||||
return
|
||||
|
||||
tool_call = event.tool_call
|
||||
tool_result = event.tool_result
|
||||
tool_payload = tool_call or tool_result
|
||||
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else None
|
||||
tool_name = tool_payload.name if tool_payload and tool_payload.name else None
|
||||
tool_arguments = tool_call.arguments if tool_call else None
|
||||
tool_elapsed_time = tool_result.elapsed_time if tool_result else None
|
||||
tool_files = tool_result.files if tool_result else []
|
||||
tool_icon = tool_payload.icon if tool_payload else None
|
||||
tool_icon_dark = tool_payload.icon_dark if tool_payload else None
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector)
|
||||
yield self._text_chunk_to_stream_response(
|
||||
text=delta_text,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
chunk_type=event.chunk_type,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
tool_files=tool_files,
|
||||
tool_elapsed_time=tool_elapsed_time,
|
||||
tool_icon=tool_icon,
|
||||
tool_icon_dark=tool_icon_dark,
|
||||
)
|
||||
|
||||
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle agent log events."""
|
||||
@ -650,16 +673,61 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
session.add(workflow_app_log)
|
||||
|
||||
def _text_chunk_to_stream_response(
|
||||
self, text: str, from_variable_selector: list[str] | None = None
|
||||
self,
|
||||
text: str,
|
||||
from_variable_selector: list[str] | None = None,
|
||||
chunk_type: ChunkType | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_arguments: str | None = None,
|
||||
tool_files: list[str] | None = None,
|
||||
tool_error: str | None = None,
|
||||
tool_elapsed_time: float | None = None,
|
||||
tool_icon: str | dict | None = None,
|
||||
tool_icon_dark: str | dict | None = None,
|
||||
) -> TextChunkStreamResponse:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
from core.app.entities.task_entities import ChunkType as ResponseChunkType
|
||||
|
||||
response_chunk_type = ResponseChunkType(chunk_type.value) if chunk_type else ResponseChunkType.TEXT
|
||||
|
||||
data = TextChunkStreamResponse.Data(
|
||||
text=text,
|
||||
from_variable_selector=from_variable_selector,
|
||||
chunk_type=response_chunk_type,
|
||||
)
|
||||
|
||||
if response_chunk_type == ResponseChunkType.TOOL_CALL:
|
||||
data = data.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
elif response_chunk_type == ResponseChunkType.TOOL_RESULT:
|
||||
data = data.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_files": tool_files,
|
||||
"tool_error": tool_error,
|
||||
"tool_elapsed_time": tool_elapsed_time,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
|
||||
response = TextChunkStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
|
||||
data=data,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -385,6 +385,7 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
@ -405,6 +406,7 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
@ -428,6 +430,7 @@ class WorkflowBasedAppRunner:
|
||||
execution_metadata=execution_metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
@ -444,6 +447,7 @@ class WorkflowBasedAppRunner:
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
@ -460,15 +464,25 @@ class WorkflowBasedAppRunner:
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
from core.app.entities.queue_entities import ChunkType as QueueChunkType
|
||||
|
||||
if event.is_final and not event.chunk:
|
||||
return
|
||||
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk,
|
||||
from_variable_selector=list(event.selector),
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
chunk_type=QueueChunkType(event.chunk_type.value),
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
@ -477,6 +491,7 @@ class WorkflowBasedAppRunner:
|
||||
retriever_resources=event.retriever_resources,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunAgentLogEvent):
|
||||
|
||||
294
api/core/app/entities/app_asset_entities.py
Normal file
294
api/core/app/entities/app_asset_entities.py
Normal file
@ -0,0 +1,294 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AssetNodeType(StrEnum):
|
||||
FILE = "file"
|
||||
FOLDER = "folder"
|
||||
|
||||
|
||||
class AppAssetNode(BaseModel):
|
||||
id: str = Field(description="Unique identifier for the node")
|
||||
node_type: AssetNodeType = Field(description="Type of node: file or folder")
|
||||
name: str = Field(description="Name of the file or folder")
|
||||
parent_id: str | None = Field(default=None, description="Parent folder ID, None for root level")
|
||||
order: int = Field(default=0, description="Sort order within parent folder, lower values first")
|
||||
extension: str = Field(default="", description="File extension without dot, empty for folders")
|
||||
size: int = Field(default=0, description="File size in bytes, 0 for folders")
|
||||
checksum: str = Field(default="", description="SHA-256 checksum of file content, empty for folders")
|
||||
|
||||
@classmethod
|
||||
def create_folder(cls, node_id: str, name: str, parent_id: str | None = None) -> AppAssetNode:
|
||||
return cls(id=node_id, node_type=AssetNodeType.FOLDER, name=name, parent_id=parent_id)
|
||||
|
||||
@classmethod
|
||||
def create_file(
|
||||
cls, node_id: str, name: str, parent_id: str | None = None, size: int = 0, checksum: str = ""
|
||||
) -> AppAssetNode:
|
||||
return cls(
|
||||
id=node_id,
|
||||
node_type=AssetNodeType.FILE,
|
||||
name=name,
|
||||
parent_id=parent_id,
|
||||
extension=name.rsplit(".", 1)[-1] if "." in name else "",
|
||||
size=size,
|
||||
checksum=checksum,
|
||||
)
|
||||
|
||||
|
||||
class AppAssetNodeView(BaseModel):
|
||||
id: str = Field(description="Unique identifier for the node")
|
||||
node_type: str = Field(description="Type of node: 'file' or 'folder'")
|
||||
name: str = Field(description="Name of the file or folder")
|
||||
path: str = Field(description="Full path from root, e.g. '/folder/file.txt'")
|
||||
extension: str = Field(default="", description="File extension without dot")
|
||||
size: int = Field(default=0, description="File size in bytes")
|
||||
checksum: str = Field(default="", description="SHA-256 checksum of file content")
|
||||
children: list[AppAssetNodeView] = Field(default_factory=list, description="Child nodes for folders")
|
||||
|
||||
|
||||
class TreeNodeNotFoundError(Exception):
|
||||
"""Tree internal: node not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TreeParentNotFoundError(Exception):
|
||||
"""Tree internal: parent folder not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TreePathConflictError(Exception):
|
||||
"""Tree internal: path already exists"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AppAssetFileTree(BaseModel):
|
||||
"""
|
||||
File tree structure for app assets using adjacency list pattern.
|
||||
|
||||
Design:
|
||||
- Storage: Flat list with parent_id references (adjacency list)
|
||||
- Path: Computed dynamically via get_path(), not stored
|
||||
- Order: Integer field for user-defined sorting within each folder
|
||||
- API response: transform() builds nested tree with computed paths
|
||||
|
||||
Why adjacency list over nested tree or materialized path:
|
||||
- Simpler CRUD: move/rename only updates one node's parent_id
|
||||
- No path cascade: renaming parent doesn't require updating all descendants
|
||||
- JSON-friendly: flat list serializes cleanly to database JSON column
|
||||
- Trade-off: path lookup is O(depth), acceptable for typical file trees
|
||||
"""
|
||||
|
||||
nodes: list[AppAssetNode] = Field(default_factory=list, description="Flat list of all nodes in the tree")
|
||||
|
||||
def get(self, node_id: str) -> AppAssetNode | None:
|
||||
return next((n for n in self.nodes if n.id == node_id), None)
|
||||
|
||||
def get_children(self, parent_id: str | None) -> list[AppAssetNode]:
|
||||
return [n for n in self.nodes if n.parent_id == parent_id]
|
||||
|
||||
def has_child_named(self, parent_id: str | None, name: str) -> bool:
|
||||
return any(n.name == name and n.parent_id == parent_id for n in self.nodes)
|
||||
|
||||
def get_path(self, node_id: str) -> str:
|
||||
node = self.get(node_id)
|
||||
if not node:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
parts: list[str] = []
|
||||
current: AppAssetNode | None = node
|
||||
while current:
|
||||
parts.append(current.name)
|
||||
current = self.get(current.parent_id) if current.parent_id else None
|
||||
return "/" + "/".join(reversed(parts))
|
||||
|
||||
def relative_path(self, a: AppAssetNode, b: AppAssetNode) -> str:
|
||||
"""
|
||||
Calculate relative path from node a to node b for Markdown references.
|
||||
Path is computed from a's parent directory (where the file resides).
|
||||
|
||||
Examples:
|
||||
/foo/a.md -> /foo/b.md => ./b.md
|
||||
/foo/a.md -> /foo/sub/b.md => ./sub/b.md
|
||||
/foo/sub/a.md -> /foo/b.md => ../b.md
|
||||
/foo/sub/deep/a.md -> /foo/b.md => ../../b.md
|
||||
"""
|
||||
|
||||
def get_ancestor_ids(node_id: str | None) -> list[str]:
|
||||
chain: list[str] = []
|
||||
current_id = node_id
|
||||
while current_id:
|
||||
chain.append(current_id)
|
||||
node = self.get(current_id)
|
||||
current_id = node.parent_id if node else None
|
||||
return chain
|
||||
|
||||
a_dir_ancestors = get_ancestor_ids(a.parent_id)
|
||||
b_ancestors = [b.id] + get_ancestor_ids(b.parent_id)
|
||||
a_dir_set = set(a_dir_ancestors)
|
||||
|
||||
lca_id: str | None = None
|
||||
lca_index_in_b = -1
|
||||
for idx, ancestor_id in enumerate(b_ancestors):
|
||||
if ancestor_id in a_dir_set or (a.parent_id is None and b_ancestors[idx:] == []):
|
||||
lca_id = ancestor_id
|
||||
lca_index_in_b = idx
|
||||
break
|
||||
|
||||
if a.parent_id is None:
|
||||
steps_up = 0
|
||||
lca_index_in_b = len(b_ancestors)
|
||||
elif lca_id is None:
|
||||
steps_up = len(a_dir_ancestors)
|
||||
lca_index_in_b = len(b_ancestors)
|
||||
else:
|
||||
steps_up = 0
|
||||
for ancestor_id in a_dir_ancestors:
|
||||
if ancestor_id == lca_id:
|
||||
break
|
||||
steps_up += 1
|
||||
|
||||
path_down: list[str] = []
|
||||
for i in range(lca_index_in_b - 1, -1, -1):
|
||||
node = self.get(b_ancestors[i])
|
||||
if node:
|
||||
path_down.append(node.name)
|
||||
|
||||
if steps_up == 0:
|
||||
return "./" + "/".join(path_down)
|
||||
|
||||
parts: list[str] = [".."] * steps_up + path_down
|
||||
return "/".join(parts)
|
||||
|
||||
def get_descendant_ids(self, node_id: str) -> list[str]:
|
||||
result: list[str] = []
|
||||
stack = [node_id]
|
||||
while stack:
|
||||
current_id = stack.pop()
|
||||
for child in self.nodes:
|
||||
if child.parent_id == current_id:
|
||||
result.append(child.id)
|
||||
stack.append(child.id)
|
||||
return result
|
||||
|
||||
def add(self, node: AppAssetNode) -> AppAssetNode:
|
||||
if self.get(node.id):
|
||||
raise TreePathConflictError(node.id)
|
||||
if self.has_child_named(node.parent_id, node.name):
|
||||
raise TreePathConflictError(node.name)
|
||||
if node.parent_id:
|
||||
parent = self.get(node.parent_id)
|
||||
if not parent or parent.node_type != AssetNodeType.FOLDER:
|
||||
raise TreeParentNotFoundError(node.parent_id)
|
||||
siblings = self.get_children(node.parent_id)
|
||||
node.order = max((s.order for s in siblings), default=-1) + 1
|
||||
self.nodes.append(node)
|
||||
return node
|
||||
|
||||
def update(self, node_id: str, size: int, checksum: str) -> AppAssetNode:
|
||||
node = self.get(node_id)
|
||||
if not node or node.node_type != AssetNodeType.FILE:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
node.size = size
|
||||
node.checksum = checksum
|
||||
return node
|
||||
|
||||
def rename(self, node_id: str, new_name: str) -> AppAssetNode:
|
||||
node = self.get(node_id)
|
||||
if not node:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
if node.name != new_name and self.has_child_named(node.parent_id, new_name):
|
||||
raise TreePathConflictError(new_name)
|
||||
node.name = new_name
|
||||
if node.node_type == AssetNodeType.FILE:
|
||||
node.extension = new_name.rsplit(".", 1)[-1] if "." in new_name else ""
|
||||
return node
|
||||
|
||||
def move(self, node_id: str, new_parent_id: str | None) -> AppAssetNode:
|
||||
node = self.get(node_id)
|
||||
if not node:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
if new_parent_id:
|
||||
parent = self.get(new_parent_id)
|
||||
if not parent or parent.node_type != AssetNodeType.FOLDER:
|
||||
raise TreeParentNotFoundError(new_parent_id)
|
||||
if self.has_child_named(new_parent_id, node.name):
|
||||
raise TreePathConflictError(node.name)
|
||||
node.parent_id = new_parent_id
|
||||
siblings = self.get_children(new_parent_id)
|
||||
node.order = max((s.order for s in siblings if s.id != node_id), default=-1) + 1
|
||||
return node
|
||||
|
||||
def reorder(self, node_id: str, after_node_id: str | None) -> AppAssetNode:
|
||||
node = self.get(node_id)
|
||||
if not node:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
|
||||
siblings = sorted(self.get_children(node.parent_id), key=lambda x: x.order)
|
||||
siblings = [s for s in siblings if s.id != node_id]
|
||||
|
||||
if after_node_id is None:
|
||||
insert_idx = 0
|
||||
else:
|
||||
after_node = self.get(after_node_id)
|
||||
if not after_node or after_node.parent_id != node.parent_id:
|
||||
raise TreeNodeNotFoundError(after_node_id)
|
||||
insert_idx = next((i for i, s in enumerate(siblings) if s.id == after_node_id), -1) + 1
|
||||
|
||||
siblings.insert(insert_idx, node)
|
||||
for idx, sibling in enumerate(siblings):
|
||||
sibling.order = idx
|
||||
|
||||
return node
|
||||
|
||||
def remove(self, node_id: str) -> list[str]:
|
||||
node = self.get(node_id)
|
||||
if not node:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
ids_to_remove = [node_id] + self.get_descendant_ids(node_id)
|
||||
self.nodes = [n for n in self.nodes if n.id not in ids_to_remove]
|
||||
return ids_to_remove
|
||||
|
||||
def walk_files(self) -> Generator[AppAssetNode, None, None]:
|
||||
return (n for n in self.nodes if n.node_type == AssetNodeType.FILE)
|
||||
|
||||
def transform(self) -> list[AppAssetNodeView]:
|
||||
by_parent: dict[str | None, list[AppAssetNode]] = defaultdict(list)
|
||||
for n in self.nodes:
|
||||
by_parent[n.parent_id].append(n)
|
||||
|
||||
for children in by_parent.values():
|
||||
children.sort(key=lambda x: x.order)
|
||||
|
||||
paths: dict[str, str] = {}
|
||||
tree_views: dict[str, AppAssetNodeView] = {}
|
||||
|
||||
def build_view(node: AppAssetNode, parent_path: str) -> None:
|
||||
path = f"{parent_path}/{node.name}"
|
||||
paths[node.id] = path
|
||||
child_views: list[AppAssetNodeView] = []
|
||||
for child in by_parent.get(node.id, []):
|
||||
build_view(child, path)
|
||||
child_views.append(tree_views[child.id])
|
||||
tree_views[node.id] = AppAssetNodeView(
|
||||
id=node.id,
|
||||
node_type=node.node_type.value,
|
||||
name=node.name,
|
||||
path=path,
|
||||
extension=node.extension,
|
||||
size=node.size,
|
||||
checksum=node.checksum,
|
||||
children=child_views,
|
||||
)
|
||||
|
||||
for root_node in by_parent.get(None, []):
|
||||
build_view(root_node, "")
|
||||
|
||||
return [tree_views[n.id] for n in by_parent.get(None, [])]
|
||||
@ -36,6 +36,9 @@ class InvokeFrom(StrEnum):
|
||||
# this is used for plugin trigger and webhook trigger.
|
||||
TRIGGER = "trigger"
|
||||
|
||||
# AGENT indicates that this invocation is from an agent.
|
||||
AGENT = "agent"
|
||||
|
||||
# EXPLORE indicates that this invocation is from
|
||||
# the workflow (or chatflow) explore page.
|
||||
EXPLORE = "explore"
|
||||
|
||||
72
api/core/app/entities/llm_generation_entities.py
Normal file
72
api/core/app/entities/llm_generation_entities.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""
|
||||
LLM Generation Detail entities.
|
||||
|
||||
Defines the structure for storing and transmitting LLM generation details
|
||||
including reasoning content, tool calls, and their sequence.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ContentSegment(BaseModel):
|
||||
"""Represents a content segment in the generation sequence."""
|
||||
|
||||
type: Literal["content"] = "content"
|
||||
start: int = Field(..., description="Start position in the text")
|
||||
end: int = Field(..., description="End position in the text")
|
||||
|
||||
|
||||
class ReasoningSegment(BaseModel):
|
||||
"""Represents a reasoning segment in the generation sequence."""
|
||||
|
||||
type: Literal["reasoning"] = "reasoning"
|
||||
index: int = Field(..., description="Index into reasoning_content array")
|
||||
|
||||
|
||||
class ToolCallSegment(BaseModel):
|
||||
"""Represents a tool call segment in the generation sequence."""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
index: int = Field(..., description="Index into tool_calls array")
|
||||
|
||||
|
||||
SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment
|
||||
|
||||
|
||||
class ToolCallDetail(BaseModel):
|
||||
"""Represents a tool call with its arguments and result."""
|
||||
|
||||
id: str = Field(default="", description="Unique identifier for the tool call")
|
||||
name: str = Field(..., description="Name of the tool")
|
||||
arguments: str = Field(default="", description="JSON string of tool arguments")
|
||||
result: str = Field(default="", description="Result from the tool execution")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds")
|
||||
icon: str | dict | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
|
||||
|
||||
class LLMGenerationDetailData(BaseModel):
|
||||
"""
|
||||
Domain model for LLM generation detail.
|
||||
|
||||
Contains the structured data for reasoning content, tool calls,
|
||||
and their display sequence.
|
||||
"""
|
||||
|
||||
reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments")
|
||||
tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details")
|
||||
sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments")
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if there's any meaningful generation detail."""
|
||||
return not self.reasoning_content and not self.tool_calls
|
||||
|
||||
def to_response_dict(self) -> dict:
|
||||
"""Convert to dictionary for API response."""
|
||||
return {
|
||||
"reasoning_content": self.reasoning_content,
|
||||
"tool_calls": [tc.model_dump() for tc in self.tool_calls],
|
||||
"sequence": [seg.model_dump() for seg in self.sequence],
|
||||
}
|
||||
@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
@ -177,6 +177,17 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
THOUGHT_START = "thought_start" # Agent thought start
|
||||
THOUGHT_END = "thought_end" # Agent thought end
|
||||
|
||||
|
||||
class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueTextChunkEvent entity
|
||||
@ -190,6 +201,18 @@ class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
|
||||
# Extended fields for Agent/Tool streaming
|
||||
chunk_type: ChunkType = ChunkType.TEXT
|
||||
"""type of the chunk"""
|
||||
|
||||
# Tool streaming payloads
|
||||
tool_call: ToolCall | None = None
|
||||
"""structured tool call info"""
|
||||
tool_result: ToolResult | None = None
|
||||
"""structured tool result info"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
@ -229,6 +252,8 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
|
||||
|
||||
class QueueAnnotationReplyEvent(AppQueueEvent):
|
||||
@ -306,6 +331,8 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
||||
in_iteration_id: str | None = None
|
||||
in_loop_id: str | None = None
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
@ -328,6 +355,8 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
@ -383,6 +412,8 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
@ -407,6 +438,8 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
@ -113,6 +113,38 @@ class MessageStreamResponse(StreamResponse):
|
||||
answer: str
|
||||
from_variable_selector: list[str] | None = None
|
||||
|
||||
# Extended fields for Agent/Tool streaming (imported at runtime to avoid circular import)
|
||||
chunk_type: str | None = None
|
||||
"""type of the chunk: text, tool_call, tool_result, thought"""
|
||||
|
||||
# Tool call fields (when chunk_type == "tool_call")
|
||||
tool_call_id: str | None = None
|
||||
"""unique identifier for this tool call"""
|
||||
tool_name: str | None = None
|
||||
"""name of the tool being called"""
|
||||
tool_arguments: str | None = None
|
||||
"""accumulated tool arguments JSON"""
|
||||
|
||||
# Tool result fields (when chunk_type == "tool_result")
|
||||
tool_files: list[str] | None = None
|
||||
"""file IDs produced by tool"""
|
||||
tool_error: str | None = None
|
||||
"""error message if tool failed"""
|
||||
tool_elapsed_time: float | None = None
|
||||
"""elapsed time spent executing the tool"""
|
||||
tool_icon: str | dict | None = None
|
||||
"""icon of the tool"""
|
||||
tool_icon_dark: str | dict | None = None
|
||||
"""dark theme icon of the tool"""
|
||||
|
||||
def model_dump(self, *args, **kwargs) -> dict[str, object]:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump(*args, **kwargs)
|
||||
|
||||
def model_dump_json(self, *args, **kwargs) -> str:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump_json(*args, **kwargs)
|
||||
|
||||
|
||||
class MessageAudioStreamResponse(StreamResponse):
|
||||
"""
|
||||
@ -262,6 +294,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
mention_parent_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
@ -285,6 +318,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
"extras": {},
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"mention_parent_id": self.data.mention_parent_id,
|
||||
},
|
||||
}
|
||||
|
||||
@ -320,6 +354,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
mention_parent_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_FINISHED
|
||||
workflow_run_id: str
|
||||
@ -349,6 +384,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"files": [],
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"mention_parent_id": self.data.mention_parent_id,
|
||||
},
|
||||
}
|
||||
|
||||
@ -384,6 +420,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
mention_parent_id: str | None = None
|
||||
retry_index: int = 0
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_RETRY
|
||||
@ -414,6 +451,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
"files": [],
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"mention_parent_id": self.data.mention_parent_id,
|
||||
"retry_index": self.data.retry_index,
|
||||
},
|
||||
}
|
||||
@ -582,6 +620,17 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
THOUGHT_START = "thought_start" # Agent thought start
|
||||
THOUGHT_END = "thought_end" # Agent thought end
|
||||
|
||||
|
||||
class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
TextChunkStreamResponse entity
|
||||
@ -595,6 +644,36 @@ class TextChunkStreamResponse(StreamResponse):
|
||||
text: str
|
||||
from_variable_selector: list[str] | None = None
|
||||
|
||||
# Extended fields for Agent/Tool streaming
|
||||
chunk_type: ChunkType = ChunkType.TEXT
|
||||
"""type of the chunk"""
|
||||
|
||||
# Tool call fields (when chunk_type == TOOL_CALL)
|
||||
tool_call_id: str | None = None
|
||||
"""unique identifier for this tool call"""
|
||||
tool_name: str | None = None
|
||||
"""name of the tool being called"""
|
||||
tool_arguments: str | None = None
|
||||
"""accumulated tool arguments JSON"""
|
||||
|
||||
# Tool result fields (when chunk_type == TOOL_RESULT)
|
||||
tool_files: list[str] | None = None
|
||||
"""file IDs produced by tool"""
|
||||
tool_error: str | None = None
|
||||
"""error message if tool failed"""
|
||||
|
||||
# Tool elapsed time fields (when chunk_type == TOOL_RESULT)
|
||||
tool_elapsed_time: float | None = None
|
||||
"""elapsed time spent executing the tool"""
|
||||
|
||||
def model_dump(self, *args, **kwargs) -> dict[str, object]:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump(*args, **kwargs)
|
||||
|
||||
def model_dump_json(self, *args, **kwargs) -> str:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump_json(*args, **kwargs)
|
||||
|
||||
event: StreamEvent = StreamEvent.TEXT_CHUNK
|
||||
data: Data
|
||||
|
||||
@ -743,7 +822,7 @@ class AgentLogStreamResponse(StreamResponse):
|
||||
"""
|
||||
|
||||
node_execution_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
label: str
|
||||
parent_id: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
22
api/core/app/layers/sandbox_layer.py
Normal file
22
api/core/app/layers/sandbox_layer.py
Normal file
@ -0,0 +1,22 @@
|
||||
import logging
|
||||
|
||||
from core.sandbox import Sandbox
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxLayer(GraphEngineLayer):
|
||||
def __init__(self, sandbox: Sandbox) -> None:
|
||||
super().__init__()
|
||||
self._sandbox = sandbox
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
pass
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
self._sandbox.release()
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
@ -58,7 +59,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought
|
||||
from models.model import AppMode, Conversation, LLMGenerationDetail, Message, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -68,6 +69,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||
|
||||
@ -409,11 +412,136 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
)
|
||||
)
|
||||
|
||||
# Save LLM generation detail if there's reasoning_content
|
||||
self._save_generation_detail(session=session, message=message, llm_result=llm_result)
|
||||
|
||||
message_was_created.send(
|
||||
message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
)
|
||||
|
||||
def _save_generation_detail(self, *, session: Session, message: Message, llm_result: LLMResult) -> None:
|
||||
"""
|
||||
Save LLM generation detail for Completion/Chat/Agent-Chat applications.
|
||||
For Agent-Chat, also merges MessageAgentThought records.
|
||||
"""
|
||||
import json
|
||||
|
||||
reasoning_list: list[str] = []
|
||||
tool_calls_list: list[dict] = []
|
||||
sequence: list[dict] = []
|
||||
answer = message.answer or ""
|
||||
|
||||
# Check if this is Agent-Chat mode by looking for agent thoughts
|
||||
agent_thoughts = (
|
||||
session.query(MessageAgentThought)
|
||||
.filter_by(message_id=message.id)
|
||||
.order_by(MessageAgentThought.position.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
if agent_thoughts:
|
||||
# Agent-Chat mode: merge MessageAgentThought records
|
||||
content_pos = 0
|
||||
cleaned_answer_parts: list[str] = []
|
||||
for thought in agent_thoughts:
|
||||
# Add thought/reasoning
|
||||
if thought.thought:
|
||||
reasoning_text = thought.thought
|
||||
if "<think" in reasoning_text.lower():
|
||||
clean_text, extracted_reasoning = self._split_reasoning_from_answer(reasoning_text)
|
||||
if extracted_reasoning:
|
||||
reasoning_text = extracted_reasoning
|
||||
thought.thought = clean_text or extracted_reasoning
|
||||
reasoning_list.append(reasoning_text)
|
||||
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
|
||||
|
||||
# Add tool calls
|
||||
if thought.tool:
|
||||
tool_calls_list.append(
|
||||
{
|
||||
"name": thought.tool,
|
||||
"arguments": thought.tool_input or "",
|
||||
"result": thought.observation or "",
|
||||
}
|
||||
)
|
||||
sequence.append({"type": "tool_call", "index": len(tool_calls_list) - 1})
|
||||
|
||||
# Add answer content if present
|
||||
if thought.answer:
|
||||
content_text = thought.answer
|
||||
if "<think" in content_text.lower():
|
||||
clean_answer, extracted_reasoning = self._split_reasoning_from_answer(content_text)
|
||||
if extracted_reasoning:
|
||||
reasoning_list.append(extracted_reasoning)
|
||||
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
|
||||
content_text = clean_answer
|
||||
thought.answer = clean_answer or content_text
|
||||
|
||||
if content_text:
|
||||
start = content_pos
|
||||
end = content_pos + len(content_text)
|
||||
sequence.append({"type": "content", "start": start, "end": end})
|
||||
content_pos = end
|
||||
cleaned_answer_parts.append(content_text)
|
||||
|
||||
if cleaned_answer_parts:
|
||||
merged_answer = "".join(cleaned_answer_parts)
|
||||
message.answer = merged_answer
|
||||
llm_result.message.content = merged_answer
|
||||
else:
|
||||
# Completion/Chat mode: use reasoning_content from llm_result
|
||||
reasoning_content = llm_result.reasoning_content
|
||||
if not reasoning_content and answer:
|
||||
# Extract reasoning from <think> blocks and clean the final answer
|
||||
clean_answer, reasoning_content = self._split_reasoning_from_answer(answer)
|
||||
if reasoning_content:
|
||||
answer = clean_answer
|
||||
llm_result.message.content = clean_answer
|
||||
llm_result.reasoning_content = reasoning_content
|
||||
message.answer = clean_answer
|
||||
if reasoning_content:
|
||||
reasoning_list = [reasoning_content]
|
||||
# Content comes first, then reasoning
|
||||
if answer:
|
||||
sequence.append({"type": "content", "start": 0, "end": len(answer)})
|
||||
sequence.append({"type": "reasoning", "index": 0})
|
||||
|
||||
# Only save if there's meaningful generation detail
|
||||
if not reasoning_list and not tool_calls_list:
|
||||
return
|
||||
|
||||
# Check if generation detail already exists
|
||||
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
|
||||
|
||||
if existing:
|
||||
existing.reasoning_content = json.dumps(reasoning_list) if reasoning_list else None
|
||||
existing.tool_calls = json.dumps(tool_calls_list) if tool_calls_list else None
|
||||
existing.sequence = json.dumps(sequence) if sequence else None
|
||||
else:
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
reasoning_content=json.dumps(reasoning_list) if reasoning_list else None,
|
||||
tool_calls=json.dumps(tool_calls_list) if tool_calls_list else None,
|
||||
sequence=json.dumps(sequence) if sequence else None,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
@classmethod
|
||||
def _split_reasoning_from_answer(cls, text: str) -> tuple[str, str]:
|
||||
"""
|
||||
Extract reasoning segments from <think> blocks and return (clean_text, reasoning).
|
||||
"""
|
||||
matches = cls._THINK_PATTERN.findall(text)
|
||||
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
|
||||
|
||||
clean_text = cls._THINK_PATTERN.sub("", text)
|
||||
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||
|
||||
return clean_text, reasoning_content or ""
|
||||
|
||||
def _handle_stop(self, event: QueueStopEvent):
|
||||
"""
|
||||
Handle stop.
|
||||
|
||||
@ -232,15 +232,31 @@ class MessageCycleManager:
|
||||
answer: str,
|
||||
message_id: str,
|
||||
from_variable_selector: list[str] | None = None,
|
||||
chunk_type: str | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_arguments: str | None = None,
|
||||
tool_files: list[str] | None = None,
|
||||
tool_error: str | None = None,
|
||||
tool_elapsed_time: float | None = None,
|
||||
tool_icon: str | dict | None = None,
|
||||
tool_icon_dark: str | dict | None = None,
|
||||
event_type: StreamEvent | None = None,
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
Message to stream response.
|
||||
:param answer: answer
|
||||
:param message_id: message id
|
||||
:param from_variable_selector: from variable selector
|
||||
:param chunk_type: type of the chunk (text, function_call, tool_result, thought)
|
||||
:param tool_call_id: unique identifier for this tool call
|
||||
:param tool_name: name of the tool being called
|
||||
:param tool_arguments: accumulated tool arguments JSON
|
||||
:param tool_files: file IDs produced by tool
|
||||
:param tool_error: error message if tool failed
|
||||
:return:
|
||||
"""
|
||||
return MessageStreamResponse(
|
||||
response = MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer,
|
||||
@ -248,6 +264,35 @@ class MessageCycleManager:
|
||||
event=event_type or StreamEvent.MESSAGE,
|
||||
)
|
||||
|
||||
if chunk_type:
|
||||
response = response.model_copy(update={"chunk_type": chunk_type})
|
||||
|
||||
if chunk_type == "tool_call":
|
||||
response = response.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
elif chunk_type == "tool_result":
|
||||
response = response.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_files": tool_files,
|
||||
"tool_error": tool_error,
|
||||
"tool_elapsed_time": tool_elapsed_time,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
"""
|
||||
Message replace to stream response.
|
||||
|
||||
31
api/core/app_assets/__init__.py
Normal file
31
api/core/app_assets/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
from .entities import (
|
||||
AssetItem,
|
||||
FileAsset,
|
||||
FileReference,
|
||||
SkillAsset,
|
||||
SkillMetadata,
|
||||
ToolConfiguration,
|
||||
ToolFieldConfig,
|
||||
ToolReference,
|
||||
)
|
||||
from .packager import AssetPackager, ZipPackager
|
||||
from .parser import AssetItemParser, AssetParser, FileAssetParser, SkillAssetParser
|
||||
from .paths import AssetPaths
|
||||
|
||||
__all__ = [
|
||||
"AssetItem",
|
||||
"AssetItemParser",
|
||||
"AssetPackager",
|
||||
"AssetParser",
|
||||
"AssetPaths",
|
||||
"FileAsset",
|
||||
"FileAssetParser",
|
||||
"FileReference",
|
||||
"SkillAsset",
|
||||
"SkillAssetParser",
|
||||
"SkillMetadata",
|
||||
"ToolConfiguration",
|
||||
"ToolFieldConfig",
|
||||
"ToolReference",
|
||||
"ZipPackager",
|
||||
]
|
||||
20
api/core/app_assets/entities/__init__.py
Normal file
20
api/core/app_assets/entities/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
from .assets import AssetItem, FileAsset
|
||||
from .skill import (
|
||||
FileReference,
|
||||
SkillAsset,
|
||||
SkillMetadata,
|
||||
ToolConfiguration,
|
||||
ToolFieldConfig,
|
||||
ToolReference,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AssetItem",
|
||||
"FileAsset",
|
||||
"FileReference",
|
||||
"SkillAsset",
|
||||
"SkillMetadata",
|
||||
"ToolConfiguration",
|
||||
"ToolFieldConfig",
|
||||
"ToolReference",
|
||||
]
|
||||
22
api/core/app_assets/entities/assets.py
Normal file
22
api/core/app_assets/entities/assets.py
Normal file
@ -0,0 +1,22 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssetItem(ABC):
|
||||
node_id: str
|
||||
path: str
|
||||
file_name: str
|
||||
extension: str
|
||||
|
||||
@abstractmethod
|
||||
def get_storage_key(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileAsset(AssetItem):
|
||||
storage_key: str
|
||||
|
||||
def get_storage_key(self) -> str:
|
||||
return self.storage_key
|
||||
59
api/core/app_assets/entities/skill.py
Normal file
59
api/core/app_assets/entities/skill.py
Normal file
@ -0,0 +1,59 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
from .assets import AssetItem
|
||||
|
||||
|
||||
class ToolFieldConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: str
|
||||
value: Any
|
||||
auto: bool = False
|
||||
|
||||
|
||||
class ToolConfiguration(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
fields: list[ToolFieldConfig] = Field(default_factory=list)
|
||||
|
||||
def default_values(self) -> dict[str, Any]:
|
||||
return {field.id: field.value for field in self.fields if field.value is not None}
|
||||
|
||||
|
||||
class ToolReference(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
uuid: str = Field(description="Unique identifier for this tool reference")
|
||||
type: ToolProviderType = Field(description="Tool provider type")
|
||||
provider: str = Field(description="Tool provider")
|
||||
tool_name: str = Field(description="Tool name")
|
||||
credential_id: str | None = Field(default=None, description="Credential ID")
|
||||
configuration: ToolConfiguration | None = Field(default=None, description="Tool configuration")
|
||||
|
||||
|
||||
class FileReference(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
source: str = Field(description="Source location or identifier of the file")
|
||||
uuid: str = Field(description="Unique identifier for this file reference")
|
||||
|
||||
|
||||
class SkillMetadata(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
tools: dict[str, ToolReference] = Field(default_factory=dict, description="Map of tool references by UUID")
|
||||
files: list[FileReference] = Field(default_factory=list, description="List of file references")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillAsset(AssetItem):
|
||||
storage_key: str
|
||||
metadata: SkillMetadata
|
||||
|
||||
def get_storage_key(self) -> str:
|
||||
return self.storage_key
|
||||
7
api/core/app_assets/packager/__init__.py
Normal file
7
api/core/app_assets/packager/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from .base import AssetPackager
|
||||
from .zip_packager import ZipPackager
|
||||
|
||||
__all__ = [
|
||||
"AssetPackager",
|
||||
"ZipPackager",
|
||||
]
|
||||
9
api/core/app_assets/packager/base.py
Normal file
9
api/core/app_assets/packager/base.py
Normal file
@ -0,0 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.app_assets.entities import AssetItem
|
||||
|
||||
|
||||
class AssetPackager(ABC):
|
||||
@abstractmethod
|
||||
def package(self, assets: list[AssetItem]) -> bytes:
|
||||
raise NotImplementedError
|
||||
42
api/core/app_assets/packager/zip_packager.py
Normal file
42
api/core/app_assets/packager/zip_packager.py
Normal file
@ -0,0 +1,42 @@
|
||||
import io
|
||||
import zipfile
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.app_assets.entities import AssetItem
|
||||
|
||||
from .base import AssetPackager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_storage import Storage
|
||||
|
||||
|
||||
class ZipPackager(AssetPackager):
|
||||
_storage: "Storage"
|
||||
|
||||
def __init__(self, storage: "Storage") -> None:
|
||||
self._storage = storage
|
||||
|
||||
def package(self, assets: list[AssetItem]) -> bytes:
|
||||
zip_buffer = io.BytesIO()
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
lock = Lock()
|
||||
# FOR DELVELPMENT AND TESTING ONLY, TODO: optimize
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
futures: list[Future[None]] = []
|
||||
for asset in assets:
|
||||
|
||||
def _write_asset(a: AssetItem) -> None:
|
||||
content = self._storage.load_once(a.get_storage_key())
|
||||
with lock:
|
||||
zf.writestr(a.path, content)
|
||||
|
||||
futures.append(executor.submit(_write_asset, asset))
|
||||
|
||||
# Wait for all futures to complete
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
return zip_buffer.getvalue()
|
||||
10
api/core/app_assets/parser/__init__.py
Normal file
10
api/core/app_assets/parser/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
from .asset_parser import AssetParser
|
||||
from .base import AssetItemParser, FileAssetParser
|
||||
from .skill_parser import SkillAssetParser
|
||||
|
||||
__all__ = [
|
||||
"AssetItemParser",
|
||||
"AssetParser",
|
||||
"FileAssetParser",
|
||||
"SkillAssetParser",
|
||||
]
|
||||
36
api/core/app_assets/parser/asset_parser.py
Normal file
36
api/core/app_assets/parser/asset_parser.py
Normal file
@ -0,0 +1,36 @@
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.app_assets.entities import AssetItem
|
||||
from core.app_assets.paths import AssetPaths
|
||||
|
||||
from .base import AssetItemParser, FileAssetParser
|
||||
|
||||
|
||||
class AssetParser:
|
||||
def __init__(
|
||||
self,
|
||||
tree: AppAssetFileTree,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
) -> None:
|
||||
self._tree = tree
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._parsers = {}
|
||||
self._default_parser = FileAssetParser()
|
||||
|
||||
def register(self, extension: str, parser: AssetItemParser) -> None:
|
||||
self._parsers[extension] = parser
|
||||
|
||||
def parse(self) -> list[AssetItem]:
|
||||
assets: list[AssetItem] = []
|
||||
|
||||
for node in self._tree.walk_files():
|
||||
path = self._tree.get_path(node.id).lstrip("/")
|
||||
storage_key = AssetPaths.draft_file(self._tenant_id, self._app_id, node.id)
|
||||
extension = node.extension or ""
|
||||
|
||||
parser = self._parsers.get(extension, self._default_parser)
|
||||
asset = parser.parse(node.id, path, node.name, extension, storage_key)
|
||||
assets.append(asset)
|
||||
|
||||
return assets
|
||||
34
api/core/app_assets/parser/base.py
Normal file
34
api/core/app_assets/parser/base.py
Normal file
@ -0,0 +1,34 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.app_assets.entities import AssetItem, FileAsset
|
||||
|
||||
|
||||
class AssetItemParser(ABC):
|
||||
@abstractmethod
|
||||
def parse(
|
||||
self,
|
||||
node_id: str,
|
||||
path: str,
|
||||
file_name: str,
|
||||
extension: str,
|
||||
storage_key: str,
|
||||
) -> AssetItem:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FileAssetParser(AssetItemParser):
|
||||
def parse(
|
||||
self,
|
||||
node_id: str,
|
||||
path: str,
|
||||
file_name: str,
|
||||
extension: str,
|
||||
storage_key: str,
|
||||
) -> FileAsset:
|
||||
return FileAsset(
|
||||
node_id=node_id,
|
||||
path=path,
|
||||
file_name=file_name,
|
||||
extension=extension,
|
||||
storage_key=storage_key,
|
||||
)
|
||||
161
api/core/app_assets/parser/skill_parser.py
Normal file
161
api/core/app_assets/parser/skill_parser.py
Normal file
@ -0,0 +1,161 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
|
||||
from core.app_assets.entities import (
|
||||
SkillAsset,
|
||||
SkillMetadata,
|
||||
)
|
||||
from core.app_assets.entities.skill import FileReference, ToolConfiguration, ToolReference
|
||||
from core.app_assets.paths import AssetPaths
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
from .base import AssetItemParser
|
||||
|
||||
TOOL_REFERENCE_PATTERN = re.compile(r"§\[tool\]\.\[([^\]]+)\]\.\[([^\]]+)\]\.\[([^\]]+)\]§")
|
||||
FILE_REFERENCE_PATTERN = re.compile(r"§\[file\]\.\[([^\]]+)\]\.\[([^\]]+)\]§")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillAssetParser(AssetItemParser):
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
assets_id: str,
|
||||
tree: AppAssetFileTree,
|
||||
) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._assets_id = assets_id
|
||||
self._tree = tree
|
||||
|
||||
def parse(
|
||||
self,
|
||||
node_id: str,
|
||||
path: str,
|
||||
file_name: str,
|
||||
extension: str,
|
||||
storage_key: str,
|
||||
) -> SkillAsset:
|
||||
try:
|
||||
return self._parse_skill_asset(node_id, path, file_name, extension, storage_key)
|
||||
except Exception:
|
||||
logger.exception("Failed to parse skill asset %s", node_id)
|
||||
# handle as plain text
|
||||
return SkillAsset(
|
||||
node_id=node_id,
|
||||
path=path,
|
||||
file_name=file_name,
|
||||
extension=extension,
|
||||
storage_key=storage_key,
|
||||
metadata=SkillMetadata(),
|
||||
)
|
||||
|
||||
def _parse_skill_asset(
|
||||
self, node_id: str, path: str, file_name: str, extension: str, storage_key: str
|
||||
) -> SkillAsset:
|
||||
try:
|
||||
data = json.loads(storage.load_once(storage_key))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
# handle as plain text
|
||||
return SkillAsset(
|
||||
node_id=node_id,
|
||||
path=path,
|
||||
file_name=file_name,
|
||||
extension=extension,
|
||||
storage_key=storage_key,
|
||||
metadata=SkillMetadata(),
|
||||
)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Skill document {node_id} must be a JSON object")
|
||||
|
||||
data_dict: dict[str, Any] = data
|
||||
metadata_raw = data_dict.get("metadata", {})
|
||||
content = data_dict.get("content", "")
|
||||
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(f"Skill document {node_id} 'content' must be a string")
|
||||
|
||||
resolved_key = AssetPaths.build_resolved_file(self._tenant_id, self._app_id, self._assets_id, node_id)
|
||||
current_file = self._tree.get(node_id)
|
||||
if current_file is None:
|
||||
raise ValueError(f"File not found for id={node_id}")
|
||||
|
||||
metadata = self._resolve_metadata(content, metadata_raw)
|
||||
storage.save(resolved_key, self._resolve_content(current_file, content, metadata).encode("utf-8"))
|
||||
|
||||
return SkillAsset(
|
||||
node_id=node_id,
|
||||
path=path,
|
||||
file_name=file_name,
|
||||
extension=extension,
|
||||
storage_key=resolved_key,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _resolve_content(self, current_file: AppAssetNode, content: str, metadata: SkillMetadata) -> str:
|
||||
for match in FILE_REFERENCE_PATTERN.finditer(content):
|
||||
# replace with file relative path
|
||||
file_id = match.group(2)
|
||||
file = self._tree.get(file_id)
|
||||
if file is None:
|
||||
logger.warning("File not found for id=%s, skipping", file_id)
|
||||
# replace with file not found placeholder
|
||||
content = content.replace(match.group(0), "[File not found]")
|
||||
continue
|
||||
content = content.replace(match.group(0), self._tree.relative_path(current_file, file))
|
||||
|
||||
for match in TOOL_REFERENCE_PATTERN.finditer(content):
|
||||
tool_id = match.group(3)
|
||||
tool = metadata.tools.get(tool_id)
|
||||
if tool is None:
|
||||
logger.warning("Tool not found for id=%s, skipping", tool_id)
|
||||
# replace with tool not found placeholder
|
||||
content = content.replace(match.group(0), f"[Tool not found: {tool_id}]")
|
||||
continue
|
||||
content = content.replace(match.group(0), f"[Bash Command: {tool.tool_name}_{tool_id}]")
|
||||
return content
|
||||
|
||||
def _resolve_file_references(self, content: str) -> list[FileReference]:
|
||||
file_references: list[FileReference] = []
|
||||
for match in FILE_REFERENCE_PATTERN.finditer(content):
|
||||
file_references.append(
|
||||
FileReference(
|
||||
source=match.group(1),
|
||||
uuid=match.group(2),
|
||||
)
|
||||
)
|
||||
return file_references
|
||||
|
||||
def _resolve_tool_references(self, content: str, tools: dict[str, Any]) -> dict[str, ToolReference]:
|
||||
tool_references: dict[str, ToolReference] = {}
|
||||
for match in TOOL_REFERENCE_PATTERN.finditer(content):
|
||||
tool_id = match.group(3)
|
||||
tool_name = match.group(2)
|
||||
tool_provider = match.group(1)
|
||||
metadata = tools.get(tool_id)
|
||||
if metadata is None:
|
||||
raise ValueError(f"Tool metadata for {tool_id} not found")
|
||||
|
||||
configuration = ToolConfiguration.model_validate(metadata.get("configuration", {}))
|
||||
tool_references[tool_id] = ToolReference(
|
||||
uuid=tool_id,
|
||||
type=ToolProviderType.value_of(metadata.get("type", None)),
|
||||
provider=tool_provider,
|
||||
tool_name=tool_name,
|
||||
credential_id=metadata.get("credential_id", None),
|
||||
configuration=configuration,
|
||||
)
|
||||
return tool_references
|
||||
|
||||
def _resolve_metadata(self, content: str, metadata: dict[str, Any]) -> SkillMetadata:
|
||||
return SkillMetadata(
|
||||
files=self._resolve_file_references(content=content),
|
||||
tools=self._resolve_tool_references(content=content, tools=metadata.get("tools", {})),
|
||||
)
|
||||
18
api/core/app_assets/paths.py
Normal file
18
api/core/app_assets/paths.py
Normal file
@ -0,0 +1,18 @@
|
||||
class AssetPaths:
|
||||
_BASE = "app_assets"
|
||||
|
||||
@staticmethod
|
||||
def draft_file(tenant_id: str, app_id: str, node_id: str) -> str:
|
||||
return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/draft/{node_id}"
|
||||
|
||||
@staticmethod
|
||||
def build_zip(tenant_id: str, app_id: str, assets_id: str) -> str:
|
||||
return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}.zip"
|
||||
|
||||
@staticmethod
|
||||
def build_resolved_file(tenant_id: str, app_id: str, assets_id: str, node_id: str) -> str:
|
||||
return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}/resolved/{node_id}"
|
||||
|
||||
@staticmethod
|
||||
def build_tool_artifact(tenant_id: str, app_id: str, assets_id: str) -> str:
|
||||
return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}/tool_artifact.json"
|
||||
@ -5,7 +5,6 @@ from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.models.document import Document
|
||||
@ -90,6 +89,8 @@ class DatasetIndexToolCallbackHandler:
|
||||
# TODO(-LAN-): Improve type check
|
||||
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
|
||||
from configs import dify_config
|
||||
@ -10,7 +11,10 @@ from core.model_runtime.entities import (
|
||||
TextPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessageContentUnionTypes,
|
||||
)
|
||||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
@ -18,6 +22,8 @@ from . import helpers
|
||||
from .enums import FileAttribute
|
||||
from .models import File, FileTransferMethod, FileType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_attr(*, file: File, attr: FileAttribute):
|
||||
match attr:
|
||||
@ -89,6 +95,8 @@ def to_prompt_message_content(
|
||||
"format": f.extension.removeprefix("."),
|
||||
"mime_type": f.mime_type,
|
||||
"filename": f.filename or "",
|
||||
# Encoded file reference for context restoration: "transfer_method:related_id" or "remote:url"
|
||||
"file_ref": _encode_file_ref(f),
|
||||
}
|
||||
if f.type == FileType.IMAGE:
|
||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
@ -96,6 +104,17 @@ def to_prompt_message_content(
|
||||
return prompt_class_map[f.type].model_validate(params)
|
||||
|
||||
|
||||
def _encode_file_ref(f: File) -> str | None:
|
||||
"""Encode file reference as 'transfer_method:id_or_url' string."""
|
||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return f"remote:{f.remote_url}" if f.remote_url else None
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
return f"local:{f.related_id}" if f.related_id else None
|
||||
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
return f"tool:{f.related_id}" if f.related_id else None
|
||||
return None
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
if f.transfer_method in (
|
||||
FileTransferMethod.TOOL_FILE,
|
||||
@ -164,3 +183,128 @@ def _to_url(f: File, /):
|
||||
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def restore_multimodal_content(
|
||||
content: MultiModalPromptMessageContent,
|
||||
) -> MultiModalPromptMessageContent:
|
||||
"""
|
||||
Restore base64_data or url for multimodal content from file_ref.
|
||||
|
||||
file_ref format: "transfer_method:id_or_url" (e.g., "local:abc123", "remote:https://...")
|
||||
|
||||
Args:
|
||||
content: MultiModalPromptMessageContent with file_ref field
|
||||
|
||||
Returns:
|
||||
MultiModalPromptMessageContent with restored base64_data or url
|
||||
"""
|
||||
# Skip if no file reference or content already has data
|
||||
if not content.file_ref:
|
||||
return content
|
||||
if content.base64_data or content.url:
|
||||
return content
|
||||
|
||||
try:
|
||||
file = _build_file_from_ref(
|
||||
file_ref=content.file_ref,
|
||||
file_format=content.format,
|
||||
mime_type=content.mime_type,
|
||||
filename=content.filename,
|
||||
)
|
||||
if not file:
|
||||
return content
|
||||
|
||||
# Restore content based on config
|
||||
if dify_config.MULTIMODAL_SEND_FORMAT == "base64":
|
||||
restored_base64 = _get_encoded_string(file)
|
||||
return content.model_copy(update={"base64_data": restored_base64})
|
||||
else:
|
||||
restored_url = _to_url(file)
|
||||
return content.model_copy(update={"url": restored_url})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to restore multimodal content: %s", e)
|
||||
return content
|
||||
|
||||
|
||||
def _build_file_from_ref(
|
||||
file_ref: str,
|
||||
file_format: str | None,
|
||||
mime_type: str | None,
|
||||
filename: str | None,
|
||||
) -> File | None:
|
||||
"""
|
||||
Build a File object from encoded file_ref string.
|
||||
|
||||
Args:
|
||||
file_ref: Encoded reference "transfer_method:id_or_url"
|
||||
file_format: The file format/extension (without dot)
|
||||
mime_type: The mime type
|
||||
filename: The filename
|
||||
|
||||
Returns:
|
||||
File object with storage_key loaded, or None if not found
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
# Parse file_ref: "method:value"
|
||||
if ":" not in file_ref:
|
||||
logger.warning("Invalid file_ref format: %s", file_ref)
|
||||
return None
|
||||
|
||||
method, value = file_ref.split(":", 1)
|
||||
extension = f".{file_format}" if file_format else None
|
||||
|
||||
if method == "remote":
|
||||
return File(
|
||||
tenant_id="",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=value,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
filename=filename,
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
# Query database for storage_key
|
||||
with Session(db.engine) as session:
|
||||
if method == "local":
|
||||
stmt = select(UploadFile).where(UploadFile.id == value)
|
||||
upload_file = session.scalar(stmt)
|
||||
if upload_file:
|
||||
return File(
|
||||
tenant_id=upload_file.tenant_id,
|
||||
type=FileType(upload_file.extension)
|
||||
if hasattr(FileType, upload_file.extension.upper())
|
||||
else FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id=value,
|
||||
extension=extension or ("." + upload_file.extension if upload_file.extension else None),
|
||||
mime_type=mime_type or upload_file.mime_type,
|
||||
filename=filename or upload_file.name,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
elif method == "tool":
|
||||
stmt = select(ToolFile).where(ToolFile.id == value)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file:
|
||||
return File(
|
||||
tenant_id=tool_file.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=value,
|
||||
extension=extension,
|
||||
mime_type=mime_type or tool_file.mimetype,
|
||||
filename=filename or tool_file.name,
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
logger.warning("File not found for file_ref: %s", file_ref)
|
||||
return None
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
import json_repair
|
||||
|
||||
from core.llm_generator.output_models import (
|
||||
CodeNodeStructuredOutput,
|
||||
InstructionModifyOutput,
|
||||
SuggestedQuestionsOutput,
|
||||
)
|
||||
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
from core.llm_generator.prompts import (
|
||||
@ -393,6 +398,432 @@ class LLMGenerator:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
|
||||
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@classmethod
|
||||
def generate_with_context(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
parameter_name: str,
|
||||
language: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate extractor code node based on conversation context.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant/workspace ID
|
||||
workflow_id: Workflow ID
|
||||
node_id: Current tool/llm node ID
|
||||
parameter_name: Parameter name to generate code for
|
||||
language: Code language (python3/javascript)
|
||||
prompt_messages: Multi-turn conversation history (last message is instruction)
|
||||
model_config: Model configuration (provider, name, completion_params)
|
||||
|
||||
Returns:
|
||||
dict with CodeNodeData format:
|
||||
- variables: Input variable selectors
|
||||
- code_language: Code language
|
||||
- code: Generated code
|
||||
- outputs: Output definitions
|
||||
- message: Explanation
|
||||
- error: Error message if any
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
# Get workflow
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(App).where(App.id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
return cls._error_response(f"App {workflow_id} not found")
|
||||
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return cls._error_response(f"Workflow for app {workflow_id} not found")
|
||||
|
||||
# Get upstream nodes via edge backtracking
|
||||
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
|
||||
|
||||
# Get current node info
|
||||
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
|
||||
if not current_node:
|
||||
return cls._error_response(f"Node {node_id} not found")
|
||||
|
||||
# Get parameter info
|
||||
parameter_info = cls._get_parameter_info(
|
||||
tenant_id=tenant_id,
|
||||
node_data=current_node.get("data", {}),
|
||||
parameter_name=parameter_name,
|
||||
)
|
||||
|
||||
# Build system prompt
|
||||
system_prompt = cls._build_extractor_system_prompt(
|
||||
upstream_nodes=upstream_nodes,
|
||||
current_node=current_node,
|
||||
parameter_info=parameter_info,
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Construct complete prompt_messages with system prompt
|
||||
complete_messages: list[PromptMessage] = [
|
||||
SystemPromptMessage(content=system_prompt),
|
||||
*prompt_messages,
|
||||
]
|
||||
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
|
||||
|
||||
# Get model instance and schema
|
||||
provider = model_config.get("provider", "")
|
||||
model_name = model_config.get("name", "")
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
if not model_schema:
|
||||
return cls._error_response(f"Model schema not found for {model_name}")
|
||||
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
try:
|
||||
response = invoke_llm_with_pydantic_model(
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=complete_messages,
|
||||
output_model=CodeNodeStructuredOutput,
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return cls._parse_code_node_output(
|
||||
response.structured_output, language, parameter_info.get("type", "string")
|
||||
)
|
||||
|
||||
except InvokeError as e:
|
||||
return cls._error_response(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate with context, model: %s", model_config.get("name"))
|
||||
return cls._error_response(f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def _error_response(cls, error: str) -> dict:
|
||||
"""Return error response in CodeNodeData format."""
|
||||
return {
|
||||
"variables": [],
|
||||
"code_language": "python3",
|
||||
"code": "",
|
||||
"outputs": {},
|
||||
"message": "",
|
||||
"error": error,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
parameter_name: str,
|
||||
language: str,
|
||||
model_config: dict | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate suggested questions for context generation.
|
||||
|
||||
Returns dict with questions array and error field.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
# Get workflow context (reuse existing logic)
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(App).where(App.id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
return {"questions": [], "error": f"App {workflow_id} not found"}
|
||||
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return {"questions": [], "error": f"Workflow for app {workflow_id} not found"}
|
||||
|
||||
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
|
||||
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
|
||||
if not current_node:
|
||||
return {"questions": [], "error": f"Node {node_id} not found"}
|
||||
|
||||
parameter_info = cls._get_parameter_info(
|
||||
tenant_id=tenant_id,
|
||||
node_data=current_node.get("data", {}),
|
||||
parameter_name=parameter_name,
|
||||
)
|
||||
|
||||
# Build prompt
|
||||
system_prompt = cls._build_suggested_questions_prompt(
|
||||
upstream_nodes=upstream_nodes,
|
||||
current_node=current_node,
|
||||
parameter_info=parameter_info,
|
||||
language=language,
|
||||
)
|
||||
|
||||
prompt_messages: list[PromptMessage] = [
|
||||
SystemPromptMessage(content=system_prompt),
|
||||
]
|
||||
|
||||
# Get model instance - use default if model_config not provided
|
||||
model_manager = ModelManager()
|
||||
if model_config:
|
||||
provider = model_config.get("provider", "")
|
||||
model_name = model_config.get("name", "")
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
else:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
model_name = model_instance.model
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
if not model_schema:
|
||||
return {"questions": [], "error": f"Model schema not found for {model_name}"}
|
||||
|
||||
completion_params = model_config.get("completion_params", {}) if model_config else {}
|
||||
model_parameters = {**completion_params, "max_tokens": 256}
|
||||
try:
|
||||
response = invoke_llm_with_pydantic_model(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
output_model=SuggestedQuestionsOutput,
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
questions = response.structured_output.get("questions", []) if response.structured_output else []
|
||||
return {"questions": questions, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
return {"questions": [], "error": str(e)}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate suggested questions, model: %s", model_name)
|
||||
return {"questions": [], "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@classmethod
|
||||
def _build_suggested_questions_prompt(
|
||||
cls,
|
||||
upstream_nodes: list[dict],
|
||||
current_node: dict,
|
||||
parameter_info: dict,
|
||||
language: str = "English",
|
||||
) -> str:
|
||||
"""Build minimal prompt for suggested questions generation."""
|
||||
# Simplify upstream nodes to reduce tokens
|
||||
sources = [f"{n['title']}({','.join(n.get('outputs', {}).keys())})" for n in upstream_nodes[:5]]
|
||||
param_type = parameter_info.get("type", "string")
|
||||
param_desc = parameter_info.get("description", "")[:100]
|
||||
|
||||
return f"""Suggest 3 code generation questions for extracting data.
|
||||
Sources: {", ".join(sources)}
|
||||
Target: {parameter_info.get("name")}({param_type}) - {param_desc}
|
||||
Output 3 short, practical questions in {language}."""
|
||||
|
||||
@classmethod
|
||||
def _get_upstream_nodes(cls, graph_dict: Mapping[str, Any], node_id: str) -> list[dict]:
|
||||
"""
|
||||
Get all upstream nodes via edge backtracking.
|
||||
|
||||
Traverses the graph backwards from node_id to collect all reachable nodes.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
nodes = {n["id"]: n for n in graph_dict.get("nodes", [])}
|
||||
edges = graph_dict.get("edges", [])
|
||||
|
||||
# Build reverse adjacency list
|
||||
reverse_adj: dict[str, list[str]] = defaultdict(list)
|
||||
for edge in edges:
|
||||
reverse_adj[edge["target"]].append(edge["source"])
|
||||
|
||||
# BFS to find all upstream nodes
|
||||
visited: set[str] = set()
|
||||
queue = [node_id]
|
||||
upstream: list[dict] = []
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for source in reverse_adj.get(current, []):
|
||||
if source not in visited:
|
||||
visited.add(source)
|
||||
queue.append(source)
|
||||
if source in nodes:
|
||||
upstream.append(cls._extract_node_info(nodes[source]))
|
||||
|
||||
return upstream
|
||||
|
||||
@classmethod
|
||||
def _get_node_by_id(cls, graph_dict: Mapping[str, Any], node_id: str) -> dict | None:
|
||||
"""Get node by ID from graph."""
|
||||
for node in graph_dict.get("nodes", []):
|
||||
if node["id"] == node_id:
|
||||
return node
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_node_info(cls, node: dict) -> dict:
|
||||
"""Extract minimal node info with outputs based on node type."""
|
||||
node_type = node["data"]["type"]
|
||||
node_data = node.get("data", {})
|
||||
|
||||
# Build outputs based on node type (only type, no description to reduce tokens)
|
||||
outputs: dict[str, str] = {}
|
||||
match node_type:
|
||||
case "start":
|
||||
for var in node_data.get("variables", []):
|
||||
name = var.get("variable", var.get("name", ""))
|
||||
outputs[name] = var.get("type", "string")
|
||||
case "llm":
|
||||
outputs["text"] = "string"
|
||||
case "code":
|
||||
for name, output in node_data.get("outputs", {}).items():
|
||||
outputs[name] = output.get("type", "string")
|
||||
case "http-request":
|
||||
outputs = {"body": "string", "status_code": "number", "headers": "object"}
|
||||
case "knowledge-retrieval":
|
||||
outputs["result"] = "array[object]"
|
||||
case "tool":
|
||||
outputs = {"text": "string", "json": "object"}
|
||||
case _:
|
||||
outputs["output"] = "string"
|
||||
|
||||
info: dict = {
|
||||
"id": node["id"],
|
||||
"title": node_data.get("title", node["id"]),
|
||||
"outputs": outputs,
|
||||
}
|
||||
# Only include description if not empty
|
||||
desc = node_data.get("desc", "")
|
||||
if desc:
|
||||
info["desc"] = desc
|
||||
|
||||
return info
|
||||
|
||||
@classmethod
|
||||
def _get_parameter_info(cls, tenant_id: str, node_data: dict, parameter_name: str) -> dict:
|
||||
"""Get parameter info from tool schema using ToolManager."""
|
||||
default_info = {"name": parameter_name, "type": "string", "description": ""}
|
||||
|
||||
if node_data.get("type") != "tool":
|
||||
return default_info
|
||||
|
||||
try:
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
provider_type_str = node_data.get("provider_type", "")
|
||||
provider_type = ToolProviderType(provider_type_str) if provider_type_str else ToolProviderType.BUILT_IN
|
||||
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=provider_type,
|
||||
provider_id=node_data.get("provider_id", ""),
|
||||
tool_name=node_data.get("tool_name", ""),
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
parameters = tool_runtime.get_merged_runtime_parameters()
|
||||
for param in parameters:
|
||||
if param.name == parameter_name:
|
||||
return {
|
||||
"name": param.name,
|
||||
"type": param.type.value if hasattr(param.type, "value") else str(param.type),
|
||||
"description": param.llm_description
|
||||
or (param.human_description.en_US if param.human_description else ""),
|
||||
"required": param.required,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug("Failed to get parameter info from ToolManager: %s", e)
|
||||
|
||||
return default_info
|
||||
|
||||
@classmethod
|
||||
def _build_extractor_system_prompt(
|
||||
cls,
|
||||
upstream_nodes: list[dict],
|
||||
current_node: dict,
|
||||
parameter_info: dict,
|
||||
language: str,
|
||||
) -> str:
|
||||
"""Build system prompt for extractor code generation."""
|
||||
upstream_json = json.dumps(upstream_nodes, indent=2, ensure_ascii=False)
|
||||
param_type = parameter_info.get("type", "string")
|
||||
return f"""You are a code generator for workflow automation.
|
||||
|
||||
Generate {language} code to extract/transform upstream node outputs for the target parameter.
|
||||
|
||||
## Upstream Nodes
|
||||
{upstream_json}
|
||||
|
||||
## Target
|
||||
Node: {current_node["data"].get("title", current_node["id"])}
|
||||
Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("description", "")}
|
||||
|
||||
## Requirements
|
||||
- Write a main function that returns type: {param_type}
|
||||
- Use value_selector format: ["node_id", "output_name"]
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _parse_code_node_output(cls, content: Mapping[str, Any] | None, language: str, parameter_type: str) -> dict:
|
||||
"""
|
||||
Parse structured output to CodeNodeData format.
|
||||
|
||||
Args:
|
||||
content: Structured output dict from invoke_llm_with_structured_output
|
||||
language: Code language
|
||||
parameter_type: Expected parameter type
|
||||
|
||||
Returns dict with variables, code_language, code, outputs, message, error.
|
||||
"""
|
||||
if content is None:
|
||||
return cls._error_response("Empty or invalid response from LLM")
|
||||
|
||||
# Validate and normalize variables
|
||||
variables = [
|
||||
{"variable": v.get("variable", ""), "value_selector": v.get("value_selector", [])}
|
||||
for v in content.get("variables", [])
|
||||
if isinstance(v, dict)
|
||||
]
|
||||
|
||||
outputs = content.get("outputs", {"result": {"type": parameter_type}})
|
||||
|
||||
return {
|
||||
"variables": variables,
|
||||
"code_language": language,
|
||||
"code": content.get("code", ""),
|
||||
"outputs": outputs,
|
||||
"message": content.get("explanation", ""),
|
||||
"error": "",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def instruction_modify_legacy(
|
||||
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
|
||||
@ -529,6 +960,10 @@ class LLMGenerator:
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
)
|
||||
model_name = model_config.get("name", "")
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
if not model_schema:
|
||||
return {"error": f"Model schema not found for {model_name}"}
|
||||
match node_type:
|
||||
case "llm" | "agent":
|
||||
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
|
||||
@ -552,20 +987,18 @@ class LLMGenerator:
|
||||
model_parameters = {"temperature": 0.4}
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
|
||||
|
||||
generated_raw = response.message.get_text_content()
|
||||
first_brace = generated_raw.find("{")
|
||||
last_brace = generated_raw.rfind("}")
|
||||
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
|
||||
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace : last_brace + 1]
|
||||
data = json_repair.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
|
||||
return data
|
||||
response = invoke_llm_with_pydantic_model(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=list(prompt_messages),
|
||||
output_model=InstructionModifyOutput,
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
)
|
||||
return response.structured_output or {}
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"error": f"Failed to generate code. Error: {error}"}
|
||||
|
||||
34
api/core/llm_generator/output_models.py
Normal file
34
api/core/llm_generator/output_models.py
Normal file
@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
class SuggestedQuestionsOutput(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
questions: list[str] = Field(min_length=3, max_length=3)
|
||||
|
||||
|
||||
class CodeNodeOutput(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: SegmentType
|
||||
|
||||
|
||||
class CodeNodeStructuredOutput(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variables: list[VariableSelector]
|
||||
code: str
|
||||
outputs: dict[str, CodeNodeOutput]
|
||||
explanation: str
|
||||
|
||||
|
||||
class InstructionModifyOutput(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
modified: str
|
||||
message: str
|
||||
188
api/core/llm_generator/output_parser/file_ref.py
Normal file
188
api/core/llm_generator/output_parser/file_ref.py
Normal file
@ -0,0 +1,188 @@
|
||||
"""
|
||||
File reference detection and conversion for structured output.
|
||||
|
||||
This module provides utilities to:
|
||||
1. Detect file reference fields in JSON Schema (format: "dify-file-ref")
|
||||
2. Convert file ID strings to File objects after LLM returns
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.file import File
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from factories.file_factory import build_from_mapping
|
||||
|
||||
FILE_REF_FORMAT = "dify-file-ref"
|
||||
|
||||
|
||||
def is_file_ref_property(schema: dict) -> bool:
|
||||
"""Check if a schema property is a file reference."""
|
||||
return schema.get("type") == "string" and schema.get("format") == FILE_REF_FORMAT
|
||||
|
||||
|
||||
def detect_file_ref_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
|
||||
"""
|
||||
Recursively detect file reference fields in schema.
|
||||
|
||||
Args:
|
||||
schema: JSON Schema to analyze
|
||||
path: Current path in the schema (used for recursion)
|
||||
|
||||
Returns:
|
||||
List of JSON paths containing file refs, e.g., ["image_id", "files[*]"]
|
||||
"""
|
||||
file_ref_paths: list[str] = []
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
for prop_name, prop_schema in schema.get("properties", {}).items():
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_ref_property(prop_schema):
|
||||
file_ref_paths.append(current_path)
|
||||
elif isinstance(prop_schema, dict):
|
||||
file_ref_paths.extend(detect_file_ref_fields(prop_schema, current_path))
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items", {})
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_ref_property(items_schema):
|
||||
file_ref_paths.append(array_path)
|
||||
elif isinstance(items_schema, dict):
|
||||
file_ref_paths.extend(detect_file_ref_fields(items_schema, array_path))
|
||||
|
||||
return file_ref_paths
|
||||
|
||||
|
||||
def convert_file_refs_in_output(
|
||||
output: Mapping[str, Any],
|
||||
json_schema: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert file ID strings to File objects based on schema.
|
||||
|
||||
Args:
|
||||
output: The structured_output from LLM result
|
||||
json_schema: The original JSON schema (to detect file ref fields)
|
||||
tenant_id: Tenant ID for file lookup
|
||||
|
||||
Returns:
|
||||
Output with file references converted to File objects
|
||||
"""
|
||||
file_ref_paths = detect_file_ref_fields(json_schema)
|
||||
if not file_ref_paths:
|
||||
return dict(output)
|
||||
|
||||
result = _deep_copy_dict(output)
|
||||
|
||||
for path in file_ref_paths:
|
||||
_convert_path_in_place(result, path.split("."), tenant_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _deep_copy_dict(obj: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Deep copy a mapping to a mutable dict."""
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in obj.items():
|
||||
if isinstance(value, Mapping):
|
||||
result[key] = _deep_copy_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [_deep_copy_dict(item) if isinstance(item, Mapping) else item for item in value]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def _convert_path_in_place(obj: dict, path_parts: list[str], tenant_id: str) -> None:
|
||||
"""Convert file refs at the given path in place, wrapping in Segment types."""
|
||||
if not path_parts:
|
||||
return
|
||||
|
||||
current = path_parts[0]
|
||||
remaining = path_parts[1:]
|
||||
|
||||
# Handle array notation like "files[*]"
|
||||
if current.endswith("[*]"):
|
||||
key = current[:-3] if current != "[*]" else None
|
||||
target = obj.get(key) if key else obj
|
||||
|
||||
if isinstance(target, list):
|
||||
if remaining:
|
||||
# Nested array with remaining path - recurse into each item
|
||||
for item in target:
|
||||
if isinstance(item, dict):
|
||||
_convert_path_in_place(item, remaining, tenant_id)
|
||||
else:
|
||||
# Array of file IDs - convert all and wrap in ArrayFileSegment
|
||||
files: list[File] = []
|
||||
for item in target:
|
||||
file = _convert_file_id(item, tenant_id)
|
||||
if file is not None:
|
||||
files.append(file)
|
||||
# Replace the array with ArrayFileSegment
|
||||
if key:
|
||||
obj[key] = ArrayFileSegment(value=files)
|
||||
return
|
||||
|
||||
if not remaining:
|
||||
# Leaf node - convert the value and wrap in FileSegment
|
||||
if current in obj:
|
||||
file = _convert_file_id(obj[current], tenant_id)
|
||||
if file is not None:
|
||||
obj[current] = FileSegment(value=file)
|
||||
else:
|
||||
obj[current] = None
|
||||
else:
|
||||
# Recurse into nested object
|
||||
if current in obj and isinstance(obj[current], dict):
|
||||
_convert_path_in_place(obj[current], remaining, tenant_id)
|
||||
|
||||
|
||||
def _convert_file_id(file_id: Any, tenant_id: str) -> File | None:
|
||||
"""
|
||||
Convert a file ID string to a File object.
|
||||
|
||||
Tries multiple file sources in order:
|
||||
1. ToolFile (files generated by tools/workflows)
|
||||
2. UploadFile (files uploaded by users)
|
||||
"""
|
||||
if not isinstance(file_id, str):
|
||||
return None
|
||||
|
||||
# Validate UUID format
|
||||
try:
|
||||
uuid.UUID(file_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# Try ToolFile first (files generated by tools/workflows)
|
||||
try:
|
||||
return build_from_mapping(
|
||||
mapping={
|
||||
"transfer_method": "tool_file",
|
||||
"tool_file_id": file_id,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try UploadFile (files uploaded by users)
|
||||
try:
|
||||
return build_from_mapping(
|
||||
mapping={
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_id,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# File not found in any source
|
||||
return None
|
||||
@ -2,12 +2,13 @@ import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, cast, overload
|
||||
from typing import Any, Literal, TypeVar, cast, overload
|
||||
|
||||
import json_repair
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from pydantic import BaseModel, TypeAdapter, ValidationError
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
|
||||
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
@ -43,6 +44,9 @@ class SpecialModelType(StrEnum):
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
@ -57,6 +61,7 @@ def invoke_llm_with_structured_output(
|
||||
stream: Literal[True],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
@ -72,6 +77,7 @@ def invoke_llm_with_structured_output(
|
||||
stream: Literal[False],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
@ -87,6 +93,7 @@ def invoke_llm_with_structured_output(
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
@ -101,23 +108,30 @@ def invoke_llm_with_structured_output(
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
"""
|
||||
Invoke large language model with structured output
|
||||
1. This method invokes model_instance.invoke_llm with json_schema
|
||||
2. Try to parse the result as structured output
|
||||
Invoke large language model with structured output.
|
||||
|
||||
This method invokes model_instance.invoke_llm with json_schema and parses
|
||||
the result as structured output.
|
||||
|
||||
:param provider: model provider name
|
||||
:param model_schema: model schema entity
|
||||
:param model_instance: model instance to invoke
|
||||
:param prompt_messages: prompt messages
|
||||
:param json_schema: json schema
|
||||
:param json_schema: json schema for structured output
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:param tenant_id: tenant ID for file reference conversion. When provided and
|
||||
json_schema contains file reference fields (format: "dify-file-ref"),
|
||||
file IDs in the output will be automatically converted to File objects.
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
# handle native json schema
|
||||
model_parameters_with_json_schema: dict[str, Any] = {
|
||||
**(model_parameters or {}),
|
||||
@ -153,8 +167,18 @@ def invoke_llm_with_structured_output(
|
||||
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
|
||||
)
|
||||
|
||||
structured_output = _parse_structured_output(llm_result.message.content)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(llm_result.message.content),
|
||||
structured_output=structured_output,
|
||||
model=llm_result.model,
|
||||
message=llm_result.message,
|
||||
usage=llm_result.usage,
|
||||
@ -186,8 +210,18 @@ def invoke_llm_with_structured_output(
|
||||
delta=event.delta,
|
||||
)
|
||||
|
||||
structured_output = _parse_structured_output(result_text)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(result_text),
|
||||
structured_output=structured_output,
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
@ -202,6 +236,87 @@ def invoke_llm_with_structured_output(
|
||||
return generator()
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_pydantic_model(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
output_model: type[T],
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[False] = False,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
|
||||
def invoke_llm_with_pydantic_model(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
output_model: type[T],
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput:
|
||||
"""
|
||||
Invoke large language model with a Pydantic output model.
|
||||
|
||||
This helper generates a JSON schema from the Pydantic model, invokes the
|
||||
structured-output LLM path, and validates the result in non-streaming mode.
|
||||
"""
|
||||
if stream:
|
||||
raise ValueError("invoke_llm_with_pydantic_model only supports stream=False")
|
||||
|
||||
json_schema = _schema_from_pydantic(output_model)
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
structured_output = result.structured_output
|
||||
if structured_output is None:
|
||||
raise OutputParserError("Structured output is empty")
|
||||
|
||||
validated_output = _validate_structured_output(output_model, structured_output)
|
||||
return result.model_copy(update={"structured_output": validated_output})
|
||||
|
||||
|
||||
def _schema_from_pydantic(output_model: type[BaseModel]) -> dict[str, Any]:
|
||||
return output_model.model_json_schema()
|
||||
|
||||
|
||||
def _validate_structured_output(
|
||||
output_model: type[T],
|
||||
structured_output: Mapping[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
validated_output = output_model.model_validate(structured_output)
|
||||
except ValidationError as exc:
|
||||
raise OutputParserError(f"Structured output validation failed: {exc}") from exc
|
||||
return validated_output.model_dump(mode="python")
|
||||
|
||||
|
||||
def _handle_native_json_schema(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
|
||||
45
api/core/llm_generator/utils.py
Normal file
45
api/core/llm_generator/utils.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""Utility functions for LLM generator."""
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_prompt_messages(messages: list[dict]) -> list[PromptMessage]:
|
||||
"""
|
||||
Deserialize list of dicts to list[PromptMessage].
|
||||
|
||||
Expected format:
|
||||
[
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."},
|
||||
]
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
for msg in messages:
|
||||
role = PromptMessageRole.value_of(msg["role"])
|
||||
content = msg.get("content", "")
|
||||
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
result.append(UserPromptMessage(content=content))
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
result.append(AssistantPromptMessage(content=content))
|
||||
case PromptMessageRole.SYSTEM:
|
||||
result.append(SystemPromptMessage(content=content))
|
||||
case PromptMessageRole.TOOL:
|
||||
result.append(ToolPromptMessage(content=content, tool_call_id=msg.get("tool_call_id", "")))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def serialize_prompt_messages(messages: list[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Serialize list[PromptMessage] to list of dicts.
|
||||
"""
|
||||
return [{"role": msg.role.value, "content": msg.content} for msg in messages]
|
||||
267
api/core/memory/README.md
Normal file
267
api/core/memory/README.md
Normal file
@ -0,0 +1,267 @@
|
||||
# Memory Module
|
||||
|
||||
This module provides memory management for LLM conversations, enabling context retention across dialogue turns.
|
||||
|
||||
## Overview
|
||||
|
||||
The memory module contains two types of memory implementations:
|
||||
|
||||
1. **TokenBufferMemory** - Conversation-level memory (existing)
|
||||
2. **NodeTokenBufferMemory** - Node-level memory (**Chatflow only**)
|
||||
|
||||
> **Note**: `NodeTokenBufferMemory` is only available in **Chatflow** (advanced-chat mode).
|
||||
> This is because it requires both `conversation_id` and `node_id`, which are only present in Chatflow.
|
||||
> Standard Workflow mode does not have `conversation_id` and therefore cannot use node-level memory.
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ Memory Architecture │
|
||||
├─────────────────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
|
||||
│ │ TokenBufferMemory │ │
|
||||
│ │ Scope: Conversation │ │
|
||||
│ │ Storage: Database (Message table) │ │
|
||||
│ │ Key: conversation_id │ │
|
||||
│ └─────────────────────────────────────────────────────────────────────-┘ │
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
|
||||
│ │ NodeTokenBufferMemory │ │
|
||||
│ │ Scope: Node within Conversation │ │
|
||||
│ │ Storage: WorkflowNodeExecutionModel.outputs["context"] │ │
|
||||
│ │ Key: (conversation_id, node_id, workflow_run_id) │ │
|
||||
│ └─────────────────────────────────────────────────────────────────────-┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## TokenBufferMemory (Existing)
|
||||
|
||||
### Purpose
|
||||
|
||||
`TokenBufferMemory` retrieves conversation history from the `Message` table and converts it to `PromptMessage` objects for LLM context.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Conversation-scoped**: All messages within a conversation are candidates
|
||||
- **Thread-aware**: Uses `parent_message_id` to extract only the current thread (supports regeneration scenarios)
|
||||
- **Token-limited**: Truncates history to fit within `max_token_limit`
|
||||
- **File support**: Handles `MessageFile` attachments (images, documents, etc.)
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
Message Table TokenBufferMemory LLM
|
||||
│ │ │
|
||||
│ SELECT * FROM messages │ │
|
||||
│ WHERE conversation_id = ? │ │
|
||||
│ ORDER BY created_at DESC │ │
|
||||
├─────────────────────────────────▶│ │
|
||||
│ │ │
|
||||
│ extract_thread_messages() │
|
||||
│ │ │
|
||||
│ build_prompt_message_with_files() │
|
||||
│ │ │
|
||||
│ truncate by max_token_limit │
|
||||
│ │ │
|
||||
│ │ Sequence[PromptMessage]
|
||||
│ ├───────────────────────▶│
|
||||
│ │ │
|
||||
```
|
||||
|
||||
### Thread Extraction
|
||||
|
||||
When a user regenerates a response, a new thread is created:
|
||||
|
||||
```
|
||||
Message A (user)
|
||||
└── Message A' (assistant)
|
||||
└── Message B (user)
|
||||
└── Message B' (assistant)
|
||||
└── Message A'' (assistant, regenerated) ← New thread
|
||||
└── Message C (user)
|
||||
└── Message C' (assistant)
|
||||
```
|
||||
|
||||
`extract_thread_messages()` traces back from the latest message using `parent_message_id` to get only the current thread: `[A, A'', C, C']`
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
history = memory.get_history_prompt_messages(max_token_limit=2000, message_limit=100)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## NodeTokenBufferMemory
|
||||
|
||||
### Purpose
|
||||
|
||||
`NodeTokenBufferMemory` provides **node-scoped memory** within a conversation. Each LLM node in a workflow can maintain its own independent conversation history.
|
||||
|
||||
### Use Cases
|
||||
|
||||
1. **Multi-LLM Workflows**: Different LLM nodes need separate context
|
||||
2. **Iterative Processing**: An LLM node in a loop needs to accumulate context across iterations
|
||||
3. **Specialized Agents**: Each agent node maintains its own dialogue history
|
||||
|
||||
### Design: Zero Extra Storage
|
||||
|
||||
**Key insight**: LLM node already saves complete context in `outputs["context"]`.
|
||||
|
||||
Each LLM node execution outputs:
|
||||
```python
|
||||
outputs = {
|
||||
"text": clean_text,
|
||||
"context": self._build_context(prompt_messages, clean_text), # Complete dialogue history!
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
This `outputs["context"]` contains:
|
||||
- All previous user/assistant messages (excluding system prompt)
|
||||
- The current assistant response
|
||||
|
||||
**No separate storage needed** - we just read from the last execution's `outputs["context"]`.
|
||||
|
||||
### Benefits
|
||||
|
||||
| Aspect | Old Design (Object Storage) | New Design (outputs["context"]) |
|
||||
|--------|----------------------------|--------------------------------|
|
||||
| Storage | Separate JSON file | Already in WorkflowNodeExecutionModel |
|
||||
| Concurrency | Race condition risk | No issue (each execution is INSERT) |
|
||||
| Cleanup | Need separate cleanup task | Follows node execution lifecycle |
|
||||
| Migration | Required | None |
|
||||
| Complexity | High | Low |
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
WorkflowNodeExecutionModel NodeTokenBufferMemory LLM Node
|
||||
│ │ │
|
||||
│ │◀── get_history_prompt_messages()
|
||||
│ │ │
|
||||
│ SELECT outputs FROM │ │
|
||||
│ workflow_node_executions │ │
|
||||
│ WHERE workflow_run_id = ? │ │
|
||||
│ AND node_id = ? │ │
|
||||
│◀─────────────────────────────────┤ │
|
||||
│ │ │
|
||||
│ outputs["context"] │ │
|
||||
├─────────────────────────────────▶│ │
|
||||
│ │ │
|
||||
│ deserialize PromptMessages │
|
||||
│ │ │
|
||||
│ truncate by max_token_limit │
|
||||
│ │ │
|
||||
│ │ Sequence[PromptMessage] │
|
||||
│ ├──────────────────────────▶│
|
||||
│ │ │
|
||||
```
|
||||
|
||||
### Thread Tracking
|
||||
|
||||
Thread extraction still uses `Message` table's `parent_message_id` structure:
|
||||
|
||||
1. Query `Message` table for conversation → get thread's `workflow_run_ids`
|
||||
2. Get the last completed `workflow_run_id` in the thread
|
||||
3. Query `WorkflowNodeExecutionModel` for that execution's `outputs["context"]`
|
||||
|
||||
### API
|
||||
|
||||
```python
|
||||
class NodeTokenBufferMemory:
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
"""Initialize node-level memory."""
|
||||
...
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
|
||||
Reads from last completed execution's outputs["context"].
|
||||
"""
|
||||
...
|
||||
|
||||
# Legacy methods (no-op, kept for compatibility)
|
||||
def add_messages(self, *args, **kwargs) -> None: pass
|
||||
def flush(self) -> None: pass
|
||||
def clear(self) -> None: pass
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Add to `MemoryConfig` in `core/workflow/nodes/llm/entities.py`:
|
||||
|
||||
```python
|
||||
class MemoryMode(StrEnum):
|
||||
CONVERSATION = "conversation" # Use TokenBufferMemory (default)
|
||||
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: MemoryWindowConfig | None = None
|
||||
query_prompt_template: str | None = None
|
||||
mode: MemoryMode = MemoryMode.CONVERSATION
|
||||
```
|
||||
|
||||
**Mode Behavior:**
|
||||
|
||||
| Mode | Memory Class | Scope | Availability |
|
||||
| -------------- | --------------------- | ------------------------ | ------------- |
|
||||
| `conversation` | TokenBufferMemory | Entire conversation | All app modes |
|
||||
| `node` | NodeTokenBufferMemory | Per-node in conversation | Chatflow only |
|
||||
|
||||
> When `mode=node` is used in a non-Chatflow context (no conversation_id), it falls back to no memory.
|
||||
|
||||
---
|
||||
|
||||
## Comparison
|
||||
|
||||
| Feature | TokenBufferMemory | NodeTokenBufferMemory |
|
||||
| -------------- | ------------------------ | ---------------------------------- |
|
||||
| Scope | Conversation | Node within Conversation |
|
||||
| Storage | Database (Message table) | WorkflowNodeExecutionModel.outputs |
|
||||
| Thread Support | Yes | Yes |
|
||||
| File Support | Yes (via MessageFile) | Yes (via context serialization) |
|
||||
| Token Limit | Yes | Yes |
|
||||
| Use Case | Standard chat apps | Complex workflows |
|
||||
|
||||
---
|
||||
|
||||
## Extending to Other Nodes
|
||||
|
||||
Currently only **LLM Node** outputs `context` in its outputs. To enable node memory for other nodes:
|
||||
|
||||
1. Add `outputs["context"] = self._build_context(prompt_messages, response)` in the node
|
||||
2. The `NodeTokenBufferMemory` will automatically pick it up
|
||||
|
||||
Nodes that could potentially support this:
|
||||
- `question_classifier`
|
||||
- `parameter_extractor`
|
||||
- `agent`
|
||||
|
||||
---
|
||||
|
||||
## Future Considerations
|
||||
|
||||
1. **Cleanup**: Node memory lifecycle follows `WorkflowNodeExecutionModel`, which already has cleanup mechanisms
|
||||
2. **Compression**: For very long conversations, consider summarization strategies
|
||||
3. **Extension**: Other nodes may benefit from node-level memory
|
||||
11
api/core/memory/__init__.py
Normal file
11
api/core/memory/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import (
|
||||
NodeTokenBufferMemory,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
||||
__all__ = [
|
||||
"BaseMemory",
|
||||
"NodeTokenBufferMemory",
|
||||
"TokenBufferMemory",
|
||||
]
|
||||
83
api/core/memory/base.py
Normal file
83
api/core/memory/base.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""
|
||||
Base memory interfaces and types.
|
||||
|
||||
This module defines the common protocol for memory implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, PromptMessage
|
||||
|
||||
|
||||
class BaseMemory(ABC):
|
||||
"""
|
||||
Abstract base class for memory implementations.
|
||||
|
||||
Provides a common interface for both conversation-level and node-level memory.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Sequence of PromptMessage for LLM context
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt as formatted text.
|
||||
|
||||
:param human_prefix: Prefix for human messages
|
||||
:param ai_prefix: Prefix for assistant messages
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Formatted history text
|
||||
"""
|
||||
from core.model_runtime.entities import (
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
prompt_messages = self.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=message_limit,
|
||||
)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
197
api/core/memory/node_token_buffer_memory.py
Normal file
197
api/core/memory/node_token_buffer_memory.py
Normal file
@ -0,0 +1,197 @@
|
||||
"""
|
||||
Node-level Token Buffer Memory for Chatflow.
|
||||
|
||||
This module provides node-scoped memory within a conversation.
|
||||
Each LLM node in a workflow can maintain its own independent conversation history.
|
||||
|
||||
Note: This is only available in Chatflow (advanced-chat mode) because it requires
|
||||
both conversation_id and node_id.
|
||||
|
||||
Design:
|
||||
- History is read directly from WorkflowNodeExecutionModel.outputs["context"]
|
||||
- No separate storage needed - the context is already saved during node execution
|
||||
- Thread tracking leverages Message table's parent_message_id structure
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.file import file_manager
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeTokenBufferMemory(BaseMemory):
|
||||
"""
|
||||
Node-level Token Buffer Memory.
|
||||
|
||||
Provides node-scoped memory within a conversation. Each LLM node can maintain
|
||||
its own independent conversation history.
|
||||
|
||||
Key design: History is read directly from WorkflowNodeExecutionModel.outputs["context"],
|
||||
which is already saved during node execution. No separate storage needed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.node_id = node_id
|
||||
self.tenant_id = tenant_id
|
||||
self.model_instance = model_instance
|
||||
|
||||
def _get_thread_workflow_run_ids(self) -> list[str]:
|
||||
"""
|
||||
Get workflow_run_ids for the current thread by querying Message table.
|
||||
Returns workflow_run_ids in chronological order (oldest first).
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == self.conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(500)
|
||||
)
|
||||
messages = list(session.scalars(stmt).all())
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# Extract thread messages using existing logic
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# For newly created message, its answer is temporarily empty, skip it
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
|
||||
# Reverse to get chronological order, extract workflow_run_ids
|
||||
return [msg.workflow_run_id for msg in reversed(thread_messages) if msg.workflow_run_id]
|
||||
|
||||
def _deserialize_prompt_message(self, msg_dict: dict) -> PromptMessage:
|
||||
"""Deserialize a dict to PromptMessage based on role."""
|
||||
role = msg_dict.get("role")
|
||||
if role in (PromptMessageRole.USER, "user"):
|
||||
return UserPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
|
||||
return AssistantPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.SYSTEM, "system"):
|
||||
return SystemPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.TOOL, "tool"):
|
||||
return ToolPromptMessage.model_validate(msg_dict)
|
||||
else:
|
||||
return PromptMessage.model_validate(msg_dict)
|
||||
|
||||
def _deserialize_context(self, context_data: list[dict]) -> list[PromptMessage]:
|
||||
"""Deserialize context data from outputs to list of PromptMessage."""
|
||||
messages = []
|
||||
for msg_dict in context_data:
|
||||
try:
|
||||
msg = self._deserialize_prompt_message(msg_dict)
|
||||
msg = self._restore_multimodal_content(msg)
|
||||
messages.append(msg)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to deserialize prompt message: %s", e)
|
||||
return messages
|
||||
|
||||
def _restore_multimodal_content(self, message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Restore multimodal content (base64 or url) from file_ref.
|
||||
|
||||
When context is saved, base64_data is cleared to save storage space.
|
||||
This method restores the content by parsing file_ref (format: "method:id_or_url").
|
||||
"""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
# Process list content, restoring multimodal data from file references
|
||||
restored_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
# restore_multimodal_content preserves the concrete subclass type
|
||||
restored_item = file_manager.restore_multimodal_content(item)
|
||||
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
|
||||
else:
|
||||
restored_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": restored_content})
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
History is read directly from the last completed node execution's outputs["context"].
|
||||
"""
|
||||
_ = message_limit # unused, kept for interface compatibility
|
||||
|
||||
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
|
||||
if not thread_workflow_run_ids:
|
||||
return []
|
||||
|
||||
# Get the last completed workflow_run_id (contains accumulated context)
|
||||
last_run_id = thread_workflow_run_ids[-1]
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == last_run_id,
|
||||
WorkflowNodeExecutionModel.node_id == self.node_id,
|
||||
WorkflowNodeExecutionModel.status == "succeeded",
|
||||
)
|
||||
execution = session.scalars(stmt).first()
|
||||
|
||||
if not execution:
|
||||
return []
|
||||
|
||||
outputs = execution.outputs_dict
|
||||
if not outputs:
|
||||
return []
|
||||
|
||||
context_data = outputs.get("context")
|
||||
|
||||
if not context_data or not isinstance(context_data, list):
|
||||
return []
|
||||
|
||||
prompt_messages = self._deserialize_context(context_data)
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
# Truncate by token limit
|
||||
try:
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
while current_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
prompt_messages.pop(0)
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to count tokens for truncation: %s", e)
|
||||
|
||||
return prompt_messages
|
||||
@ -5,12 +5,12 @@ from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
@ -24,7 +24,7 @@ from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class TokenBufferMemory:
|
||||
class TokenBufferMemory(BaseMemory):
|
||||
def __init__(
|
||||
self,
|
||||
conversation: Conversation,
|
||||
@ -115,10 +115,14 @@ class TokenBufferMemory:
|
||||
return AssistantPromptMessage(content=prompt_message_contents)
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: int | None = None
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
"""
|
||||
@ -200,44 +204,3 @@ class TokenBufferMemory:
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt text.
|
||||
:param human_prefix: human prefix
|
||||
:param ai_prefix: ai prefix
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
||||
@ -91,6 +91,9 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
|
||||
filename: str = Field(default="", description="the filename of multi-modal file")
|
||||
|
||||
# File reference for context restoration, format: "transfer_method:related_id" or "remote:url"
|
||||
file_ref: str | None = Field(default=None, description="Encoded file reference for restoration")
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
|
||||
@ -276,7 +279,5 @@ class ToolPromptMessage(PromptMessage):
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_call_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
# ToolPromptMessage is not empty if it has content OR has a tool_call_id
|
||||
return super().is_empty() and not self.tool_call_id
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
@ -41,6 +42,43 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
|
||||
response, user_id=user_id, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return response
|
||||
return cls._sign_tool_file_urls(response)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
# FIXME: this method should be gracefully deprecated
|
||||
@classmethod
|
||||
def _sign_tool_file_urls(
|
||||
cls, messages: Generator[ToolInvokeMessage, None, None]
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Sign file URLs in tool invoke messages for external access.
|
||||
"""
|
||||
for message in messages:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.FILE,
|
||||
}:
|
||||
if isinstance(message.message, ToolInvokeMessage.TextMessage):
|
||||
url = message.message.text
|
||||
# Check if it's an unsigned internal path
|
||||
if url.startswith("/files/tools/"):
|
||||
parts = url.split("/")[-1]
|
||||
if "." in parts:
|
||||
file_id, ext = parts.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
else:
|
||||
file_id = parts
|
||||
extension = ".bin"
|
||||
|
||||
signed_url = sign_tool_file(tool_file_id=file_id, extension=extension)
|
||||
|
||||
yield ToolInvokeMessage(
|
||||
type=message.type,
|
||||
message=ToolInvokeMessage.TextMessage(text=signed_url),
|
||||
meta=message.meta,
|
||||
)
|
||||
continue
|
||||
|
||||
yield message
|
||||
|
||||
@ -282,3 +282,11 @@ class TriggerDispatchResponse(BaseModel):
|
||||
return deserialize_response(binascii.unhexlify(v.encode()))
|
||||
except Exception as e:
|
||||
raise ValueError("Failed to deserialize response from hex string") from e
|
||||
|
||||
|
||||
class RequestListTools(BaseModel):
|
||||
"""
|
||||
Request to list all available tools
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@ -5,7 +5,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.file import file_manager
|
||||
from core.file.models import File
|
||||
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@ -43,7 +43,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -84,7 +84,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -145,7 +145,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -270,7 +270,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
def _set_histories_variable(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory: BaseMemory,
|
||||
memory_config: MemoryConfig,
|
||||
raw_prompt: str,
|
||||
role_prefix: MemoryConfig.RolePrefix,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -5,6 +6,13 @@ from pydantic import BaseModel
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
|
||||
|
||||
class MemoryMode(StrEnum):
|
||||
"""Memory mode for LLM nodes."""
|
||||
|
||||
CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior)
|
||||
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
|
||||
|
||||
|
||||
class ChatModelMessage(BaseModel):
|
||||
"""
|
||||
Chat Message.
|
||||
@ -48,3 +56,4 @@ class MemoryConfig(BaseModel):
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: str | None = None
|
||||
mode: MemoryMode = MemoryMode.CONVERSATION
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
@ -11,7 +11,7 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
class PromptTransform:
|
||||
def _append_chat_histories(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory: BaseMemory,
|
||||
memory_config: MemoryConfig,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
@ -52,7 +52,7 @@ class PromptTransform:
|
||||
|
||||
def _get_history_messages_from_memory(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory: BaseMemory,
|
||||
memory_config: MemoryConfig,
|
||||
max_token_limit: int,
|
||||
human_prefix: str | None = None,
|
||||
@ -73,7 +73,7 @@ class PromptTransform:
|
||||
return memory.get_history_prompt_text(**kwargs)
|
||||
|
||||
def _get_history_messages_list_from_memory(
|
||||
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
|
||||
self, memory: BaseMemory, memory_config: MemoryConfig, max_token_limit: int
|
||||
) -> list[PromptMessage]:
|
||||
"""Get memory messages."""
|
||||
return list(
|
||||
|
||||
@ -29,6 +29,7 @@ from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
LLMGenerationDetail,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
)
|
||||
@ -457,6 +458,113 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
session.merge(db_model)
|
||||
session.flush()
|
||||
|
||||
# Save LLMGenerationDetail for LLM nodes with successful execution
|
||||
if (
|
||||
domain_model.node_type == NodeType.LLM
|
||||
and domain_model.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
and domain_model.outputs is not None
|
||||
):
|
||||
self._save_llm_generation_detail(session, domain_model)
|
||||
|
||||
def _save_llm_generation_detail(self, session, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save LLM generation detail for LLM nodes.
|
||||
Extracts reasoning_content, tool_calls, and sequence from outputs and metadata.
|
||||
"""
|
||||
outputs = execution.outputs or {}
|
||||
metadata = execution.metadata or {}
|
||||
|
||||
reasoning_list = self._extract_reasoning(outputs)
|
||||
tool_calls_list = self._extract_tool_calls(metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG))
|
||||
|
||||
if not reasoning_list and not tool_calls_list:
|
||||
return
|
||||
|
||||
sequence = self._build_generation_sequence(outputs.get("text", ""), reasoning_list, tool_calls_list)
|
||||
self._upsert_generation_detail(session, execution, reasoning_list, tool_calls_list, sequence)
|
||||
|
||||
def _extract_reasoning(self, outputs: Mapping[str, Any]) -> list[str]:
|
||||
"""Extract reasoning_content as a clean list of non-empty strings."""
|
||||
reasoning_content = outputs.get("reasoning_content")
|
||||
if isinstance(reasoning_content, str):
|
||||
trimmed = reasoning_content.strip()
|
||||
return [trimmed] if trimmed else []
|
||||
if isinstance(reasoning_content, list):
|
||||
return [item.strip() for item in reasoning_content if isinstance(item, str) and item.strip()]
|
||||
return []
|
||||
|
||||
def _extract_tool_calls(self, agent_log: Any) -> list[dict[str, str]]:
|
||||
"""Extract tool call records from agent logs."""
|
||||
if not agent_log or not isinstance(agent_log, list):
|
||||
return []
|
||||
|
||||
tool_calls: list[dict[str, str]] = []
|
||||
for log in agent_log:
|
||||
log_data = log.data if hasattr(log, "data") else (log.get("data", {}) if isinstance(log, dict) else {})
|
||||
tool_name = log_data.get("tool_name")
|
||||
if tool_name and str(tool_name).strip():
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": log_data.get("tool_call_id", ""),
|
||||
"name": tool_name,
|
||||
"arguments": json.dumps(log_data.get("tool_args", {})),
|
||||
"result": str(log_data.get("output", "")),
|
||||
}
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
def _build_generation_sequence(
|
||||
self, text: str, reasoning_list: list[str], tool_calls_list: list[dict[str, str]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a simple content/reasoning/tool_call sequence."""
|
||||
sequence: list[dict[str, Any]] = []
|
||||
if text:
|
||||
sequence.append({"type": "content", "start": 0, "end": len(text)})
|
||||
for index in range(len(reasoning_list)):
|
||||
sequence.append({"type": "reasoning", "index": index})
|
||||
for index in range(len(tool_calls_list)):
|
||||
sequence.append({"type": "tool_call", "index": index})
|
||||
return sequence
|
||||
|
||||
def _upsert_generation_detail(
|
||||
self,
|
||||
session,
|
||||
execution: WorkflowNodeExecution,
|
||||
reasoning_list: list[str],
|
||||
tool_calls_list: list[dict[str, str]],
|
||||
sequence: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Insert or update LLMGenerationDetail with serialized fields."""
|
||||
existing = (
|
||||
session.query(LLMGenerationDetail)
|
||||
.filter_by(
|
||||
workflow_run_id=execution.workflow_execution_id,
|
||||
node_id=execution.node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
reasoning_json = json.dumps(reasoning_list) if reasoning_list else None
|
||||
tool_calls_json = json.dumps(tool_calls_list) if tool_calls_list else None
|
||||
sequence_json = json.dumps(sequence) if sequence else None
|
||||
|
||||
if existing:
|
||||
existing.reasoning_content = reasoning_json
|
||||
existing.tool_calls = tool_calls_json
|
||||
existing.sequence = sequence_json
|
||||
return
|
||||
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_run_id=execution.workflow_execution_id,
|
||||
node_id=execution.node_id,
|
||||
reasoning_content=reasoning_json,
|
||||
tool_calls=tool_calls_json,
|
||||
sequence=sequence_json,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
def get_db_models_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
||||
41
api/core/sandbox/__init__.py
Normal file
41
api/core/sandbox/__init__.py
Normal file
@ -0,0 +1,41 @@
|
||||
from .bash.dify_cli import (
|
||||
DifyCliBinary,
|
||||
DifyCliConfig,
|
||||
DifyCliEnvConfig,
|
||||
DifyCliLocator,
|
||||
DifyCliToolConfig,
|
||||
)
|
||||
from .bash.session import SandboxBashSession
|
||||
from .builder import SandboxBuilder, VMConfig
|
||||
from .entities import AppAssets, DifyCli, SandboxProviderApiEntity, SandboxType
|
||||
from .initializer import AppAssetsInitializer, DifyCliInitializer, SandboxInitializer
|
||||
from .manager import SandboxManager
|
||||
from .sandbox import Sandbox
|
||||
from .storage import ArchiveSandboxStorage, SandboxStorage
|
||||
from .utils.debug import sandbox_debug
|
||||
from .utils.encryption import create_sandbox_config_encrypter, masked_config
|
||||
|
||||
__all__ = [
|
||||
"AppAssets",
|
||||
"AppAssetsInitializer",
|
||||
"ArchiveSandboxStorage",
|
||||
"DifyCli",
|
||||
"DifyCliBinary",
|
||||
"DifyCliConfig",
|
||||
"DifyCliEnvConfig",
|
||||
"DifyCliInitializer",
|
||||
"DifyCliLocator",
|
||||
"DifyCliToolConfig",
|
||||
"Sandbox",
|
||||
"SandboxBashSession",
|
||||
"SandboxBuilder",
|
||||
"SandboxInitializer",
|
||||
"SandboxManager",
|
||||
"SandboxProviderApiEntity",
|
||||
"SandboxStorage",
|
||||
"SandboxType",
|
||||
"VMConfig",
|
||||
"create_sandbox_config_encrypter",
|
||||
"masked_config",
|
||||
"sandbox_debug",
|
||||
]
|
||||
1
api/core/sandbox/bash/TODO.md
Normal file
1
api/core/sandbox/bash/TODO.md
Normal file
@ -0,0 +1 @@
|
||||
# refactor the package import paths
|
||||
15
api/core/sandbox/bash/__init__.py
Normal file
15
api/core/sandbox/bash/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
from .dify_cli import (
|
||||
DifyCliBinary,
|
||||
DifyCliConfig,
|
||||
DifyCliEnvConfig,
|
||||
DifyCliLocator,
|
||||
DifyCliToolConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DifyCliBinary",
|
||||
"DifyCliConfig",
|
||||
"DifyCliEnvConfig",
|
||||
"DifyCliLocator",
|
||||
"DifyCliToolConfig",
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user