mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 20:19:28 +08:00
Compare commits
368 Commits
fix/handle
...
dev/plugin
| Author | SHA1 | Date | |
|---|---|---|---|
| 086aeea181 | |||
| 1d7c4a87d0 | |||
| 9042b368e9 | |||
| f1bcd26c69 | |||
| 3dcd8b6330 | |||
| 10c088029c | |||
| 73b1adf862 | |||
| ae76dbd92c | |||
| 782df0c383 | |||
| 089207240e | |||
| 53d30d537f | |||
| 53512a4650 | |||
| 1fb7dcda24 | |||
| 3c3e0a35f4 | |||
| 202a246e83 | |||
| 08b968eca5 | |||
| b1ac71db3e | |||
| 55405c1a26 | |||
| 779770dae5 | |||
| 002b16e1c6 | |||
| 7710d8e83b | |||
| cf75fcdffc | |||
| 6e8601b52c | |||
| 96cf0ed5af | |||
| ddf9eb1f9a | |||
| 46a798bea8 | |||
| bb4fecf3d1 | |||
| 9e258c495d | |||
| 4fbe52da40 | |||
| 1e3197a1ea | |||
| 5f692dfce2 | |||
| 78a7d7fa21 | |||
| a9dda1554e | |||
| c53786d229 | |||
| 17f23f4798 | |||
| 67f2c766bc | |||
| 9a417bfc5e | |||
| 90bc51ed2e | |||
| 02dc835721 | |||
| a05e8f0e37 | |||
| b10cbb9b20 | |||
| 1aaab741a0 | |||
| bafa46393c | |||
| 45d43c41bc | |||
| e944646541 | |||
| 21e1443ed5 | |||
| 93a5ffb037 | |||
| d5711589cd | |||
| 375a359c97 | |||
| 3228bac56d | |||
| c66b4e32db | |||
| 57b60dd51f | |||
| ff911d0dc5 | |||
| 7a71498a3e | |||
| 76bcdc2581 | |||
| 91a218b29d | |||
| 4a6cbda1b4 | |||
| 8c08153e33 | |||
| b44b3866a1 | |||
| c242bb372b | |||
| 8c9e34133c | |||
| 3403ac361a | |||
| 07d6cb3f4a | |||
| 545aa61cf4 | |||
| 9fb78ce827 | |||
| 490b6d092e | |||
| 42b13bd312 | |||
| 28add22f20 | |||
| ce545274a6 | |||
| aa6c951e8c | |||
| c4f4dfc3fb | |||
| 548f6ef2b6 | |||
| b15ff4eb8c | |||
| 7790214620 | |||
| 3942e45cab | |||
| 2ace9ae4e4 | |||
| 5ac0ef6253 | |||
| f552667312 | |||
| 5669a18bd8 | |||
| a97d73ab05 | |||
| 252d2c425b | |||
| 09fc4bba61 | |||
| 5f995fac32 | |||
| 79d4db8541 | |||
| 9c42626772 | |||
| bbfe83c86b | |||
| 55aa4e424a | |||
| 8015f5c0c5 | |||
| f3fe14863d | |||
| d96c368660 | |||
| 3f34b8b0d1 | |||
| 6a58ea9e56 | |||
| 23888398d1 | |||
| bfbc5eb91e | |||
| 98b0d4169e | |||
| 356cd271b2 | |||
| baf7561cf8 | |||
| b09f22961c | |||
| f3ad3a5dfd | |||
| ee49d321c5 | |||
| f88f9d6970 | |||
| 3467ad3d02 | |||
| 6741604027 | |||
| 35312cf96c | |||
| 15f028fe59 | |||
| 8a2301af56 | |||
| 66747a8eef | |||
| 19d413ac1e | |||
| 4a332ff1af | |||
| dc942db52f | |||
| f535a2aa71 | |||
| dfdd6dfa20 | |||
| 2af81d1ee3 | |||
| ece25bce1a | |||
| 6fc234183a | |||
| 15a56f705f | |||
| 899f7e125f | |||
| aa19bb3f30 | |||
| 562852a0ae | |||
| a4b992c1ab | |||
| 3460c1dfbd | |||
| 653f6c2d46 | |||
| ed7851a4b3 | |||
| cb841e5cde | |||
| 4dae0e514e | |||
| 363c46ace8 | |||
| abe5aca3e2 | |||
| d2cc502c71 | |||
| bea10b4356 | |||
| f5f83f1924 | |||
| 403e2d58b9 | |||
| 222df44d21 | |||
| 566e548713 | |||
| 1434d54e7a | |||
| 4229d0f9a7 | |||
| 7f9eb35e1f | |||
| ed7d7a74ea | |||
| 035e54ba4d | |||
| 284707c3a8 | |||
| 1f63028a83 | |||
| 8a0aa91ed7 | |||
| 62079991b7 | |||
| 4e7e172ff3 | |||
| 33a565a719 | |||
| f0b9257387 | |||
| c398c9cb6a | |||
| a3d3e30e3a | |||
| 2b86465d4c | |||
| 6529240da6 | |||
| 0751ad1eeb | |||
| 786550bdc9 | |||
| bde756a1ab | |||
| 423fb2d7bc | |||
| f96b4f287a | |||
| c00e7d3f65 | |||
| 1f38d4846b | |||
| 47a64610ca | |||
| f0a845f0f9 | |||
| abec23118d | |||
| 0957119550 | |||
| b88194d1c6 | |||
| 2b95e54d54 | |||
| f48fa3e4e8 | |||
| 5ffc58d6ca | |||
| 7d958635f0 | |||
| 33990426c1 | |||
| 9f3fc7ebf8 | |||
| c8357da13b | |||
| 2290f14fb1 | |||
| 7796984444 | |||
| 75113c26c6 | |||
| 939a9ecd21 | |||
| f307c7cd88 | |||
| 33ecceb90c | |||
| e0d1cab079 | |||
| 811d72a727 | |||
| c3c575c2e1 | |||
| c189629eca | |||
| 37117c22d4 | |||
| b05e9d2ab4 | |||
| 0451333990 | |||
| ab2e6c19a4 | |||
| f7959bc887 | |||
| 45874c699d | |||
| 286cdc41ab | |||
| 78708eb5d5 | |||
| cf36745770 | |||
| 6622c7f98d | |||
| 3112b74527 | |||
| b3ae6b634f | |||
| 982bca5d40 | |||
| c8dcde6cd0 | |||
| 8f9db61688 | |||
| ebdbaf34e6 | |||
| a081b1e79e | |||
| 38c31e64db | |||
| ae6f67420c | |||
| ca19bd31d4 | |||
| 413dfd5628 | |||
| f9515901cc | |||
| 3f42fabff8 | |||
| 9bff9b5c9e | |||
| 1caa578771 | |||
| b7c11c1818 | |||
| 3eb3db0663 | |||
| be46f32056 | |||
| 6e5c915f96 | |||
| 3dd2c170e7 | |||
| 04d13a8116 | |||
| e638ede3f2 | |||
| 2348abe4bf | |||
| f7e7a399d9 | |||
| ba91f34636 | |||
| 16865d43a8 | |||
| 88f41f164f | |||
| 0d13aee15c | |||
| 49b4144ffd | |||
| 186e2d972e | |||
| 40dd63ecef | |||
| 6d66d6da15 | |||
| 03ec3513f3 | |||
| 87763fc234 | |||
| f6c44cae2e | |||
| da2ee04fce | |||
| 7673c36af3 | |||
| 9457b2af2f | |||
| 7203991032 | |||
| 5a685f7156 | |||
| a6a25030ad | |||
| 00458a31d5 | |||
| c6ddf6d6cc | |||
| 34b21b3065 | |||
| 8fbb355cd2 | |||
| e8b3b7e578 | |||
| 59ca44f493 | |||
| 9e1457c2c3 | |||
| cd932519b3 | |||
| fac83e14bc | |||
| a97cec57e4 | |||
| 38c10b47d3 | |||
| 1a2523fd15 | |||
| 03243cb422 | |||
| 2ad7ee0344 | |||
| 2ff2b08739 | |||
| 55ce3618ce | |||
| e9e34c1ab2 | |||
| d4c916b496 | |||
| 8fbc9c9342 | |||
| 1b6fd9dfe8 | |||
| 304467e3f5 | |||
| 7452032d81 | |||
| 87e2048f1b | |||
| d876084392 | |||
| 840729afa5 | |||
| 941ad03f3c | |||
| d73d191f99 | |||
| c2664e0283 | |||
| ee61cede4e | |||
| b47669b80b | |||
| c0d0c63592 | |||
| b09c39c8dc | |||
| b4b09ddc3c | |||
| d0a21086bd | |||
| d44882c1b5 | |||
| 23c68efa2d | |||
| 560c5de1b7 | |||
| 5d91dbd000 | |||
| 6c31ee36cd | |||
| edc29780ed | |||
| aad7e4dd1c | |||
| a6a727e8a4 | |||
| d1fc65fabc | |||
| d4be5ef9de | |||
| 1374be5a31 | |||
| b2bbc28580 | |||
| 59b3e672aa | |||
| a2f8bce8f5 | |||
| a2b9adb3a2 | |||
| 28067640b5 | |||
| da67916843 | |||
| a4a45421cc | |||
| aafab1b59e | |||
| 7f49f96c3f | |||
| 5673f03db5 | |||
| e54ce479ad | |||
| 278adbc10e | |||
| 5d4e517397 | |||
| c2671c16a8 | |||
| 6024d8a42d | |||
| 10991cbc03 | |||
| f565f08aa0 | |||
| 3fcf7e88b0 | |||
| ffa5af1356 | |||
| fd4afe09f8 | |||
| dd0904f95c | |||
| 4c3076f2a4 | |||
| 1e73f63ff8 | |||
| d167d5b1be | |||
| 71fa14f791 | |||
| 8dd1873e76 | |||
| f91f5c7401 | |||
| c62b7cc679 | |||
| 3ee213ddca | |||
| 8429877b02 | |||
| 05a0faff6a | |||
| e09f6e4987 | |||
| e23f4b0265 | |||
| f582d4a13e | |||
| 2f41bd495d | |||
| 162a8c4393 | |||
| 3d1ce4c53f | |||
| 6db3ae9b8e | |||
| 6d0cb9dc33 | |||
| 46e95e8309 | |||
| a7b9375877 | |||
| 0c6a8a130e | |||
| 9903f1e703 | |||
| 6fad719e42 | |||
| 9aaee8ee47 | |||
| 166221d784 | |||
| 925d69a2ee | |||
| 5ff08e241a | |||
| 3defd24087 | |||
| 9d86147d20 | |||
| 80801ac4ab | |||
| 210926cd91 | |||
| 677a69deed | |||
| 8dfdee21ce | |||
| 6ea77ab4cd | |||
| e3c996688d | |||
| 066516b54d | |||
| 49415e5e7f | |||
| bc3a570dda | |||
| 0800021a2d | |||
| 435eddd867 | |||
| 6e0fb055d1 | |||
| 1e9ac7ffeb | |||
| b4873ecb43 | |||
| 1859d57784 | |||
| 69d58fbb50 | |||
| cb34991663 | |||
| c700364e1c | |||
| 9a6b1dc3a1 | |||
| 54b5b80a07 | |||
| 831459b895 | |||
| 4e101604c3 | |||
| a6455269f0 | |||
| cd257b91c5 | |||
| d8f57bf899 | |||
| 989fb11fd7 | |||
| 140965b738 | |||
| 14ee51aead | |||
| 2e97ba5700 | |||
| f549d53b68 | |||
| a085ad4719 | |||
| f230a9232e | |||
| e84bf35e2a | |||
| 20f090537f | |||
| dbe7a7c4fd | |||
| b7a4e3903e | |||
| a697bbdfa7 | |||
| d5c31f8728 | |||
| 508005b741 | |||
| 4f0ecdbb6e | |||
| ab2e69faef | |||
| e46a3343b8 | |||
| 47637da734 | |||
| 525bde28f6 |
4
.github/workflows/api-tests.yml
vendored
4
.github/workflows/api-tests.yml
vendored
@ -4,7 +4,6 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- plugins/beta
|
||||
paths:
|
||||
- api/**
|
||||
- docker/**
|
||||
@ -27,6 +26,9 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Poetry and Python ${{ matrix.python-version }}
|
||||
uses: ./.github/actions/setup-poetry
|
||||
|
||||
27
.github/workflows/build-push.yml
vendored
27
.github/workflows/build-push.yml
vendored
@ -5,7 +5,6 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
- "plugins/beta"
|
||||
- "dev/plugin-deploy"
|
||||
release:
|
||||
types: [published]
|
||||
@ -81,10 +80,12 @@ jobs:
|
||||
cache-to: type=gha,mode=max,scope=${{ matrix.service_name }}
|
||||
|
||||
- name: Export digest
|
||||
env:
|
||||
DIGEST: ${{ steps.build.outputs.digest }}
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
sanitized_digest=${DIGEST#sha256:}
|
||||
touch "/tmp/digests/${sanitized_digest}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
@ -134,23 +135,15 @@ jobs:
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
env:
|
||||
IMAGE_NAME: ${{ env[matrix.image_name_env] }}
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env[matrix.image_name_env] }}@sha256:%s ' *)
|
||||
$(printf "$IMAGE_NAME@sha256:%s " *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env[matrix.image_name_env] }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
- name: print context var
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: deploy pod in plugin env
|
||||
if: github.ref == 'refs/heads/dev/plugin-deploy'
|
||||
env:
|
||||
IMAGEHASH: ${{ github.sha }}
|
||||
APICMD: "${{ secrets.PLUGIN_CD_API_CURL }}"
|
||||
WEBCMD: "${{ secrets.PLUGIN_CD_WEB_CURL }}"
|
||||
IMAGE_NAME: ${{ env[matrix.image_name_env] }}
|
||||
IMAGE_VERSION: ${{ steps.meta.outputs.version }}
|
||||
run: |
|
||||
bash -c "${APICMD/yourNewVersion/$IMAGEHASH}"
|
||||
bash -c "${WEBCMD/yourNewVersion/$IMAGEHASH}"
|
||||
docker buildx imagetools inspect "$IMAGE_NAME:$IMAGE_VERSION"
|
||||
|
||||
3
.github/workflows/db-migration-test.yml
vendored
3
.github/workflows/db-migration-test.yml
vendored
@ -20,6 +20,9 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Poetry and Python
|
||||
uses: ./.github/actions/setup-poetry
|
||||
|
||||
23
.github/workflows/deploy-plugin-dev.yml
vendored
23
.github/workflows/deploy-plugin-dev.yml
vendored
@ -1,23 +0,0 @@
|
||||
name: Deploy Plugin Dev
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "dev/plugin-deploy"
|
||||
types:
|
||||
- completed
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: "echo 123"
|
||||
2
.github/workflows/expose_service_ports.sh
vendored
2
.github/workflows/expose_service_ports.sh
vendored
@ -9,6 +9,6 @@ yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compos
|
||||
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.yaml
|
||||
|
||||
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
|
||||
|
||||
17
.github/workflows/style.yml
vendored
17
.github/workflows/style.yml
vendored
@ -4,7 +4,6 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- plugins/beta
|
||||
|
||||
concurrency:
|
||||
group: style-${{ github.head_ref || github.run_id }}
|
||||
@ -18,6 +17,9 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
@ -60,6 +62,9 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
@ -87,7 +92,7 @@ jobs:
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: yarn run lint
|
||||
run: pnpm run lint
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
@ -96,6 +101,9 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
@ -124,6 +132,9 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
@ -141,7 +152,7 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
DEFAULT_BRANCH: plugins/beta
|
||||
DEFAULT_BRANCH: main
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
IGNORE_GENERATED_FILES: true
|
||||
IGNORE_GITIGNORED_FILES: true
|
||||
|
||||
5
.github/workflows/tool-test-sdks.yaml
vendored
5
.github/workflows/tool-test-sdks.yaml
vendored
@ -26,6 +26,9 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Use Node.js ${{ matrix.node-version }}
|
||||
uses: actions/setup-node@v4
|
||||
@ -35,7 +38,7 @@ jobs:
|
||||
cache-dependency-path: 'pnpm-lock.yaml'
|
||||
|
||||
- name: Install Dependencies
|
||||
run: pnpm install
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Test
|
||||
run: pnpm test
|
||||
|
||||
@ -16,6 +16,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2 # last 2 commits
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
|
||||
17
.github/workflows/vdb-tests.yml
vendored
17
.github/workflows/vdb-tests.yml
vendored
@ -28,6 +28,9 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Poetry and Python ${{ matrix.python-version }}
|
||||
uses: ./.github/actions/setup-poetry
|
||||
@ -51,7 +54,15 @@ jobs:
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Vector Stores (TiDB, Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
|
||||
- name: Set up Vector Store (TiDB)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: docker/tidb/docker-compose.yaml
|
||||
services: |
|
||||
tidb
|
||||
tiflash
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
@ -67,7 +78,9 @@ jobs:
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
tidb
|
||||
|
||||
- name: Check TiDB Ready
|
||||
run: poetry run -P api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -P api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
35
.github/workflows/web-tests.yml
vendored
35
.github/workflows/web-tests.yml
vendored
@ -22,25 +22,34 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: web/**
|
||||
# to run pnpm, should install package canvas, but it always install failed on amd64 under ubuntu-latest
|
||||
# - name: Install pnpm
|
||||
# uses: pnpm/action-setup@v4
|
||||
# with:
|
||||
# version: 10
|
||||
# run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 20
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
# - name: Setup Node.js
|
||||
# uses: actions/setup-node@v4
|
||||
# if: steps.changed-files.outputs.any_changed == 'true'
|
||||
# with:
|
||||
# node-version: 20
|
||||
# cache: pnpm
|
||||
# cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: pnpm install --frozen-lockfile
|
||||
# - name: Install dependencies
|
||||
# if: steps.changed-files.outputs.any_changed == 'true'
|
||||
# run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tests
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: pnpm test
|
||||
# - name: Run tests
|
||||
# if: steps.changed-files.outputs.any_changed == 'true'
|
||||
# run: pnpm test
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -163,6 +163,7 @@ docker/volumes/db/data/*
|
||||
docker/volumes/redis/data/*
|
||||
docker/volumes/weaviate/*
|
||||
docker/volumes/qdrant/*
|
||||
docker/tidb/volumes/*
|
||||
docker/volumes/etcd/*
|
||||
docker/volumes/minio/*
|
||||
docker/volumes/milvus/*
|
||||
|
||||
@ -73,7 +73,7 @@ Dify requires the following dependencies to build, make sure they're installed o
|
||||
* [Docker](https://www.docker.com/)
|
||||
* [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
* [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
* [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
* [pnpm](https://pnpm.io/)
|
||||
* [Python](https://www.python.org/) version 3.11.x or 3.12.x
|
||||
|
||||
### 4. Installations
|
||||
|
||||
@ -70,7 +70,7 @@ Dify 依赖以下工具和库:
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
- [pnpm](https://pnpm.io/)
|
||||
- [Python](https://www.python.org/) version 3.11.x or 3.12.x
|
||||
|
||||
### 4. 安装
|
||||
|
||||
@ -73,7 +73,7 @@ Dify を構築するには次の依存関係が必要です。それらがシス
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
- [pnpm](https://pnpm.io/)
|
||||
- [Python](https://www.python.org/) version 3.11.x or 3.12.x
|
||||
|
||||
### 4. インストール
|
||||
|
||||
@ -72,7 +72,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/)
|
||||
- [pnpm](https://pnpm.io/)
|
||||
- [Python](https://www.python.org/) phiên bản 3.11.x hoặc 3.12.x
|
||||
|
||||
### 4. Cài đặt
|
||||
|
||||
23
LICENSE
23
LICENSE
@ -1,12 +1,12 @@
|
||||
# Open Source License
|
||||
|
||||
Dify is licensed under the Apache License 2.0, with the following additional conditions:
|
||||
Dify is licensed under a modified version of the Apache License 2.0, with the following additional conditions:
|
||||
|
||||
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
|
||||
|
||||
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
|
||||
|
||||
|
||||
b. LOGO and copyright information: In the process of using Dify's frontend, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend.
|
||||
- Frontend Definition: For the purposes of this license, the "frontend" of Dify includes all components located in the `web/` directory when running Dify from the raw source code, or the "web" image when running Dify with Docker.
|
||||
|
||||
@ -21,19 +21,4 @@ Apart from the specific conditions mentioned above, all other rights and restric
|
||||
|
||||
The interactive design of this product is protected by appearance patent.
|
||||
|
||||
© 2024 LangGenius, Inc.
|
||||
|
||||
|
||||
----------
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
© 2025 LangGenius, Inc.
|
||||
|
||||
66
README.md
66
README.md
@ -108,6 +108,72 @@ Please refer to our [FAQ](https://docs.dify.ai/getting-started/install-self-host
|
||||
**7. Backend-as-a-Service**:
|
||||
All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic.
|
||||
|
||||
## Feature Comparison
|
||||
<table style="width: 100%;">
|
||||
<tr>
|
||||
<th align="center">Feature</th>
|
||||
<th align="center">Dify.AI</th>
|
||||
<th align="center">LangChain</th>
|
||||
<th align="center">Flowise</th>
|
||||
<th align="center">OpenAI Assistants API</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Programming Approach</td>
|
||||
<td align="center">API + App-oriented</td>
|
||||
<td align="center">Python Code</td>
|
||||
<td align="center">App-oriented</td>
|
||||
<td align="center">API-oriented</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Supported LLMs</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
<td align="center">OpenAI-only</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">RAG Engine</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Workflow</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Observability</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Enterprise Feature (SSO/Access control)</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Local Deployment</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Using Dify
|
||||
|
||||
|
||||
16
README_FR.md
16
README_FR.md
@ -55,7 +55,7 @@
|
||||
Dify est une plateforme de développement d'applications LLM open source. Son interface intuitive combine un flux de travail d'IA, un pipeline RAG, des capacités d'agent, une gestion de modèles, des fonctionnalités d'observabilité, et plus encore, vous permettant de passer rapidement du prototype à la production. Voici une liste des fonctionnalités principales:
|
||||
</br> </br>
|
||||
|
||||
**1. Flux de travail**:
|
||||
**1. Flux de travail** :
|
||||
Construisez et testez des flux de travail d'IA puissants sur un canevas visuel, en utilisant toutes les fonctionnalités suivantes et plus encore.
|
||||
|
||||
|
||||
@ -63,27 +63,25 @@ Dify est une plateforme de développement d'applications LLM open source. Son in
|
||||
|
||||
|
||||
|
||||
**2. Prise en charge complète des modèles**:
|
||||
**2. Prise en charge complète des modèles** :
|
||||
Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
|
||||
**3. IDE de prompt**:
|
||||
**3. IDE de prompt** :
|
||||
Interface intuitive pour créer des prompts, comparer les performances des modèles et ajouter des fonctionnalités supplémentaires telles que la synthèse vocale à une application basée sur des chats.
|
||||
|
||||
**4. Pipeline RAG**:
|
||||
**4. Pipeline RAG** :
|
||||
Des capacités RAG étendues qui couvrent tout, de l'ingestion de documents à la récupération, avec un support prêt à l'emploi pour l'extraction de texte à partir de PDF, PPT et autres formats de document courants.
|
||||
|
||||
**5. Capac
|
||||
|
||||
ités d'agent**:
|
||||
**5. Capacités d'agent** :
|
||||
Vous pouvez définir des agents basés sur l'appel de fonction LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DALL·E, Stable Diffusion et WolframAlpha.
|
||||
|
||||
**6. LLMOps**:
|
||||
**6. LLMOps** :
|
||||
Surveillez et analysez les journaux d'application et les performances au fil du temps. Vous pouvez continuellement améliorer les prompts, les ensembles de données et les modèles en fonction des données de production et des annotations.
|
||||
|
||||
**7. Backend-as-a-Service**:
|
||||
**7. Backend-as-a-Service** :
|
||||
Toutes les offres de Dify sont accompagnées d'API correspondantes, vous permettant d'intégrer facilement Dify dans votre propre logique métier.
|
||||
|
||||
|
||||
|
||||
@ -164,7 +164,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
|
||||
|
||||
- **企業/組織向けのDify</br>**
|
||||
企業中心の機能を提供しています。[メールを送信](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)して企業のニーズについて相談してください。 </br>
|
||||
> AWSを使用しているスタートアップ企業や中小企業の場合は、[AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6)のDify Premiumをチェックして、ワンクリックで自分のAWS VPCにデプロイできます。さらに、手頃な価格のAMIオファリングどして、ロゴやブランディングをカスタマイズしてアプリケーションを作成するオプションがあります。
|
||||
> AWSを使用しているスタートアップ企業や中小企業の場合は、[AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t23mebxzwjhu6)のDify Premiumをチェックして、ワンクリックで自分のAWS VPCにデプロイできます。さらに、手頃な価格のAMIオファリングとして、ロゴやブランディングをカスタマイズしてアプリケーションを作成するオプションがあります。
|
||||
|
||||
|
||||
## 最新の情報を入手
|
||||
|
||||
@ -87,9 +87,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
|
||||
## Feature Comparison
|
||||
<table style="width: 100%;">
|
||||
<tr
|
||||
|
||||
>
|
||||
<tr>
|
||||
<th align="center">Feature</th>
|
||||
<th align="center">Dify.AI</th>
|
||||
<th align="center">LangChain</th>
|
||||
|
||||
69
README_SI.md
69
README_SI.md
@ -106,6 +106,73 @@ Prosimo, glejte naša pogosta vprašanja [FAQ](https://docs.dify.ai/getting-star
|
||||
**7. Backend-as-a-Service**:
|
||||
AVse ponudbe Difyja so opremljene z ustreznimi API-ji, tako da lahko Dify brez težav integrirate v svojo poslovno logiko.
|
||||
|
||||
## Primerjava Funkcij
|
||||
|
||||
<table style="width: 100%;">
|
||||
<tr>
|
||||
<th align="center">Funkcija</th>
|
||||
<th align="center">Dify.AI</th>
|
||||
<th align="center">LangChain</th>
|
||||
<th align="center">Flowise</th>
|
||||
<th align="center">OpenAI Assistants API</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Programski pristop</td>
|
||||
<td align="center">API + usmerjeno v aplikacije</td>
|
||||
<td align="center">Python koda</td>
|
||||
<td align="center">Usmerjeno v aplikacije</td>
|
||||
<td align="center">Usmerjeno v API</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Podprti LLM-ji</td>
|
||||
<td align="center">Bogata izbira</td>
|
||||
<td align="center">Bogata izbira</td>
|
||||
<td align="center">Bogata izbira</td>
|
||||
<td align="center">Samo OpenAI</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">RAG pogon</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Potek dela</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Spremljanje</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Funkcija za podjetja (SSO/nadzor dostopa)</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Lokalna namestitev</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Uporaba Dify
|
||||
|
||||
@ -187,4 +254,4 @@ Zaradi zaščite vaše zasebnosti se izogibajte objavljanju varnostnih vprašanj
|
||||
|
||||
## Licenca
|
||||
|
||||
To skladišče je na voljo pod [odprtokodno licenco Dify](LICENSE) , ki je v bistvu Apache 2.0 z nekaj dodatnimi omejitvami.
|
||||
To skladišče je na voljo pod [odprtokodno licenco Dify](LICENSE) , ki je v bistvu Apache 2.0 z nekaj dodatnimi omejitvami.
|
||||
|
||||
@ -55,9 +55,11 @@ RUN \
|
||||
# basic environment
|
||||
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
# expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
fonts-noto-cjk \
|
||||
# install a package to improve the accuracy of guessing mime type and file extension
|
||||
media-types \
|
||||
# install libmagic to support the use of python-magic guess MIMETYPE
|
||||
libmagic1 \
|
||||
&& apt-get autoremove -y \
|
||||
|
||||
@ -37,7 +37,13 @@
|
||||
|
||||
4. Create environment.
|
||||
|
||||
Dify API service uses [Poetry](https://python-poetry.org/docs/) to manage dependencies. You can execute `poetry shell` to activate the environment.
|
||||
Dify API service uses [Poetry](https://python-poetry.org/docs/) to manage dependencies. First, you need to add the poetry shell plugin, if you don't have it already, in order to run in a virtual environment. [Note: Poetry shell is no longer a native command so you need to install the poetry plugin beforehand]
|
||||
|
||||
```bash
|
||||
poetry self add poetry-plugin-shell
|
||||
```
|
||||
|
||||
Then, You can execute `poetry shell` to activate the environment.
|
||||
|
||||
5. Install dependencies
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
import time
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
@ -16,6 +17,12 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
dify_app = DifyApp(__name__)
|
||||
dify_app.config.from_mapping(dify_config.model_dump())
|
||||
|
||||
# add before request hook
|
||||
@dify_app.before_request
|
||||
def before_request():
|
||||
# add an unique identifier to each request
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
return dify_app
|
||||
|
||||
|
||||
|
||||
@ -707,12 +707,13 @@ def extract_unique_plugins(output_file: str, input_file: str):
|
||||
@click.option(
|
||||
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
|
||||
)
|
||||
def install_plugins(input_file: str, output_file: str):
|
||||
@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100)
|
||||
def install_plugins(input_file: str, output_file: str, workers: int):
|
||||
"""
|
||||
Install plugins.
|
||||
"""
|
||||
click.echo(click.style("Starting install plugins.", fg="white"))
|
||||
|
||||
PluginMigration.install_plugins(input_file, output_file)
|
||||
PluginMigration.install_plugins(input_file, output_file, workers)
|
||||
|
||||
click.echo(click.style("Install plugins completed.", fg="green"))
|
||||
|
||||
@ -373,8 +373,8 @@ class HttpConfig(BaseSettings):
|
||||
)
|
||||
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field(
|
||||
description="Enable or disable the X-Forwarded-For Proxy Fix middleware from Werkzeug"
|
||||
" to respect X-* headers to redirect clients",
|
||||
description="Enable handling of X-Forwarded-For, X-Forwarded-Proto, and X-Forwarded-Port headers"
|
||||
" when the app is behind a single trusted reverse proxy.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Any, Literal, Optional
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
@ -166,6 +167,11 @@ class DatabaseConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field(
|
||||
description="Number of processes for the retrieval service, default to CPU cores.",
|
||||
default=os.cpu_count(),
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
||||
@ -15,7 +15,7 @@ AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls"]
|
||||
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
DOCUMENT_EXTENSIONS.append("ppt")
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
|
||||
@ -2,6 +2,8 @@ from contextvars import ContextVar
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
@ -12,8 +14,17 @@ tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||
|
||||
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||
|
||||
plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
|
||||
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
|
||||
"""
|
||||
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
|
||||
"""
|
||||
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_tool_providers")
|
||||
)
|
||||
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
|
||||
|
||||
plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers")
|
||||
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")
|
||||
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_providers")
|
||||
)
|
||||
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_providers_lock")
|
||||
)
|
||||
|
||||
65
api/contexts/wrapper.py
Normal file
65
api/contexts/wrapper.py
Normal file
@ -0,0 +1,65 @@
|
||||
from contextvars import ContextVar
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class HiddenValue:
|
||||
pass
|
||||
|
||||
|
||||
_default = HiddenValue()
|
||||
|
||||
|
||||
class RecyclableContextVar(Generic[T]):
|
||||
"""
|
||||
RecyclableContextVar is a wrapper around ContextVar
|
||||
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now
|
||||
|
||||
NOTE: you need to call `increment_thread_recycles` before requests
|
||||
"""
|
||||
|
||||
_thread_recycles: ContextVar[int] = ContextVar("thread_recycles")
|
||||
|
||||
@classmethod
|
||||
def increment_thread_recycles(cls):
|
||||
try:
|
||||
recycles = cls._thread_recycles.get()
|
||||
cls._thread_recycles.set(recycles + 1)
|
||||
except LookupError:
|
||||
cls._thread_recycles.set(0)
|
||||
|
||||
def __init__(self, context_var: ContextVar[T]):
|
||||
self._context_var = context_var
|
||||
self._updates = ContextVar[int](context_var.name + "_updates", default=0)
|
||||
|
||||
def get(self, default: T | HiddenValue = _default) -> T:
|
||||
thread_recycles = self._thread_recycles.get(0)
|
||||
self_updates = self._updates.get()
|
||||
if thread_recycles > self_updates:
|
||||
self._updates.set(thread_recycles)
|
||||
|
||||
# check if thread is recycled and should be updated
|
||||
if thread_recycles < self_updates:
|
||||
return self._context_var.get()
|
||||
else:
|
||||
# thread_recycles >= self_updates, means current context is invalid
|
||||
if isinstance(default, HiddenValue) or default is _default:
|
||||
raise LookupError
|
||||
else:
|
||||
return default
|
||||
|
||||
def set(self, value: T):
|
||||
# it leads to a situation that self.updates is less than cls.thread_recycles if `set` was never called before
|
||||
# increase it manually
|
||||
thread_recycles = self._thread_recycles.get(0)
|
||||
self_updates = self._updates.get()
|
||||
if thread_recycles > self_updates:
|
||||
self._updates.set(thread_recycles)
|
||||
|
||||
if self._updates.get() == self._thread_recycles.get(0):
|
||||
# after increment,
|
||||
self._updates.set(self._updates.get() + 1)
|
||||
|
||||
# set the context
|
||||
self._context_var.set(value)
|
||||
@ -623,7 +623,6 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.RELYT
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
|
||||
@ -617,7 +617,7 @@ class DocumentDetailApi(DocumentResource):
|
||||
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
|
||||
|
||||
if metadata == "only":
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
@ -678,7 +678,7 @@ class DocumentDetailApi(DocumentResource):
|
||||
"disabled_by": document.disabled_by,
|
||||
"archived": document.archived,
|
||||
"doc_type": document.doc_type,
|
||||
"doc_metadata": document.doc_metadata,
|
||||
"doc_metadata": document.doc_metadata_details,
|
||||
"segment_count": document.segment_count,
|
||||
"average_segment_length": document.average_segment_length,
|
||||
"hit_count": document.hit_count,
|
||||
|
||||
143
api/controllers/console/datasets/metadata.py
Normal file
143
api/controllers/console/datasets/metadata.py
Normal file
@ -0,0 +1,143 @@
|
||||
from flask_login import current_user # type: ignore # type: ignore
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class DatasetListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
def post(self, dataset_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataArgs(**args)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||
return metadata, 201
|
||||
|
||||
|
||||
class DatasetMetadataApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
||||
return metadata, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, dataset_id, metadata_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||
return 200
|
||||
|
||||
|
||||
class DatasetMetadataBuiltInFieldApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
built_in_fields = MetadataService.get_built_in_fields()
|
||||
return built_in_fields, 200
|
||||
|
||||
|
||||
class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id, action):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
if action == "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
elif action == "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return 200
|
||||
|
||||
|
||||
class DocumentMetadataApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataOperationData(**args)
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, "/datasets/<uuid:dataset_id>/metadata")
|
||||
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
|
||||
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/metadata/built-in/<string:action>")
|
||||
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
@ -1,3 +1,5 @@
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -71,7 +73,8 @@ class FilePreviewApi(Resource):
|
||||
if upload_file.size > 0:
|
||||
response.headers["Content-Length"] = str(upload_file.size)
|
||||
if args["as_attachment"]:
|
||||
response.headers["Content-Disposition"] = f"attachment; filename={upload_file.name}"
|
||||
encoded_filename = quote(upload_file.name)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@ -50,8 +50,8 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
||||
"plan": tenant.plan,
|
||||
"status": tenant.status,
|
||||
"custom_config": json.loads(tenant.custom_config) if tenant.custom_config else {},
|
||||
"created_at": tenant.created_at.isoformat() if tenant.created_at else None,
|
||||
"updated_at": tenant.updated_at.isoformat() if tenant.updated_at else None,
|
||||
"created_at": tenant.created_at.isoformat() + "Z" if tenant.created_at else None,
|
||||
"updated_at": tenant.updated_at.isoformat() + "Z" if tenant.updated_at else None,
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@ -10,6 +10,7 @@ from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import feedback_fields, retriever_resource_fields
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
@ -18,26 +19,6 @@ from services.message_service import MessageService
|
||||
|
||||
|
||||
class MessageListApi(Resource):
|
||||
feedback_fields = {"rating": fields.String}
|
||||
retriever_resource_fields = {
|
||||
"id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"dataset_id": fields.String,
|
||||
"dataset_name": fields.String,
|
||||
"document_id": fields.String,
|
||||
"document_name": fields.String,
|
||||
"data_source_type": fields.String,
|
||||
"segment_id": fields.String,
|
||||
"score": fields.Float,
|
||||
"hit_count": fields.Integer,
|
||||
"word_count": fields.Integer,
|
||||
"segment_position": fields.Integer,
|
||||
"index_node_hash": fields.String,
|
||||
"content": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
agent_thought_fields = {
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
@ -89,7 +70,7 @@ class MessageListApi(Resource):
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@ -336,6 +336,10 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset is not exist.")
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if "file" in request.files:
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
||||
@ -154,7 +154,7 @@ def validate_dataset_token(view=None):
|
||||
) # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = Account.query.filter_by(id=ta.account_id).first()
|
||||
account = db.session.query(Account).filter(Account.id == ta.account_id).first()
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
|
||||
@ -21,7 +21,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import agent_thought_fields
|
||||
from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields
|
||||
from fields.raws import FilesContainedField
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
@ -34,27 +34,6 @@ from services.message_service import MessageService
|
||||
|
||||
|
||||
class MessageListApi(WebApiResource):
|
||||
feedback_fields = {"rating": fields.String}
|
||||
|
||||
retriever_resource_fields = {
|
||||
"id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"dataset_id": fields.String,
|
||||
"dataset_name": fields.String,
|
||||
"document_id": fields.String,
|
||||
"document_name": fields.String,
|
||||
"data_source_type": fields.String,
|
||||
"segment_id": fields.String,
|
||||
"score": fields.Float,
|
||||
"hit_count": fields.Integer,
|
||||
"word_count": fields.Integer,
|
||||
"segment_position": fields.Integer,
|
||||
"index_node_hash": fields.String,
|
||||
"content": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
|
||||
@ -329,6 +329,7 @@ class BaseAgentRunner(AppRunner):
|
||||
)
|
||||
if not updated_agent_thought:
|
||||
raise ValueError("agent thought not found")
|
||||
agent_thought = updated_agent_thought
|
||||
|
||||
if thought:
|
||||
agent_thought.thought = thought
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
|
||||
@ -14,7 +14,7 @@ class AgentToolEntity(BaseModel):
|
||||
provider_type: ToolProviderType
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
tool_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
plugin_unique_identifier: str | None = None
|
||||
|
||||
|
||||
|
||||
@ -2,9 +2,9 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
|
||||
@ -61,9 +61,7 @@ class ModelConfigManager:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
if "/" not in config["model"]["provider"]:
|
||||
config["model"]["provider"] = (
|
||||
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
|
||||
)
|
||||
config["model"]["provider"] = str(ModelProviderID(config["model"]["provider"]))
|
||||
|
||||
if config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
@ -17,8 +17,8 @@ class ModelConfigEntity(BaseModel):
|
||||
provider: str
|
||||
model: str
|
||||
mode: Optional[str] = None
|
||||
parameters: dict[str, Any] = {}
|
||||
stop: list[str] = []
|
||||
parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
stop: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AdvancedChatMessageEntity(BaseModel):
|
||||
@ -132,7 +132,7 @@ class ExternalDataVariableEntity(BaseModel):
|
||||
|
||||
variable: str
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class DatasetRetrieveConfigEntity(BaseModel):
|
||||
@ -188,7 +188,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
|
||||
"""
|
||||
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TextToSpeechEntity(BaseModel):
|
||||
|
||||
@ -140,9 +140,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
|
||||
@ -384,6 +384,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
|
||||
@ -149,9 +149,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
|
||||
@ -141,9 +141,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
|
||||
@ -42,7 +42,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
|
||||
@ -387,7 +387,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
node_id=event.node_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
|
||||
@ -63,9 +63,9 @@ class ModelConfigWithCredentialsEntity(BaseModel):
|
||||
model_schema: AIModelEntity
|
||||
mode: str
|
||||
provider_model_bundle: ProviderModelBundle
|
||||
credentials: dict[str, Any] = {}
|
||||
parameters: dict[str, Any] = {}
|
||||
stop: list[str] = []
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
stop: list[str] = Field(default_factory=list)
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
@ -94,7 +94,7 @@ class AppGenerateEntity(BaseModel):
|
||||
call_depth: int = 0
|
||||
|
||||
# extra parameters, like: auto_generate_conversation_name
|
||||
extras: dict[str, Any] = {}
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# tracing instance
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
|
||||
@ -331,7 +331,6 @@ class QueueAgentLogEvent(AppQueueEvent):
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
node_id: str
|
||||
|
||||
|
||||
class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||
|
||||
@ -719,7 +719,6 @@ class AgentLogStreamResponse(StreamResponse):
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
node_id: str
|
||||
|
||||
event: StreamEvent = StreamEvent.AGENT_LOG
|
||||
data: Data
|
||||
|
||||
@ -844,7 +844,7 @@ class WorkflowCycleManage:
|
||||
if node_execution_id not in self._workflow_node_executions:
|
||||
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
|
||||
return cached_workflow_node_execution
|
||||
return session.merge(cached_workflow_node_execution)
|
||||
|
||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||
"""
|
||||
@ -864,6 +864,5 @@ class WorkflowCycleManage:
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
node_id=event.node_id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -6,10 +6,10 @@ from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import or_
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
@ -28,6 +28,7 @@ from core.model_runtime.entities.provider_entities import (
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from extensions.ext_database import db
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
@ -190,8 +191,11 @@ class ProviderConfiguration(BaseModel):
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
or_(
|
||||
Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@ -279,7 +283,10 @@ class ProviderConfiguration(BaseModel):
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
or_(
|
||||
Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
),
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
)
|
||||
.first()
|
||||
@ -996,7 +1003,7 @@ class ProviderConfigurations(BaseModel):
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
configurations: dict[str, ProviderConfiguration] = {}
|
||||
configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)
|
||||
|
||||
def __init__(self, tenant_id: str):
|
||||
super().__init__(tenant_id=tenant_id)
|
||||
@ -1052,7 +1059,7 @@ class ProviderConfigurations(BaseModel):
|
||||
|
||||
def __getitem__(self, key):
|
||||
if "/" not in key:
|
||||
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
||||
key = str(ModelProviderID(key))
|
||||
|
||||
return self.configurations[key]
|
||||
|
||||
@ -1067,7 +1074,7 @@ class ProviderConfigurations(BaseModel):
|
||||
|
||||
def get(self, key, default=None) -> ProviderConfiguration | None:
|
||||
if "/" not in key:
|
||||
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
||||
key = str(ModelProviderID(key))
|
||||
|
||||
return self.configurations.get(key, default) # type: ignore
|
||||
|
||||
|
||||
@ -65,8 +65,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
retries += 1
|
||||
if retries <= max_retries:
|
||||
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
|
||||
raise MaxRetriesExceededError(
|
||||
f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
|
||||
|
||||
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
|
||||
@ -41,9 +41,13 @@ class HostedModerationConfig(BaseModel):
|
||||
|
||||
|
||||
class HostingConfiguration:
|
||||
provider_map: dict[str, HostingProvider] = {}
|
||||
provider_map: dict[str, HostingProvider]
|
||||
moderation_config: Optional[HostedModerationConfig] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.provider_map = {}
|
||||
self.moderation_config = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
if dify_config.EDITION != "CLOUD":
|
||||
return
|
||||
|
||||
@ -228,7 +228,7 @@ class LargeLanguageModel(AIModel):
|
||||
:return: result generator
|
||||
"""
|
||||
callbacks = callbacks or []
|
||||
prompt_message = AssistantPromptMessage(content="")
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
usage = None
|
||||
system_fingerprint = None
|
||||
real_model = model
|
||||
@ -250,7 +250,7 @@ class LargeLanguageModel(AIModel):
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
prompt_message.content += chunk.delta.message.content
|
||||
assistant_message.content += chunk.delta.message.content
|
||||
real_model = chunk.model
|
||||
if chunk.delta.usage:
|
||||
usage = chunk.delta.usage
|
||||
@ -265,7 +265,7 @@ class LargeLanguageModel(AIModel):
|
||||
result=LLMResult(
|
||||
model=real_model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=prompt_message,
|
||||
message=assistant_message,
|
||||
usage=usage or LLMUsage.empty_usage(),
|
||||
system_fingerprint=system_fingerprint,
|
||||
),
|
||||
|
||||
@ -7,7 +7,6 @@ from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
import contexts
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||
@ -34,9 +33,11 @@ class ModelProviderExtension(BaseModel):
|
||||
|
||||
|
||||
class ModelProviderFactory:
|
||||
provider_position_map: dict[str, int] = {}
|
||||
provider_position_map: dict[str, int]
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self.provider_position_map = {}
|
||||
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_model_manager = PluginModelManager()
|
||||
|
||||
@ -360,11 +361,5 @@ class ModelProviderFactory:
|
||||
:param provider: provider name
|
||||
:return: plugin id and provider name
|
||||
"""
|
||||
plugin_id = DEFAULT_PLUGIN_ID
|
||||
provider_name = provider
|
||||
if "/" in provider:
|
||||
# get the plugin_id before provider
|
||||
plugin_id = "/".join(provider.split("/")[:-1])
|
||||
provider_name = provider.split("/")[-1]
|
||||
|
||||
return str(plugin_id), provider_name
|
||||
provider_id = ModelProviderID(provider)
|
||||
return provider_id.plugin_id, provider_id.provider_name
|
||||
|
||||
@ -0,0 +1,22 @@
|
||||
- claude-3-haiku@20240307
|
||||
- claude-3-opus@20240229
|
||||
- claude-3-sonnet@20240229
|
||||
- claude-3-5-sonnet-v2@20241022
|
||||
- claude-3-5-sonnet@20240620
|
||||
- gemini-1.0-pro-vision-001
|
||||
- gemini-1.0-pro-002
|
||||
- gemini-1.5-flash-001
|
||||
- gemini-1.5-flash-002
|
||||
- gemini-1.5-pro-001
|
||||
- gemini-1.5-pro-002
|
||||
- gemini-2.0-flash-001
|
||||
- gemini-2.0-flash-exp
|
||||
- gemini-2.0-flash-lite-preview-02-05
|
||||
- gemini-2.0-flash-thinking-exp-01-21
|
||||
- gemini-2.0-flash-thinking-exp-1219
|
||||
- gemini-2.0-pro-exp-02-05
|
||||
- gemini-exp-1114
|
||||
- gemini-exp-1121
|
||||
- gemini-exp-1206
|
||||
- gemini-flash-experimental
|
||||
- gemini-pro-experimental
|
||||
@ -159,7 +159,7 @@ class GenericProviderID:
|
||||
if re.match(r"^[a-z0-9_-]+$", value):
|
||||
value = f"langgenius/{value}/{value}"
|
||||
else:
|
||||
raise ValueError("Invalid plugin id")
|
||||
raise ValueError(f"Invalid plugin id {value}")
|
||||
|
||||
self.organization, self.plugin_name, self.provider_name = value.split("/")
|
||||
self.is_hardcoded = is_hardcoded
|
||||
@ -180,7 +180,7 @@ class ToolProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
super().__init__(value, is_hardcoded)
|
||||
if self.organization == "langgenius":
|
||||
if self.provider_name in ["jina", "siliconflow"]:
|
||||
if self.provider_name in ["jina", "siliconflow", "stepfun"]:
|
||||
self.plugin_name = f"{self.provider_name}_tool"
|
||||
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.manager.base import BasePluginManager
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
@ -45,7 +45,7 @@ class PluginToolManager(BasePluginManager):
|
||||
"""
|
||||
Fetch tool provider for the given tenant and plugin.
|
||||
"""
|
||||
tool_provider_id = GenericProviderID(provider)
|
||||
tool_provider_id = ToolProviderID(provider)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
data = json_response.get("data")
|
||||
|
||||
@ -100,6 +100,15 @@ class ProviderManager:
|
||||
tenant_id, provider_name_to_provider_records_dict
|
||||
)
|
||||
|
||||
# append providers with langgenius/openai/openai
|
||||
provider_name_list = list(provider_name_to_provider_records_dict.keys())
|
||||
for provider_name in provider_name_list:
|
||||
provider_id = ModelProviderID(provider_name)
|
||||
if str(provider_id) not in provider_name_list:
|
||||
provider_name_to_provider_records_dict[str(provider_id)] = provider_name_to_provider_records_dict[
|
||||
provider_name
|
||||
]
|
||||
|
||||
# Get all provider model records of the workspace
|
||||
provider_name_to_provider_model_records_dict = self._get_all_provider_models(tenant_id)
|
||||
|
||||
@ -360,7 +369,8 @@ class ProviderManager:
|
||||
|
||||
provider_name_to_provider_records_dict = defaultdict(list)
|
||||
for provider in providers:
|
||||
provider_name_to_provider_records_dict[provider.provider_name].append(provider)
|
||||
# TODO: Use provider name with prefix after the data migration
|
||||
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
|
||||
|
||||
return provider_name_to_provider_records_dict
|
||||
|
||||
@ -454,11 +464,9 @@ class ProviderManager:
|
||||
|
||||
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
(
|
||||
provider_name_to_provider_load_balancing_model_configs_dict[
|
||||
provider_load_balancing_config.provider_name
|
||||
].append(provider_load_balancing_config)
|
||||
)
|
||||
provider_name_to_provider_load_balancing_model_configs_dict[
|
||||
provider_load_balancing_config.provider_name
|
||||
].append(provider_load_balancing_config)
|
||||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
@ -501,7 +509,8 @@ class ProviderManager:
|
||||
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
|
||||
provider_record = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
provider_name=ModelProviderID(provider_name).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
quota_limit=quota.quota_limit, # type: ignore
|
||||
@ -516,13 +525,12 @@ class ProviderManager:
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_name == ModelProviderID(provider_name).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider_record and not provider_record.is_valid:
|
||||
provider_record.is_valid = True
|
||||
db.session.commit()
|
||||
|
||||
@ -88,16 +88,17 @@ class Jieba(BaseKeyword):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
|
||||
k = kwargs.get("top_k", 4)
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
|
||||
|
||||
documents = []
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index)
|
||||
.first()
|
||||
segment_query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
|
||||
)
|
||||
if document_ids_filter:
|
||||
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
|
||||
segment = segment_query.first()
|
||||
|
||||
if segment:
|
||||
documents.append(
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
import threading
|
||||
import concurrent.futures
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy.orm import load_only
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
@ -26,6 +30,7 @@ default_retrieval_model = {
|
||||
|
||||
|
||||
class RetrievalService:
|
||||
# Cache precompiled regular expressions to avoid repeated compilation
|
||||
@classmethod
|
||||
def retrieve(
|
||||
cls,
|
||||
@ -37,77 +42,68 @@ class RetrievalService:
|
||||
reranking_model: Optional[dict] = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: Optional[dict] = None,
|
||||
document_ids_filter: Optional[list[str]] = None,
|
||||
):
|
||||
if not query:
|
||||
return []
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
dataset = cls._get_dataset(dataset_id)
|
||||
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return []
|
||||
|
||||
all_documents: list[Document] = []
|
||||
threads: list[threading.Thread] = []
|
||||
exceptions: list[str] = []
|
||||
# retrieval_model source with keyword
|
||||
if retrieval_method == "keyword_search":
|
||||
keyword_thread = threading.Thread(
|
||||
target=RetrievalService.keyword_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents,
|
||||
"exceptions": exceptions,
|
||||
},
|
||||
)
|
||||
threads.append(keyword_thread)
|
||||
keyword_thread.start()
|
||||
# retrieval_model source with semantic
|
||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||
embedding_thread = threading.Thread(
|
||||
target=RetrievalService.embedding_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
"reranking_model": reranking_model,
|
||||
"all_documents": all_documents,
|
||||
"retrieval_method": retrieval_method,
|
||||
"exceptions": exceptions,
|
||||
},
|
||||
)
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
|
||||
# retrieval source with full text
|
||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
||||
full_text_index_thread = threading.Thread(
|
||||
target=RetrievalService.full_text_index_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"retrieval_method": retrieval_method,
|
||||
"score_threshold": score_threshold,
|
||||
"top_k": top_k,
|
||||
"reranking_model": reranking_model,
|
||||
"all_documents": all_documents,
|
||||
"exceptions": exceptions,
|
||||
},
|
||||
)
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# Optimize multithreading with thread pools
|
||||
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
|
||||
futures = []
|
||||
if retrieval_method == "keyword_search":
|
||||
futures.append(
|
||||
executor.submit(
|
||||
cls.keyword_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
all_documents=all_documents,
|
||||
exceptions=exceptions,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||
futures.append(
|
||||
executor.submit(
|
||||
cls.embedding_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
all_documents=all_documents,
|
||||
retrieval_method=retrieval_method,
|
||||
exceptions=exceptions,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
||||
futures.append(
|
||||
executor.submit(
|
||||
cls.full_text_index_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
all_documents=all_documents,
|
||||
retrieval_method=retrieval_method,
|
||||
exceptions=exceptions,
|
||||
)
|
||||
)
|
||||
concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED)
|
||||
|
||||
if exceptions:
|
||||
exception_message = ";\n".join(exceptions)
|
||||
raise ValueError(exception_message)
|
||||
raise ValueError(";\n".join(exceptions))
|
||||
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(
|
||||
@ -132,19 +128,32 @@ class RetrievalService:
|
||||
)
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
|
||||
return db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
@classmethod
|
||||
def keyword_search(
|
||||
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
all_documents: list,
|
||||
exceptions: list,
|
||||
document_ids_filter: Optional[list[str]] = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
dataset = cls._get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("dataset not found")
|
||||
|
||||
keyword = Keyword(dataset=dataset)
|
||||
|
||||
documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k)
|
||||
documents = keyword.search(
|
||||
cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter
|
||||
)
|
||||
all_documents.extend(documents)
|
||||
except Exception as e:
|
||||
exceptions.append(str(e))
|
||||
@ -161,21 +170,22 @@ class RetrievalService:
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
document_ids_filter: Optional[list[str]] = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
dataset = cls._get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("dataset not found")
|
||||
|
||||
vector = Vector(dataset=dataset)
|
||||
|
||||
documents = vector.search_by_vector(
|
||||
cls.escape_query_for_search(query),
|
||||
query,
|
||||
search_type="similarity_score_threshold",
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
filter={"group_id": [dataset.id]},
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
|
||||
if documents:
|
||||
@ -186,7 +196,7 @@ class RetrievalService:
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
|
||||
)
|
||||
all_documents.extend(
|
||||
data_post_processor.invoke(
|
||||
@ -216,13 +226,11 @@ class RetrievalService:
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
dataset = cls._get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("dataset not found")
|
||||
|
||||
vector_processor = Vector(
|
||||
dataset=dataset,
|
||||
)
|
||||
vector_processor = Vector(dataset=dataset)
|
||||
|
||||
documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k)
|
||||
if documents:
|
||||
@ -233,7 +241,7 @@ class RetrievalService:
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
|
||||
)
|
||||
all_documents.extend(
|
||||
data_post_processor.invoke(
|
||||
@ -250,66 +258,106 @@ class RetrievalService:
|
||||
|
||||
@staticmethod
|
||||
def escape_query_for_search(query: str) -> str:
|
||||
return query.replace('"', '\\"')
|
||||
return json.dumps(query).strip('"')
|
||||
|
||||
@classmethod
|
||||
def format_retrieval_documents(cls, documents: list[Document]) -> list[RetrievalSegments]:
|
||||
"""Format retrieval documents with optimized batch processing"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Collect document IDs
|
||||
document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata}
|
||||
if not document_ids:
|
||||
return []
|
||||
|
||||
# Batch query dataset documents
|
||||
dataset_documents = {
|
||||
doc.id: doc
|
||||
for doc in db.session.query(DatasetDocument)
|
||||
.filter(DatasetDocument.id.in_(document_ids))
|
||||
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
|
||||
.all()
|
||||
}
|
||||
|
||||
records = []
|
||||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
|
||||
# Process documents
|
||||
for document in documents:
|
||||
document_id = document.metadata.get("document_id")
|
||||
if document_id not in dataset_documents:
|
||||
continue
|
||||
|
||||
dataset_document = dataset_documents[document_id]
|
||||
|
||||
@staticmethod
|
||||
def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]:
|
||||
records = []
|
||||
include_segment_ids = []
|
||||
segment_child_map = {}
|
||||
for document in documents:
|
||||
document_id = document.metadata.get("document_id")
|
||||
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
# Handle parent-child documents
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
result = (
|
||||
db.session.query(ChildChunk, DocumentSegment)
|
||||
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
|
||||
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first()
|
||||
)
|
||||
|
||||
if not child_chunk:
|
||||
continue
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
ChildChunk.index_node_id == child_index_node_id,
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == child_chunk.segment_id,
|
||||
)
|
||||
.options(
|
||||
load_only(
|
||||
DocumentSegment.id,
|
||||
DocumentSegment.content,
|
||||
DocumentSegment.answer,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if result:
|
||||
child_chunk, segment = result
|
||||
if not segment:
|
||||
continue
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.append(segment.id)
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
record = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
else:
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
)
|
||||
else:
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
record = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
else:
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
)
|
||||
else:
|
||||
index_node_id = document.metadata["doc_id"]
|
||||
# Handle normal documents
|
||||
index_node_id = document.metadata.get("doc_id")
|
||||
if not index_node_id:
|
||||
continue
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
@ -324,16 +372,21 @@ class RetrievalService:
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
include_segment_ids.append(segment.id)
|
||||
|
||||
include_segment_ids.add(segment.id)
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata.get("score", None),
|
||||
"score": document.metadata.get("score"), # type: ignore
|
||||
}
|
||||
|
||||
records.append(record)
|
||||
|
||||
# Add child chunks information to records
|
||||
for record in records:
|
||||
if record["segment"].id in segment_child_map:
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
||||
|
||||
return [RetrievalSegments(**record) for record in records]
|
||||
return [RetrievalSegments(**record) for record in records]
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
raise e
|
||||
|
||||
@ -53,7 +53,7 @@ class AnalyticdbVector(BaseVector):
|
||||
self.analyticdb_vector.delete_by_metadata_field(key, value)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_vector(query_vector)
|
||||
return self.analyticdb_vector.search_by_vector(query_vector, **kwargs)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
|
||||
|
||||
@ -194,6 +194,11 @@ class AnalyticdbVectorBySql:
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = "WHERE 1=1"
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
with self._get_cursor() as cur:
|
||||
query_vector_str = json.dumps(query_vector)
|
||||
@ -202,7 +207,7 @@ class AnalyticdbVectorBySql:
|
||||
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
|
||||
f"t.page_content as page_content, t.metadata_ AS metadata_ "
|
||||
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
|
||||
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
|
||||
f"FROM {self.table_name} {where_clause} ORDER BY score LIMIT {top_k} ) t",
|
||||
(query_vector_str,),
|
||||
)
|
||||
documents = []
|
||||
@ -220,12 +225,17 @@ class AnalyticdbVectorBySql:
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT id, vector, page_content, metadata_,
|
||||
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
|
||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
|
||||
@ -123,11 +123,21 @@ class BaiduVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
filter=f"document_id IN ({document_ids})",
|
||||
)
|
||||
else:
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
)
|
||||
res = self._db.table(self._collection_name).search(
|
||||
anns=anns,
|
||||
projections=[self.field_id, self.field_text, self.field_metadata],
|
||||
|
||||
@ -95,7 +95,15 @@ class ChromaVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
results: QueryResult = collection.query(
|
||||
query_embeddings=query_vector,
|
||||
n_results=kwargs.get("top_k", 4),
|
||||
where={"document_id": {"$in": document_ids_filter}},
|
||||
)
|
||||
else:
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
# Check if results contain data
|
||||
@ -111,8 +119,9 @@ class ChromaVector(BaseVector):
|
||||
for index in range(len(ids)):
|
||||
distance = distances[index]
|
||||
metadata = dict(metadatas[index])
|
||||
if distance >= score_threshold:
|
||||
metadata["score"] = distance
|
||||
score = 1 - distance
|
||||
if score > score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=documents[index],
|
||||
metadata=metadata,
|
||||
|
||||
@ -117,6 +117,9 @@ class ElasticSearchVector(BaseVector):
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
num_candidates = math.ceil(top_k * 1.5)
|
||||
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
|
||||
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
|
||||
|
||||
@ -145,6 +148,9 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
|
||||
docs = []
|
||||
for hit in results["hits"]["hits"]:
|
||||
|
||||
@ -168,7 +168,12 @@ class LindormVectorStore(BaseVector):
|
||||
raise ValueError("All elements in query_vector should be floats")
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filters = []
|
||||
if document_ids_filter:
|
||||
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
|
||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
|
||||
|
||||
try:
|
||||
params = {}
|
||||
if self._using_ugc:
|
||||
@ -206,7 +211,10 @@ class LindormVectorStore(BaseVector):
|
||||
should = kwargs.get("should")
|
||||
minimum_should_match = kwargs.get("minimum_should_match", 0)
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
filters = kwargs.get("filter")
|
||||
filters = kwargs.get("filter", [])
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
|
||||
routing = self._routing
|
||||
full_text_query = default_text_search_query(
|
||||
query_text=query,
|
||||
|
||||
@ -218,12 +218,18 @@ class MilvusVector(BaseVector):
|
||||
"""
|
||||
Search for documents by vector similarity.
|
||||
"""
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f'metadata["document_id"] in ({document_ids})'
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
anns_field=Field.VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
@ -239,6 +245,11 @@ class MilvusVector(BaseVector):
|
||||
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
|
||||
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
|
||||
return []
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f'metadata["document_id"] in ({document_ids})'
|
||||
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
@ -246,6 +257,7 @@ class MilvusVector(BaseVector):
|
||||
anns_field=Field.SPARSE_VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
|
||||
@ -131,6 +131,10 @@ class MyScaleVector(BaseVector):
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||
else ""
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})"
|
||||
sql = f"""
|
||||
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||
|
||||
@ -154,6 +154,11 @@ class OceanBaseVector(BaseVector):
|
||||
return []
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = None
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
|
||||
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
|
||||
if ef_search != self._hnsw_ef_search:
|
||||
self._client.set_ob_hnsw_ef_search(ef_search)
|
||||
@ -167,6 +172,7 @@ class OceanBaseVector(BaseVector):
|
||||
distance_func=func.l2_distance,
|
||||
output_column_names=["text", "metadata"],
|
||||
with_dist=True,
|
||||
where_clause=where_clause,
|
||||
)
|
||||
docs = []
|
||||
for text, metadata, distance in cur:
|
||||
|
||||
@ -154,6 +154,9 @@ class OpenSearchVector(BaseVector):
|
||||
"size": kwargs.get("top_k", 4),
|
||||
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
|
||||
}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
query["query"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
@ -179,6 +182,9 @@ class OpenSearchVector(BaseVector):
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter}
|
||||
|
||||
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
|
||||
|
||||
|
||||
@ -185,10 +185,15 @@ class OracleVector(BaseVector):
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
|
||||
f" ORDER BY distance fetch first {top_k} rows only",
|
||||
f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
|
||||
[numpy.array(query_vector)],
|
||||
)
|
||||
docs = []
|
||||
@ -241,9 +246,15 @@ class OracleVector(BaseVector):
|
||||
if token not in stop_words:
|
||||
entities.append(token)
|
||||
with self._get_cursor() as cur:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
cur.execute(
|
||||
f"select meta, text, embedding FROM {self.table_name}"
|
||||
f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
|
||||
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
|
||||
f"order by score(1) desc fetch first {top_k} rows only",
|
||||
[" ACCUM ".join(entities)],
|
||||
)
|
||||
docs = []
|
||||
|
||||
@ -189,6 +189,9 @@ class PGVectoRS(BaseVector):
|
||||
.limit(kwargs.get("top_k", 4))
|
||||
.order_by("distance")
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
stmt = stmt.where(self._table.meta["document_id"].in_(document_ids_filter))
|
||||
res = session.execute(stmt)
|
||||
results = [(row[0], row[1]) for row in res]
|
||||
|
||||
|
||||
@ -155,10 +155,16 @@ class PGVector(BaseVector):
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" WHERE metadata->>'document_id' in ({document_ids}) "
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
|
||||
f" {where_clause}"
|
||||
f" ORDER BY distance LIMIT {top_k}",
|
||||
(json.dumps(query_vector),),
|
||||
)
|
||||
@ -176,10 +182,16 @@ class PGVector(BaseVector):
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
|
||||
{where_clause}
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
# f"'{query}'" is required in order to account for whitespace in query
|
||||
|
||||
@ -286,27 +286,26 @@ class QdrantVector(BaseVector):
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
for node_id in ids:
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchAny(any=ids),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
all_collection_name = []
|
||||
@ -331,6 +330,14 @@ class QdrantVector(BaseVector):
|
||||
),
|
||||
],
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
)
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_vector,
|
||||
@ -377,6 +384,14 @@ class QdrantVector(BaseVector):
|
||||
),
|
||||
]
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
scroll_filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
)
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
|
||||
@ -223,8 +223,12 @@ class RelytVector(BaseVector):
|
||||
return len(result) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = kwargs.get("filter", {})
|
||||
if document_ids_filter:
|
||||
filter["document_id"] = document_ids_filter
|
||||
results = self.similarity_search_with_score_by_vector(
|
||||
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter")
|
||||
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=filter
|
||||
)
|
||||
|
||||
# Organize results.
|
||||
@ -246,9 +250,9 @@ class RelytVector(BaseVector):
|
||||
filter_condition = ""
|
||||
if filter is not None:
|
||||
conditions = [
|
||||
f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
|
||||
f"metadata->>'{key!r}' in ({', '.join(map(repr, value))})"
|
||||
if len(value) > 1
|
||||
else f"metadata->>{key!r} = {value[0]!r}"
|
||||
else f"metadata->>'{key!r}' = {value[0]!r}"
|
||||
for key, value in filter.items()
|
||||
]
|
||||
filter_condition = f"WHERE {' AND '.join(conditions)}"
|
||||
|
||||
@ -145,11 +145,16 @@ class TencentVector(BaseVector):
|
||||
self._db.collection(self._collection_name).delete(document_ids=ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
|
||||
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(f"metadata.{key}", [value])))
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = None
|
||||
if document_ids_filter:
|
||||
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
|
||||
res = self._db.collection(self._collection_name).search(
|
||||
vectors=[query_vector],
|
||||
filter=filter,
|
||||
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
|
||||
retrieve_vector=False,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
|
||||
@ -326,6 +326,14 @@ class TidbOnQdrantVector(BaseVector):
|
||||
),
|
||||
],
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
)
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_vector,
|
||||
@ -368,6 +376,14 @@ class TidbOnQdrantVector(BaseVector):
|
||||
)
|
||||
]
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
scroll_filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
)
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
|
||||
@ -9,6 +9,7 @@ from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.orm import Session, declarative_base
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -54,14 +55,13 @@ class TiDBVector(BaseVector):
|
||||
return Table(
|
||||
self._collection_name,
|
||||
self._orm_base.metadata,
|
||||
Column("id", String(36), primary_key=True, nullable=False),
|
||||
Column(Field.PRIMARY_KEY.value, String(36), primary_key=True, nullable=False),
|
||||
Column(
|
||||
"vector",
|
||||
Field.VECTOR.value,
|
||||
VectorType(dim),
|
||||
nullable=False,
|
||||
comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})",
|
||||
),
|
||||
Column("text", TEXT, nullable=False),
|
||||
Column(Field.TEXT_KEY.value, TEXT, nullable=False),
|
||||
Column("meta", JSON, nullable=False),
|
||||
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
|
||||
Column(
|
||||
@ -96,6 +96,7 @@ class TiDBVector(BaseVector):
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
tidb_dist_func = self._get_distance_func()
|
||||
with Session(self._engine) as session:
|
||||
session.begin()
|
||||
create_statement = sql_text(f"""
|
||||
@ -104,14 +105,14 @@ class TiDBVector(BaseVector):
|
||||
text TEXT NOT NULL,
|
||||
meta JSON NOT NULL,
|
||||
doc_id VARCHAR(64) AS (JSON_UNQUOTE(JSON_EXTRACT(meta, '$.doc_id'))) STORED,
|
||||
KEY (doc_id),
|
||||
vector VECTOR<FLOAT>({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})",
|
||||
vector VECTOR<FLOAT>({dimension}) NOT NULL,
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
KEY (doc_id),
|
||||
VECTOR INDEX idx_vector (({tidb_dist_func}(vector))) USING HNSW
|
||||
);
|
||||
""")
|
||||
session.execute(create_statement)
|
||||
# tidb vector not support 'CREATE/ADD INDEX' now
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
@ -194,23 +195,36 @@ class TiDBVector(BaseVector):
|
||||
)
|
||||
|
||||
docs = []
|
||||
if self._distance_func == "l2":
|
||||
tidb_func = "Vec_l2_distance"
|
||||
elif self._distance_func == "cosine":
|
||||
tidb_func = "Vec_Cosine_distance"
|
||||
else:
|
||||
tidb_func = "Vec_Cosine_distance"
|
||||
tidb_dist_func = self._get_distance_func()
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" WHERE meta->>'$.document_id' in ({document_ids}) "
|
||||
|
||||
with Session(self._engine) as session:
|
||||
select_statement = sql_text(
|
||||
f"""SELECT meta, text, distance FROM (
|
||||
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
|
||||
FROM {self._collection_name}
|
||||
ORDER BY distance
|
||||
LIMIT {top_k}
|
||||
) t WHERE distance < {distance};"""
|
||||
select_statement = sql_text(f"""
|
||||
SELECT meta, text, distance
|
||||
FROM (
|
||||
SELECT
|
||||
meta,
|
||||
text,
|
||||
{tidb_dist_func}(vector, :query_vector_str) AS distance
|
||||
FROM {self._collection_name}
|
||||
{where_clause}
|
||||
ORDER BY distance ASC
|
||||
LIMIT :top_k
|
||||
) t
|
||||
WHERE distance <= :distance
|
||||
""")
|
||||
res = session.execute(
|
||||
select_statement,
|
||||
params={
|
||||
"query_vector_str": query_vector_str,
|
||||
"distance": distance,
|
||||
"top_k": top_k,
|
||||
},
|
||||
)
|
||||
res = session.execute(select_statement)
|
||||
results = [(row[0], row[1], row[2]) for row in res]
|
||||
for meta, text, distance in results:
|
||||
metadata = json.loads(meta)
|
||||
@ -227,6 +241,16 @@ class TiDBVector(BaseVector):
|
||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
||||
session.commit()
|
||||
|
||||
def _get_distance_func(self) -> str:
|
||||
match self._distance_func:
|
||||
case "l2":
|
||||
tidb_dist_func = "VEC_L2_DISTANCE"
|
||||
case "cosine":
|
||||
tidb_dist_func = "VEC_COSINE_DISTANCE"
|
||||
case _:
|
||||
tidb_dist_func = "VEC_COSINE_DISTANCE"
|
||||
return tidb_dist_func
|
||||
|
||||
|
||||
class TiDBVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:
|
||||
|
||||
@ -88,7 +88,20 @@ class UpstashVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f"document_id in ({document_ids})"
|
||||
else:
|
||||
filter = ""
|
||||
result = self.index.query(
|
||||
vector=query_vector,
|
||||
top_k=top_k,
|
||||
include_metadata=True,
|
||||
include_data=True,
|
||||
include_vectors=False,
|
||||
filter=filter,
|
||||
)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in result:
|
||||
|
||||
@ -49,6 +49,10 @@ class BaseVector(ABC):
|
||||
def delete(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_metadata(self, document_id: str, metadata: dict) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts.copy():
|
||||
if text.metadata and "doc_id" in text.metadata:
|
||||
|
||||
@ -177,7 +177,11 @@ class VikingDBVector(BaseVector):
|
||||
query_vector, limit=kwargs.get("top_k", 4)
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(results, score_threshold)
|
||||
docs = self._get_search_res(results, score_threshold)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
docs = [doc for doc in docs if doc.metadata.get("document_id") in document_ids_filter]
|
||||
return docs
|
||||
|
||||
def _get_search_res(self, results, score_threshold) -> list[Document]:
|
||||
if len(results) == 0:
|
||||
|
||||
@ -168,16 +168,16 @@ class WeaviateVector(BaseVector):
|
||||
# check whether the index already exists
|
||||
schema = self._default_schema(self._collection_name)
|
||||
if self._client.schema.contains(schema):
|
||||
for uuid in ids:
|
||||
try:
|
||||
self._client.data_object.delete(
|
||||
class_name=self._collection_name,
|
||||
uuid=uuid,
|
||||
)
|
||||
except weaviate.UnexpectedStatusCodeException as e:
|
||||
# tolerate not found error
|
||||
if e.status_code != 404:
|
||||
raise e
|
||||
try:
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._collection_name,
|
||||
where={"operator": "ContainsAny", "path": ["id"], "valueTextArray": ids},
|
||||
output="minimal",
|
||||
)
|
||||
except weaviate.UnexpectedStatusCodeException as e:
|
||||
# tolerate not found error
|
||||
if e.status_code != 404:
|
||||
raise e
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Look up similar documents by embedding vector in Weaviate."""
|
||||
@ -187,8 +187,10 @@ class WeaviateVector(BaseVector):
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
|
||||
vector = {"vector": query_vector}
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
|
||||
query_obj = query_obj.with_where(where_filter)
|
||||
result = (
|
||||
query_obj.with_near_vector(vector)
|
||||
.with_limit(kwargs.get("top_k", 4))
|
||||
@ -233,8 +235,10 @@ class WeaviateVector(BaseVector):
|
||||
if kwargs.get("search_distance"):
|
||||
content["certainty"] = kwargs.get("search_distance")
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
|
||||
query_obj = query_obj.with_where(where_filter)
|
||||
query_obj = query_obj.with_additional(["vector"])
|
||||
properties = ["text"]
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()
|
||||
|
||||
9
api/core/rag/index_processor/constant/built_in_field.py
Normal file
9
api/core/rag/index_processor/constant/built_in_field.py
Normal file
@ -0,0 +1,9 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class BuiltInField(str, Enum):
|
||||
document_name = "document_name"
|
||||
uploader = "uploader"
|
||||
upload_date = "upload_date"
|
||||
last_update_date = "last_update_date"
|
||||
source = "source"
|
||||
@ -237,6 +237,7 @@ class DatasetRetrieval:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
message_id: Optional[str] = None,
|
||||
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
|
||||
):
|
||||
tools = []
|
||||
for dataset in available_datasets:
|
||||
@ -291,6 +292,11 @@ class DatasetRetrieval:
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
else:
|
||||
document_ids_filter = None
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
# get top k
|
||||
@ -322,6 +328,7 @@ class DatasetRetrieval:
|
||||
reranking_model=reranking_model,
|
||||
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
|
||||
weights=retrieval_model_config.get("weights", None),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
||||
|
||||
|
||||
@ -105,10 +105,10 @@ class ApiTool(Tool):
|
||||
needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required]
|
||||
for parameter in needed_parameters:
|
||||
if parameter.required and parameter.name not in parameters:
|
||||
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
|
||||
|
||||
if parameter.default is not None and parameter.name not in parameters:
|
||||
parameters[parameter.name] = parameter.default
|
||||
if parameter.default is not None:
|
||||
parameters[parameter.name] = parameter.default
|
||||
else:
|
||||
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
@ -246,10 +246,11 @@ class ToolEngine:
|
||||
+ "you do not need to create it, just tell the user to check it now."
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
text = json.dumps(cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False)
|
||||
result += f"tool response: {text}."
|
||||
result = json.dumps(
|
||||
cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
|
||||
)
|
||||
else:
|
||||
result += f"tool response: {response.message!r}."
|
||||
result += str(response.message)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Union, cast
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.manager.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
@ -160,8 +160,8 @@ class ToolManager:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
:param provider_type: the type of the provider
|
||||
:param provider_name: the name of the provider
|
||||
:param provider_type: the type of the provider
|
||||
:param provider_name: the name of the provider
|
||||
:param tool_name: the name of the tool
|
||||
|
||||
:return: the tool
|
||||
@ -188,7 +188,7 @@ class ToolManager:
|
||||
)
|
||||
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
provider_id_entity = GenericProviderID(provider_id)
|
||||
provider_id_entity = ToolProviderID(provider_id)
|
||||
# get credentials
|
||||
builtin_provider: BuiltinToolProvider | None = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
@ -572,95 +572,96 @@ class ToolManager:
|
||||
else:
|
||||
filters.append(typ)
|
||||
|
||||
if "builtin" in filters:
|
||||
# get builtin providers
|
||||
builtin_providers = cls.list_builtin_providers(tenant_id)
|
||||
with db.session.no_autoflush:
|
||||
if "builtin" in filters:
|
||||
# get builtin providers
|
||||
builtin_providers = cls.list_builtin_providers(tenant_id)
|
||||
|
||||
# get db builtin providers
|
||||
db_builtin_providers: list[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
# rewrite db_builtin_providers
|
||||
for db_provider in db_builtin_providers:
|
||||
tool_provider_id = GenericProviderID(db_provider.provider)
|
||||
db_provider.provider = tool_provider_id.to_string()
|
||||
|
||||
def find_db_builtin_provider(provider):
|
||||
return next((x for x in db_builtin_providers if x.provider == provider), None)
|
||||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
|
||||
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.entity.identity.name),
|
||||
decrypt_credentials=False,
|
||||
# get db builtin providers
|
||||
db_builtin_providers: list[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
|
||||
else:
|
||||
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
||||
# rewrite db_builtin_providers
|
||||
for db_provider in db_builtin_providers:
|
||||
tool_provider_id = str(ToolProviderID(db_provider.provider))
|
||||
db_provider.provider = tool_provider_id
|
||||
|
||||
# get db api providers
|
||||
def find_db_builtin_provider(provider):
|
||||
return next((x for x in db_builtin_providers if x.provider == provider), None)
|
||||
|
||||
if "api" in filters:
|
||||
db_api_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
|
||||
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
api_provider_controllers: list[dict[str, Any]] = [
|
||||
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
||||
for provider in db_api_providers
|
||||
]
|
||||
|
||||
# get labels
|
||||
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
|
||||
|
||||
for api_provider_controller in api_provider_controllers:
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=api_provider_controller["controller"],
|
||||
db_provider=api_provider_controller["provider"],
|
||||
decrypt_credentials=False,
|
||||
labels=labels.get(api_provider_controller["controller"].provider_id, []),
|
||||
)
|
||||
result_providers[f"api_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
if "workflow" in filters:
|
||||
# get workflow providers
|
||||
workflow_providers: list[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||
for provider in workflow_providers:
|
||||
try:
|
||||
workflow_provider_controllers.append(
|
||||
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.entity.identity.name),
|
||||
decrypt_credentials=False,
|
||||
)
|
||||
except Exception:
|
||||
# app has been deleted
|
||||
pass
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(
|
||||
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
|
||||
)
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
|
||||
else:
|
||||
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
for provider_controller in workflow_provider_controllers:
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
labels=labels.get(provider_controller.provider_id, []),
|
||||
# get db api providers
|
||||
|
||||
if "api" in filters:
|
||||
db_api_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
api_provider_controllers: list[dict[str, Any]] = [
|
||||
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
||||
for provider in db_api_providers
|
||||
]
|
||||
|
||||
# get labels
|
||||
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
|
||||
|
||||
for api_provider_controller in api_provider_controllers:
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=api_provider_controller["controller"],
|
||||
db_provider=api_provider_controller["provider"],
|
||||
decrypt_credentials=False,
|
||||
labels=labels.get(api_provider_controller["controller"].provider_id, []),
|
||||
)
|
||||
result_providers[f"api_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
if "workflow" in filters:
|
||||
# get workflow providers
|
||||
workflow_providers: list[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||
for provider in workflow_providers:
|
||||
try:
|
||||
workflow_provider_controllers.append(
|
||||
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
|
||||
)
|
||||
except Exception:
|
||||
# app has been deleted
|
||||
pass
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(
|
||||
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
|
||||
)
|
||||
|
||||
for provider_controller in workflow_provider_controllers:
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
labels=labels.get(provider_controller.provider_id, []),
|
||||
)
|
||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
return BuiltinToolProviderSort.sort(list(result_providers.values()))
|
||||
|
||||
|
||||
@ -3,11 +3,13 @@ from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
@ -54,7 +56,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
|
||||
if not dataset:
|
||||
return ""
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
if dataset.provider == "external":
|
||||
@ -125,7 +126,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_score_list = {}
|
||||
@ -134,50 +134,46 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
if item.metadata is not None and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata["doc_id"] for document in documents]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id == self.dataset_id,
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
).all()
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(
|
||||
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
|
||||
)
|
||||
for segment in sorted_segments:
|
||||
records = RetrievalService.format_retrieval_documents(documents)
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
if segment.answer:
|
||||
document_context_list.append(
|
||||
f"question:{segment.get_sign_content()} answer:{segment.answer}"
|
||||
DocumentContext(
|
||||
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
else:
|
||||
document_context_list.append(segment.get_sign_content())
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=segment.get_sign_content(),
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
retrieval_resource_list = []
|
||||
if self.return_resource:
|
||||
context_list = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
document_segment = Document.query.filter(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).first()
|
||||
if not document_segment:
|
||||
continue
|
||||
if dataset and document_segment:
|
||||
if dataset and document:
|
||||
source = {
|
||||
"position": resource_number,
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document_segment.id,
|
||||
"document_name": document_segment.name,
|
||||
"data_source_type": document_segment.data_source_type,
|
||||
"document_id": document.id, # type: ignore
|
||||
"document_name": document.name, # type: ignore
|
||||
"data_source_type": document.data_source_type, # type: ignore
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
"score": record.score or 0.0,
|
||||
}
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
@ -187,10 +183,19 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
if self.return_resource and retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=lambda x: x.get("score") or 0.0,
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
|
||||
item["position"] = position # type: ignore
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
return str("\n".join([document_context.content for document_context in document_context_list]))
|
||||
return ""
|
||||
|
||||
@ -207,7 +207,6 @@ class AgentLogEvent(BaseAgentEvent):
|
||||
status: str = Field(..., description="status")
|
||||
data: Mapping[str, Any] = Field(..., description="data")
|
||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
|
||||
node_id: str = Field(..., description="agent node id")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent
|
||||
|
||||
@ -18,7 +18,6 @@ from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunM
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseAgentEvent,
|
||||
BaseIterationEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
@ -502,7 +501,7 @@ class GraphEngine:
|
||||
break
|
||||
|
||||
yield event
|
||||
if not isinstance(event, BaseAgentEvent) and event.parallel_id == parallel_id:
|
||||
if event.parallel_id == parallel_id:
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(futures):
|
||||
@ -666,7 +665,7 @@ class GraphEngine:
|
||||
retries += 1
|
||||
route_node_state.node_run_result = run_result
|
||||
yield NodeRunRetryEvent(
|
||||
id=node_instance.id,
|
||||
id=str(uuid.uuid4()),
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
@ -681,7 +680,7 @@ class GraphEngine:
|
||||
start_at=retry_start_at,
|
||||
)
|
||||
time.sleep(retry_interval)
|
||||
continue
|
||||
break
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
|
||||
@ -8,12 +8,12 @@ from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, ParamsAutoGenerated
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||
@ -156,38 +156,16 @@ class AgentNode(ToolNode):
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN.value) == ParamsAutoGenerated.CLOSE.value:
|
||||
value_param = param.get("value", {})
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN.value))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_type=provider_type,
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_parameters=parameters,
|
||||
tool_parameters=tool.get("parameters", {}),
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
)
|
||||
|
||||
@ -200,27 +178,13 @@ class AgentNode(ToolNode):
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("descrption", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
for params in tool_runtime.entity.parameters:
|
||||
params.form = (
|
||||
ToolParameter.ToolParameterForm.FORM
|
||||
if params.name in manual_input_params
|
||||
else params.form
|
||||
)
|
||||
if tool_runtime.entity.parameters:
|
||||
manual_input_value = {
|
||||
key: value for key, value in parameters.items() if key in manual_input_params
|
||||
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": tool_runtime.runtime.runtime_parameters,
|
||||
}
|
||||
runtime_parameters = {
|
||||
**tool_runtime.runtime.runtime_parameters,
|
||||
**manual_input_value,
|
||||
}
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == "model-selector":
|
||||
value = cast(dict[str, Any], value)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -17,8 +16,3 @@ class AgentNodeData(BaseNodeData):
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
|
||||
agent_parameters: dict[str, AgentInput]
|
||||
|
||||
|
||||
class ParamsAutoGenerated(Enum):
|
||||
CLOSE = 0
|
||||
OPEN = 1
|
||||
|
||||
@ -107,8 +107,10 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
return _extract_text_from_plain_text(file_content)
|
||||
case "application/pdf":
|
||||
return _extract_text_from_pdf(file_content)
|
||||
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document" | "application/msword":
|
||||
case "application/msword":
|
||||
return _extract_text_from_doc(file_content)
|
||||
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
return _extract_text_from_docx(file_content)
|
||||
case "text/csv":
|
||||
return _extract_text_from_csv(file_content)
|
||||
case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel":
|
||||
@ -142,8 +144,10 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
|
||||
return _extract_text_from_yaml(file_content)
|
||||
case ".pdf":
|
||||
return _extract_text_from_pdf(file_content)
|
||||
case ".doc" | ".docx":
|
||||
case ".doc":
|
||||
return _extract_text_from_doc(file_content)
|
||||
case ".docx":
|
||||
return _extract_text_from_docx(file_content)
|
||||
case ".csv":
|
||||
return _extract_text_from_csv(file_content)
|
||||
case ".xls" | ".xlsx":
|
||||
@ -203,7 +207,33 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
|
||||
|
||||
def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
"""
|
||||
Extract text from a DOC/DOCX file.
|
||||
Extract text from a DOC file.
|
||||
"""
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
if not (dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY):
|
||||
raise TextExtractionError("UNSTRUCTURED_API_URL and UNSTRUCTURED_API_KEY must be set")
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
with open(temp_file.name, "rb") as file:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY,
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
return "\n".join([getattr(element, "text", "") for element in elements])
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from DOC: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_docx(file_content: bytes) -> str:
|
||||
"""
|
||||
Extract text from a DOCX file.
|
||||
For now support only paragraph and table add more if needed
|
||||
"""
|
||||
try:
|
||||
@ -255,13 +285,13 @@ def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
|
||||
text.append(markdown_table)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
||||
logger.warning(f"Failed to extract table from DOC: {e}")
|
||||
continue
|
||||
|
||||
return "\n".join(text)
|
||||
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
|
||||
raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e
|
||||
|
||||
|
||||
def _download_file_content(file: File) -> bytes:
|
||||
@ -329,14 +359,29 @@ def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_ppt(file=file)
|
||||
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
|
||||
with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
with open(temp_file.name, "rb") as file:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY,
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
else:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_ppt(file=file)
|
||||
return "\n".join([getattr(element, "text", "") for element in elements])
|
||||
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from PPT: {str(e)}") from e
|
||||
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
@ -30,20 +27,3 @@ class EndNode(BaseNode[EndNodeData]):
|
||||
inputs=outputs,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: EndNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@ -64,7 +64,7 @@ class EndStreamGeneratorRouter:
|
||||
node_type = node.get("data", {}).get("type")
|
||||
if (
|
||||
variable_selector.value_selector not in value_selectors
|
||||
and (node_type in (NodeType.LLM.value, NodeType.AGENT.value))
|
||||
and node_type == NodeType.LLM.value
|
||||
and variable_selector.value_selector[1] == "text"
|
||||
):
|
||||
value_selectors.append(list(variable_selector.value_selector))
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@ -88,23 +87,6 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IfElseNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
@deprecated("This function is deprecated. You should use the new cases structure.")
|
||||
def _should_not_use_old_function(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user