Compare commits

..

79 Commits

Author SHA1 Message Date
1e73f63ff8 chore: update version to 0.15.2 in packaging and docker configurations (#12940)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-22 16:40:44 +08:00
d167d5b1be feat(ark): support doubao 1.5 series of models (#12935) 2025-01-22 15:25:57 +08:00
71fa14f791 fix: resolve clipboard.writeText failure under HTTP protocol (#12936) 2025-01-22 15:18:23 +08:00
8dd1873e76 feat: workflow note dark theme (#12932) 2025-01-22 14:22:33 +08:00
f91f5c7401 fix(batch_create_segment_to_index_task): count max_position in memory. (#12929) 2025-01-22 13:39:02 +08:00
c62b7cc679 chore(build): bump poetry from 1.x to 2.x (#12369) 2025-01-22 13:38:24 +08:00
3ee213ddca add milvus full text search setting (#12930) 2025-01-22 13:36:39 +08:00
8429877b02 fix: Agent is configured for ReAct inference mode, an error is reported when viewing the agent log (#12920)
Co-authored-by: crazywoola <427733928@qq.com>
2025-01-22 13:20:32 +08:00
05a0faff6a fix: app token's last_used_at can't be updated when last_used_at is null (#12770) 2025-01-22 11:01:45 +08:00
e09f6e4987 feat: support config chunk length by env (#12925) 2025-01-22 10:43:40 +08:00
e23f4b0265 feat: add gemini-2.0-flash-thinking-exp-01-21 (#12924) 2025-01-22 10:14:37 +08:00
f582d4a13e feat: Add ability to change profile avatar (#12642) 2025-01-22 10:11:31 +08:00
2f41bd495d fix:Fix a bug that returns null when the passed path is a file. (#12775)
Co-authored-by: 刘江波 <jiangbo721@163.com>
2025-01-22 10:10:03 +08:00
162a8c4393 fix update segment keyword with same content (#12908) 2025-01-21 19:19:32 +08:00
3d1ce4c53f bug: fixed bedrock rerank bug (#12774)
Co-authored-by: hobo.l <hobo.l@binance.com>
2025-01-21 19:09:36 +08:00
6db3ae9b8e chore: remove webapp ga (#12909) 2025-01-21 18:38:33 +08:00
6d0cb9dc33 fix: variable panel scrollable (#12769)
Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com>
2025-01-21 17:50:42 +08:00
46e95e8309 fix: OpenAI o1 Bad Request Error (#12839) 2025-01-21 15:29:13 +08:00
a7b9375877 Update deepseek model configuration (#12899) 2025-01-21 15:28:11 +08:00
0c6a8a130e fix: external dataset hit test display issue(#12564) (#12612)
Co-authored-by: zhuxinliang <zhuxinliang@didiglobal.com>
2025-01-21 14:31:45 +08:00
9903f1e703 add deepseek-reasoner (#12898) 2025-01-21 12:40:58 +08:00
6fad719e42 chore(fix): Invalid quotes for using Array[String] in HTTP request node as JSON body (#12761) 2025-01-21 10:38:44 +08:00
9aaee8ee47 fix: Issues related to the deletion of conversation_id (#12488) (#12665) 2025-01-21 10:25:35 +08:00
166221d784 chore(lint): fix quotes for f-string formatting by bumping ruff to 0.9.x (#12702) 2025-01-21 10:12:29 +08:00
925d69a2ee feat:Support Minimax-Text-01 (#12763) 2025-01-21 10:08:53 +08:00
5ff08e241a fix: serply credential check query might return empty records (#12784) 2025-01-21 09:38:56 +08:00
3defd24087 feat: allow updating chunk settings for the existing documents (#12833) 2025-01-21 09:25:40 +08:00
9d86147d20 fix: SparkLite API Auth error (#12781) (#12790) 2025-01-20 22:21:21 +08:00
80801ac4ab fix: "parmas" spelling mistake. (#12875) 2025-01-20 22:18:30 +08:00
210926cd91 Fix suggested_question_prompt (#12738) 2025-01-20 22:16:30 +08:00
677a69deed fix(i18n): correct typo in zh-Hant translation (#12852) 2025-01-20 22:15:41 +08:00
8dfdee21ce chore: fix chinese translation for 'recall' (#12772)
Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com>
2025-01-20 22:15:26 +08:00
6ea77ab4cd fix: DeepSeek API Error with response format active (text and json_object) (#12747) 2025-01-20 22:04:18 +08:00
e3c996688d feat: enhance credential extraction logic based on configurate method (#12853) 2025-01-20 21:59:22 +08:00
bc3a570dda fix: Fix rerank model switching issue (#12721)
ok
2025-01-14 15:42:45 +08:00
0800021a2d chore: translate i18n files (#12708)
Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com>
2025-01-14 13:35:23 +08:00
435eddd867 Feat: copyright modification (#12707) 2025-01-14 10:00:57 +08:00
6e0fb055d1 chore: bump version to 0.15.1 (#12690)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-13 19:21:06 +08:00
eux
1e9ac7ffeb feat: add table of contents to Knowledge API doc (#12688) 2025-01-13 18:31:43 +08:00
b4873ecb43 [fix] support feature restore (#12563) 2025-01-13 18:29:06 +08:00
mbo
1859d57784 api tool support multiple env url (#12249)
Co-authored-by: mabo <mabo@aeyes.ai>
2025-01-13 17:49:30 +08:00
69d58fbb50 Add new integration with Opik Tracking tool (#11501) 2025-01-13 17:41:44 +08:00
cb34991663 fix: add type hints for App model and improve error handling in audio services (#12677)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-13 15:55:16 +08:00
c700364e1c fix: Update variable handling in VariableAssignerNode and clean up app_dsl_service (#12672)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-13 15:54:26 +08:00
9a6b1dc3a1 Revert "Feat/new saas billing" (#12673) 2025-01-13 15:17:43 +08:00
54b5b80a07 fix(workflow): fix answer node stream processing in conditional branches (#12510) 2025-01-13 14:54:21 +08:00
831459b895 fix: ruff with statements (#12578)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-01-13 09:55:55 +08:00
4e101604c3 fix: ruff check for True if ... else (#12576)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2025-01-13 09:38:48 +08:00
a6455269f0 chore: Adjust translations to align with Taiwanese Mandarin conventions (#12633) 2025-01-13 09:12:43 +08:00
cd257b91c5 Fix pandas indexing method for knowledge base imports (#12637) (#12638)
Co-authored-by: CN-P5 <heibai2006@qq.com>
2025-01-13 09:06:59 +08:00
d8f57bf899 Feat/new saas billing (#12591) 2025-01-12 14:50:46 +08:00
989fb11fd7 improve the readability of the function generate_api_key (#12552) 2025-01-09 21:30:17 +08:00
140965b738 chore: translate i18n files (#12543)
Co-authored-by: WTW0313 <30284043+WTW0313@users.noreply.github.com>
2025-01-09 20:30:06 +08:00
14ee51aead Feat/add knowledge include all filter (#12537) 2025-01-09 20:21:25 +08:00
2e97ba5700 fix: Add datasets list access control and fix datasets config display issue (#12533)
Co-authored-by: nite-knite <nkCoding@gmail.com>
2025-01-09 17:44:11 +08:00
f549d53b68 fix: sum costs return error value on overview page (#12534) 2025-01-09 16:04:14 +08:00
a085ad4719 feat: show workflow running status (#12531) 2025-01-09 15:36:13 +08:00
f230a9232e fix: Parsing OpenAPI spec for external tools (#12518) (#12530) 2025-01-09 15:30:43 +08:00
e84bf35e2a fix: same chunk insert deadlock (#12502)
Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com>
2025-01-09 15:16:41 +08:00
eux
20f090537f feat: add GET upload file API endpoint to dataset service api (#11899) 2025-01-09 14:52:09 +08:00
dbe7a7c4fd Fix: Add a INFO-level log when fallback to gpt2tokenizer (#12508) 2025-01-09 14:37:46 +08:00
b7a4e3903e fix: add last_refresh_time to track the validity of is_other_tab_refreshing (#12517) 2025-01-09 10:40:45 +08:00
b4c1c2f731 fix: Reverse sync docker-compose-template.yaml (#12509) 2025-01-09 10:21:22 +08:00
1b940e7daa feat: add ci job to test template for docker compose (#12514) 2025-01-09 00:04:58 +08:00
f4ee50a7ad chore: improve app doc (#12490) 2025-01-08 18:37:12 +08:00
bee32d960a fix #12453 #12482 (#12495) 2025-01-08 18:26:05 +08:00
040a3b782c FEAT: support milvus to full text search (#11430)
Signed-off-by: YoungLH <974840768@qq.com>
2025-01-08 17:39:53 +08:00
d649037c3e feat: support single run doc extractor node (#11318) 2025-01-08 15:20:15 +08:00
0a49d3dd52 fix: tiktoken cannot be loaded without internet (#12478)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-08 14:49:44 +08:00
53bb37b749 fix: fix the incorrect plaintext file key when saving (#10429) 2025-01-08 12:52:45 +08:00
d2586278d6 Feat elasticsearch japanese (#12194) 2025-01-08 12:35:41 +08:00
6635c393e9 fix: adjust opacity for model selector based on readonly state (#12472) 2025-01-08 12:11:45 +08:00
6222179a57 Revert "fix:deepseek tool call not working correctly" (#12463) 2025-01-08 10:50:34 +08:00
05bda6f38d add tidb on qdrant redis lock (#12462) 2025-01-08 08:55:44 +08:00
4295cefeb1 fix: allow fallback to remote_url when url is not provided (#12455) 2025-01-07 22:33:25 +08:00
67228c9b26 fix: url with variable not work (#12452) 2025-01-07 21:55:51 +08:00
fd2bfff023 remove knowledge admin role (#12450) 2025-01-07 21:30:23 +08:00
4e6c86341d Add 'document' feature to Sonnet 3.5 through OpenRouter (#12444) 2025-01-07 19:51:38 +08:00
2a14c67edc Fix #12448 - update bedrock retrieve tool, support hybrid search type and re… (#12446)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
2025-01-07 19:51:23 +08:00
313 changed files with 5498 additions and 1373 deletions

View File

@ -8,7 +8,7 @@ inputs:
poetry-version:
description: Poetry version to set up
required: true
default: '1.8.4'
default: '2.0.1'
poetry-lockfile:
description: Path to the Poetry lockfile to restore cache from
required: true

View File

@ -42,25 +42,23 @@ jobs:
run: poetry install -C api --with dev
- name: Check dependencies in pyproject.toml
run: poetry run -C api bash dev/pytest/pytest_artifacts.sh
run: poetry run -P api bash dev/pytest/pytest_artifacts.sh
- name: Run Unit tests
run: poetry run -C api bash dev/pytest/pytest_unit_tests.sh
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh
- name: Run ModelRuntime
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
run: poetry run -P api bash dev/pytest/pytest_model_runtime.sh
- name: Run dify config tests
run: poetry run -C api python dev/pytest/pytest_config_tests.py
run: poetry run -P api python dev/pytest/pytest_config_tests.py
- name: Run Tool
run: poetry run -C api bash dev/pytest/pytest_tools.sh
run: poetry run -P api bash dev/pytest/pytest_tools.sh
- name: Run mypy
run: |
pushd api
poetry run python -m mypy --install-types --non-interactive .
popd
poetry run -C api python -m mypy --install-types --non-interactive .
- name: Set up dotenvs
run: |
@ -80,4 +78,4 @@ jobs:
ssrf_proxy
- name: Run Workflow
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
run: poetry run -P api bash dev/pytest/pytest_workflow.sh

View File

@ -38,12 +38,12 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: |
poetry run -C api ruff --version
poetry run -C api ruff check ./api
poetry run -C api ruff format --check ./api
poetry run -C api ruff check ./
poetry run -C api ruff format --check ./
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
run: poetry run -P api dotenv-linter ./api/.env.example ./web/.env.example
- name: Lint hints
if: failure()
@ -82,6 +82,33 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: yarn run lint
docker-compose-template:
name: Docker Compose Template
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- name: Generate Docker Compose
if: steps.changed-files.outputs.any_changed == 'true'
run: |
cd docker
./generate_docker_compose
- name: Check for changes
if: steps.changed-files.outputs.any_changed == 'true'
run: git diff --exit-code
superlinter:
name: SuperLinter

View File

@ -70,4 +70,4 @@ jobs:
tidb
- name: Test Vector Stores
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
run: poetry run -P api bash dev/pytest/pytest_vdb.sh

View File

@ -53,10 +53,12 @@ ignore = [
"FURB152", # math-constant
"UP007", # non-pep604-annotation
"UP032", # f-string
"UP045", # non-pep604-annotation-optional
"B005", # strip-with-multi-characters
"B006", # mutable-argument-default
"B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg
"B903", # class-as-data-structure
"B904", # raise-without-from-inside-except
"B905", # zip-without-explicit-strict
"N806", # non-lowercase-variable-in-function

View File

@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
WORKDIR /app/api
# Install Poetry
ENV POETRY_VERSION=1.8.4
ENV POETRY_VERSION=2.0.1
# if you located in China, you can use aliyun mirror to speed up
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/

View File

@ -79,5 +79,5 @@
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
```bash
poetry run -C api bash dev/pytest/pytest_all_tests.sh
poetry run -P api bash dev/pytest/pytest_all_tests.sh
```

View File

@ -146,7 +146,7 @@ class EndpointConfig(BaseSettings):
)
CONSOLE_WEB_URL: str = Field(
description="Base URL for the console web interface," "used for frontend references and CORS configuration",
description="Base URL for the console web interface,used for frontend references and CORS configuration",
default="",
)

View File

@ -181,7 +181,7 @@ class HostedFetchAppTemplateConfig(BaseSettings):
"""
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
description="Mode for fetching app templates: remote, db, or builtin" " default to remote,",
description="Mode for fetching app templates: remote, db, or builtin default to remote,",
default="remote",
)

View File

@ -33,3 +33,9 @@ class MilvusConfig(BaseSettings):
description="Name of the Milvus database to connect to (default is 'default')",
default="default",
)
MILVUS_ENABLE_HYBRID_SEARCH: bool = Field(
description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with "
"older versions",
default=True,
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.15.0",
default="0.15.2",
)
COMMIT_SHA: str = Field(

View File

@ -56,7 +56,7 @@ class InsertExploreAppListApi(Resource):
app = App.query.filter(App.id == args["app_id"]).first()
if not app:
raise NotFound(f'App \'{args["app_id"]}\' is not found')
raise NotFound(f"App '{args['app_id']}' is not found")
site = app.site
if not site:

View File

@ -22,7 +22,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from models.model import AppMode
from models import App, AppMode
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@ -79,7 +79,7 @@ class ChatMessageTextApi(Resource):
@login_required
@account_initialization_required
@get_app_model
def post(self, app_model):
def post(self, app_model: App):
from werkzeug.exceptions import InternalServerError
try:
@ -98,9 +98,13 @@ class ChatMessageTextApi(Resource):
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
if text_to_speech is None:
raise ValueError("TTS is not enabled")
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
if app_model.app_model_config is None:
raise ValueError("AppModelConfig not found")
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None

View File

@ -52,12 +52,12 @@ class DatasetListApi(Resource):
# provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true"
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(
page, limit, current_user.current_tenant_id, current_user, search, tag_ids
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
)
# check embedding setting
@ -457,7 +457,7 @@ class DatasetIndexingEstimateApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -619,8 +619,7 @@ class DatasetRetrievalSettingApi(Resource):
vector_type = dify_config.VECTOR_STORE
match vector_type:
case (
VectorType.MILVUS
| VectorType.RELYT
VectorType.RELYT
| VectorType.PGVECTOR
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
@ -640,10 +639,12 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.PGVECTOR
| VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM
| VectorType.COUCHBASE
| VectorType.MILVUS
):
return {
"retrieval_method": [
@ -683,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.COUCHBASE
| VectorType.PGVECTOR
| VectorType.LINDORM

View File

@ -257,7 +257,8 @@ class DatasetDocumentListApi(Resource):
parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
@ -349,8 +350,7 @@ class DatasetInitApi(Resource):
)
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -525,8 +525,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
return response.model_dump(), 200
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)

View File

@ -168,8 +168,7 @@ class DatasetDocumentSegmentApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -217,8 +216,7 @@ class DatasetDocumentSegmentAddApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -267,8 +265,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -368,9 +365,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
result = []
for index, row in df.iterrows():
if document.doc_form == "qa_model":
data = {"content": row[0], "answer": row[1]}
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row[0]}
data = {"content": row.iloc[0]}
result.append(data)
if len(result) == 0:
raise ValueError("The CSV file is empty.")
@ -437,8 +434,7 @@ class ChildChunkAddApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)

View File

@ -32,7 +32,7 @@ class ConversationListApi(InstalledAppResource):
pinned = None
if "pinned" in args and args["pinned"] is not None:
pinned = True if args["pinned"] == "true" else False
pinned = args["pinned"] == "true"
try:
with Session(db.engine) as session:

View File

@ -7,4 +7,4 @@ api = ExternalApi(bp)
from . import index
from .app import app, audio, completion, conversation, file, message, workflow
from .dataset import dataset, document, hit_testing, segment
from .dataset import dataset, document, hit_testing, segment, upload_file

View File

@ -31,8 +31,11 @@ class DatasetListApi(DatasetApiResource):
# provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true"
datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids)
datasets, total = DatasetService.get_datasets(
page, limit, tenant_id, current_user, search, tag_ids, include_all
)
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)

View File

@ -53,8 +53,7 @@ class SegmentApi(DatasetApiResource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -95,8 +94,7 @@ class SegmentApi(DatasetApiResource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -175,8 +173,7 @@ class DatasetSegmentApi(DatasetApiResource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)

View File

@ -0,0 +1,54 @@
from werkzeug.exceptions import NotFound
from controllers.service_api import api
from controllers.service_api.wraps import (
DatasetApiResource,
)
from core.file import helpers as file_helpers
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import UploadFile
from services.dataset_service import DocumentService
class UploadFileApi(DatasetApiResource):
def get(self, tenant_id, dataset_id, document_id):
"""Get upload file."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check upload file
if document.data_source_type != "upload_file":
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("UploadFile not found.")
else:
raise ValueError("Upload file id not found in document data source info.")
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"url": url,
"download_url": f"{url}&as_attachment=true",
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at.timestamp(),
}, 200
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")

View File

@ -195,7 +195,11 @@ def validate_and_get_api_token(scope: str | None = None):
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
update(ApiToken)
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
.where(
ApiToken.token == auth_token,
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
ApiToken.type == scope,
)
.values(last_used_at=current_time)
.returning(ApiToken)
)
@ -236,7 +240,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=True if user_id == "DEFAULT-USER" else False,
is_anonymous=user_id == "DEFAULT-USER",
session_id=user_id,
)
db.session.add(end_user)

View File

@ -39,7 +39,7 @@ class ConversationListApi(WebApiResource):
pinned = None
if "pinned" in args and args["pinned"] is not None:
pinned = True if args["pinned"] == "true" else False
pinned = args["pinned"] == "true"
try:
with Session(db.engine) as session:

View File

@ -172,7 +172,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else "",
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_invoke_meta={},
thought=scratchpad.thought or "",

View File

@ -167,8 +167,7 @@ class AppQueueManager:
else:
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances "
"that cause thread safety issues is not allowed."
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
)

View File

@ -89,6 +89,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
Conversation.status == "normal",
Conversation.is_deleted.is_(False),
]
if isinstance(user, Account):

View File

@ -145,7 +145,7 @@ class MessageCycleManage:
# get extension
if "." in message_file.url:
extension = f'.{message_file.url.split(".")[-1]}'
extension = f".{message_file.url.split('.')[-1]}"
if len(extension) > 10:
extension = ".bin"
else:

View File

@ -62,8 +62,9 @@ class ApiExternalDataTool(ExternalDataTool):
if not api_based_extension:
raise ValueError(
"[External data tool] API query failed, variable: {}, "
"error: api_based_extension_id is invalid".format(self.variable)
"[External data tool] API query failed, variable: {}, error: api_based_extension_id is invalid".format(
self.variable
)
)
# decrypt api_key

View File

@ -90,7 +90,7 @@ class File(BaseModel):
def markdown(self) -> str:
url = self.generate_url()
if self.type == FileType.IMAGE:
text = f'![{self.filename or ""}]({url})'
text = f"![{self.filename or ''}]({url})"
else:
text = f"[{self.filename or url}]({url})"

View File

@ -530,7 +530,6 @@ class IndexingRunner:
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
tokens = 0
chunk_size = 10
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
# create keyword index
create_keyword_thread = threading.Thread(
@ -539,11 +538,22 @@ class IndexingRunner:
)
create_keyword_thread.start()
max_workers = 10
if dataset.indexing_technique == "high_quality":
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for i in range(0, len(documents), chunk_size):
chunk_documents = documents[i : i + chunk_size]
# Distribute documents into multiple groups based on the hash values of page_content
# This is done to prevent multiple threads from processing the same document,
# Thereby avoiding potential database insertion deadlocks
document_groups: list[list[Document]] = [[] for _ in range(max_workers)]
for document in documents:
hash = helper.generate_text_hash(document.page_content)
group_index = int(hash, 16) % max_workers
document_groups[group_index].append(document)
for chunk_documents in document_groups:
if len(chunk_documents) == 0:
continue
futures.append(
executor.submit(
self._process_chunk,

View File

@ -131,7 +131,7 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n"
"MAKE SURE your output is the SAME language as the Assistant's latest response"
"MAKE SURE your output is the SAME language as the Assistant's latest response. "
"The output must be an array in JSON format following the specified schema:\n"
'["question1","question2","question3"]\n'
)

View File

@ -1,7 +1,8 @@
import logging
from threading import Lock
from typing import Any
import tiktoken
logger = logging.getLogger(__name__)
_tokenizer: Any = None
_lock = Lock()
@ -33,9 +34,18 @@ class GPT2Tokenizer:
if _tokenizer is None:
# Try to use tiktoken to get the tokenizer because it is faster
#
_tokenizer = tiktoken.get_encoding("gpt2")
# base_path = abspath(__file__)
# gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
# _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
try:
import tiktoken
_tokenizer = tiktoken.get_encoding("gpt2")
except Exception:
from os.path import abspath, dirname, join
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken")
return _tokenizer

View File

@ -108,7 +108,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not ai_model_entity:
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
raise CredentialsValidateFailedError(f"Base Model Name {credentials['base_model_name']} is invalid")
try:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))

View File

@ -130,7 +130,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
raise CredentialsValidateFailedError("Base Model Name is required")
if not self._get_ai_model_entity(credentials["base_model_name"], model):
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
raise CredentialsValidateFailedError(f"Base Model Name {credentials['base_model_name']} is invalid")
try:
credentials_kwargs = self._to_credential_kwargs(credentials)

View File

@ -70,7 +70,7 @@ class BedrockRerankModel(RerankModel):
rerankingConfiguration = {
"type": "BEDROCK_RERANKING_MODEL",
"bedrockRerankingConfiguration": {
"numberOfResults": top_n,
"numberOfResults": min(top_n, len(text_sources)),
"modelConfiguration": {
"modelArn": model_package_arn,
},

View File

@ -1,2 +1,3 @@
- deepseek-chat
- deepseek-coder
- deepseek-reasoner

View File

@ -10,7 +10,7 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
context_size: 64000
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -10,7 +10,7 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
context_size: 64000
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -0,0 +1,21 @@
model: deepseek-reasoner
label:
zh_Hans: deepseek-reasoner
en_US: deepseek-reasoner
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 64000
parameter_rules:
- name: max_tokens
use_template: max_tokens
min: 1
max: 8192
default: 4096
pricing:
input: "4"
output: "16"
unit: "0.000001"
currency: RMB

View File

@ -24,9 +24,6 @@ class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
# {"response_format": "xx"} need convert to {"response_format": {"type": "xx"}}
if "response_format" in model_parameters:
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
def validate_credentials(self, model: str, credentials: dict) -> None:

View File

@ -1,5 +1,6 @@
- gemini-2.0-flash-exp
- gemini-2.0-flash-thinking-exp-1219
- gemini-2.0-flash-thinking-exp-01-21
- gemini-1.5-pro
- gemini-1.5-pro-latest
- gemini-1.5-pro-001

View File

@ -0,0 +1,39 @@
model: gemini-2.0-flash-thinking-exp-01-21
label:
en_US: Gemini 2.0 Flash Thinking Exp 01-21
model_type: llm
features:
- agent-thought
- vision
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -162,9 +162,9 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
@staticmethod
def _check_endpoint_url_model_repository_name(credentials: dict, model_name: str):
try:
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
url = f"{HUGGINGFACE_ENDPOINT_API}{credentials['huggingface_namespace']}"
headers = {
"Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}',
"Authorization": f"Bearer {credentials['huggingfacehub_api_token']}",
"Content-Type": "application/json",
}

View File

@ -34,6 +34,7 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
class MinimaxLargeLanguageModel(LargeLanguageModel):
model_apis = {
"minimax-text-01": MinimaxChatCompletionPro,
"abab7-chat-preview": MinimaxChatCompletionPro,
"abab6.5t-chat": MinimaxChatCompletionPro,
"abab6.5s-chat": MinimaxChatCompletionPro,

View File

@ -0,0 +1,46 @@
model: minimax-text-01
label:
en_US: Minimax-Text-01
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 1000192
parameter_rules:
- name: temperature
use_template: temperature
min: 0.01
max: 1
default: 0.1
- name: top_p
use_template: top_p
min: 0.01
max: 1
default: 0.95
- name: max_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 1000192
- name: mask_sensitive_info
type: boolean
default: true
label:
zh_Hans: 隐私保护
en_US: Moderate
help:
zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码目前包括但不限于邮箱、域名、链接、证件号、家庭住址等默认true即开启打码
en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
pricing:
input: '0.001'
output: '0.008'
unit: '0.001'
currency: RMB

View File

@ -44,9 +44,6 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
self._add_custom_parameters(credentials)
self._add_function_call(model, credentials)
user = user[:32] if user else None
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
if "response_format" in model_parameters:
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:

View File

@ -1,5 +1,6 @@
import json
import logging
import re
from collections.abc import Generator
from typing import Any, Optional, Union, cast
@ -621,11 +622,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
# o1 compatibility
block_as_stream = False
if model.startswith("o1"):
if "max_tokens" in model_parameters:
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
del model_parameters["max_tokens"]
if re.match(r"^o1(-\d{4}-\d{2}-\d{2})?$", model):
if stream:
block_as_stream = True
stream = False
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]
if "stop" in extra_model_kwargs:
del extra_model_kwargs["stop"]
@ -642,7 +651,45 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
if block_as_stream:
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
return block_result
def _handle_chat_block_as_stream_response(
self,
block_result: LLMResult,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return: llm response chunk generator
"""
text = block_result.message.content
text = cast(str, text)
if stop:
text = self.enforce_stop_tokens(text, stop)
yield LLMResultChunk(
model=block_result.model,
prompt_messages=prompt_messages,
system_fingerprint=block_result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=block_result.message,
finish_reason="stop",
usage=block_result.usage,
),
)
def _handle_chat_generate_response(
self,

View File

@ -377,10 +377,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
for tool in tools:
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
if prompt_messages[-1].role.value == "tool":
data["tools"] = None
else:
data["tools"] = formatted_tools
data["tools"] = formatted_tools
if stop:
data["stop"] = stop

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 200000

View File

@ -29,9 +29,6 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel):
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
if "response_format" in model_parameters:
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
def validate_credentials(self, model: str, credentials: dict) -> None:

View File

@ -21,7 +21,7 @@ class SparkLLMClient:
domain = api_domain
model_api_configs = {
"spark-lite": {"version": "v1.1", "chat_domain": "general"},
"spark-lite": {"version": "v1.1", "chat_domain": "lite"},
"spark-pro": {"version": "v3.1", "chat_domain": "generalv3"},
"spark-pro-128k": {"version": "pro-128k", "chat_domain": "pro-128k"},
"spark-max": {"version": "v3.5", "chat_domain": "generalv3.5"},

View File

@ -257,8 +257,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
for index, response in enumerate(responses):
if response.status_code not in {200, HTTPStatus.OK}:
raise ServiceUnavailableError(
f"Failed to invoke model {model}, status code: {response.status_code}, "
f"message: {response.message}"
f"Failed to invoke model {model}, status code: {response.status_code}, message: {response.message}"
)
resp_finish_reason = response.output.choices[0].finish_reason

View File

@ -146,7 +146,7 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel):
elif credentials["completion_type"] == "completion":
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
raise ValueError(f"completion_type {credentials['completion_type']} is not supported")
entity = AIModelEntity(
model=model,

View File

@ -18,72 +18,93 @@ class ModelConfig(BaseModel):
configs: dict[str, ModelConfig] = {
"Doubao-1.5-vision-pro-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=12288, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.VISION],
),
"Doubao-1.5-pro-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=12288, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
),
"Doubao-1.5-lite-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=12288, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
),
"Doubao-1.5-pro-256k": ModelConfig(
properties=ModelProperties(context_size=262144, max_tokens=12288, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
),
"Doubao-vision-pro-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.VISION],
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.VISION],
),
"Doubao-vision-lite-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.VISION],
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.VISION],
),
"Doubao-pro-4k": ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL],
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
),
"Doubao-lite-4k": ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL],
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
),
"Doubao-pro-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL],
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
),
"Doubao-lite-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL],
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
),
"Doubao-pro-256k": ModelConfig(
properties=ModelProperties(context_size=262144, max_tokens=4096, mode=LLMMode.CHAT),
features=[],
features=[ModelFeature.AGENT_THOUGHT],
),
"Doubao-pro-128k": ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL],
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
),
"Doubao-lite-128k": ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[]
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
),
"Skylark2-pro-4k": ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), features=[]
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
),
"Llama3-8B": ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[]
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
),
"Llama3-70B": ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[]
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
),
"Moonshot-v1-8k": ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL],
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
),
"Moonshot-v1-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL],
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
),
"Moonshot-v1-128k": ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL],
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
),
"GLM3-130B": ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL],
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
),
"GLM3-130B-Fin": ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL],
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
),
"Mistral-7B": ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), features=[]
properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
),
}

View File

@ -118,6 +118,30 @@ model_credential_schema:
type: select
required: true
options:
- label:
en_US: Doubao-1.5-vision-pro-32k
value: Doubao-1.5-vision-pro-32k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-1.5-pro-32k
value: Doubao-1.5-pro-32k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-1.5-lite-32k
value: Doubao-1.5-lite-32k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-1.5-pro-256k
value: Doubao-1.5-pro-256k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-vision-pro-32k
value: Doubao-vision-pro-32k

View File

@ -41,15 +41,15 @@ class BaiduAccessToken:
resp = response.json()
if "error" in resp:
if resp["error"] == "invalid_client":
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
raise InvalidAPIKeyError(f"Invalid API key or secret key: {resp['error_description']}")
elif resp["error"] == "unknown_error":
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
raise InternalServerError(f"Internal server error: {resp['error_description']}")
elif resp["error"] == "invalid_request":
raise BadRequestError(f'Bad request: {resp["error_description"]}')
raise BadRequestError(f"Bad request: {resp['error_description']}")
elif resp["error"] == "rate_limit_exceeded":
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
raise RateLimitReachedError(f"Rate limit reached: {resp['error_description']}")
else:
raise Exception(f'Unknown error: {resp["error_description"]}')
raise Exception(f"Unknown error: {resp['error_description']}")
return resp["access_token"]

View File

@ -406,7 +406,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
elif credentials["completion_type"] == "completion":
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
raise ValueError(f"completion_type {credentials['completion_type']} is not supported")
else:
extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials["server_url"],
@ -472,7 +472,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
api_key = credentials.get("api_key") or "abc"
client = OpenAI(
base_url=f'{credentials["server_url"]}/v1',
base_url=f"{credentials['server_url']}/v1",
api_key=api_key,
max_retries=int(credentials.get("max_retries") or DEFAULT_MAX_RETRIES),
timeout=int(credentials.get("invoke_timeout") or DEFAULT_INVOKE_TIMEOUT),

View File

@ -87,6 +87,6 @@ class CommonValidator:
if value.lower() not in {"true", "false"}:
raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
value = True if value.lower() == "true" else False
value = value.lower() == "true"
return value

View File

@ -6,6 +6,7 @@ from pydantic import BaseModel, ValidationInfo, field_validator
class TracingProviderEnum(Enum):
LANGFUSE = "langfuse"
LANGSMITH = "langsmith"
OPIK = "opik"
class BaseTracingConfig(BaseModel):
@ -56,5 +57,36 @@ class LangSmithConfig(BaseTracingConfig):
return v
class OpikConfig(BaseTracingConfig):
"""
Model class for Opik tracing config.
"""
api_key: str | None = None
project: str | None = None
workspace: str | None = None
url: str = "https://www.comet.com/opik/api/"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "Default Project"
return v
@field_validator("url")
@classmethod
def url_validator(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://www.comet.com/opik/api/"
if not v.startswith(("https://", "http://")):
raise ValueError("url must start with https:// or http://")
if not v.endswith("/api/"):
raise ValueError("url should ends with /api/")
return v
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

View File

View File

@ -0,0 +1,469 @@
import json
import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import Optional, cast
from opik import Opik, Trace
from opik.id_helpers import uuid4_to_uuid7
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import OpikConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__)
def wrap_dict(key_name, data):
"""Make sure that the input data is a dict"""
if not isinstance(data, dict):
return {key_name: data}
return data
def wrap_metadata(metadata, **kwargs):
"""Add common metatada to all Traces and Spans"""
metadata["created_from"] = "dify"
metadata.update(kwargs)
return metadata
def prepare_opik_uuid(user_datetime: Optional[datetime], user_uuid: Optional[str]):
"""Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most
messages and objects. The type-hints of BaseTraceInfo indicates that
objects start_time and message_id could be null which means we cannot map
it to a UUIDv7. Given that we have no way to identify that object
uniquely, generate a new random one UUIDv7 in that case.
"""
if user_datetime is None:
user_datetime = datetime.now()
if user_uuid is None:
user_uuid = str(uuid.uuid4())
return uuid4_to_uuid7(user_datetime, user_uuid)
class OpikDataTrace(BaseTraceInstance):
def __init__(
self,
opik_config: OpikConfig,
):
super().__init__(opik_config)
self.opik_client = Opik(
project_name=opik_config.project,
workspace=opik_config.workspace,
host=opik_config.url,
api_key=opik_config.api_key,
)
self.project = opik_config.project
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
dify_trace_id = trace_info.workflow_run_id
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
workflow_metadata = wrap_metadata(
trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
)
root_span_id = None
if trace_info.message_id:
dify_trace_id = trace_info.message_id
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
trace_data = {
"id": opik_trace_id,
"name": TraceTaskName.MESSAGE_TRACE.value,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"tags": ["message", "workflow"],
"project_name": self.project,
}
self.add_trace(trace_data)
root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
span_data = {
"id": root_span_id,
"parent_span_id": None,
"trace_id": opik_trace_id,
"name": TraceTaskName.WORKFLOW_TRACE.value,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
"tags": ["workflow"],
"project_name": self.project,
}
self.add_span(span_data)
else:
trace_data = {
"id": opik_trace_id,
"name": TraceTaskName.MESSAGE_TRACE.value,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"tags": ["workflow"],
"project_name": self.project,
}
self.add_trace(trace_data)
# through workflow_run_id get all_nodes_execution
workflow_nodes_execution_id_records = (
db.session.query(WorkflowNodeExecution.id)
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
.all()
)
for node_execution_id_record in workflow_nodes_execution_id_records:
node_execution = (
db.session.query(
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id
app_id = node_execution.app_id
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == "llm":
inputs = (
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
)
else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = (
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
)
metadata = execution_metadata.copy()
metadata.update(
{
"workflow_run_id": trace_info.workflow_run_id,
"node_execution_id": node_execution_id,
"tenant_id": tenant_id,
"app_id": app_id,
"app_name": node_name,
"node_type": node_type,
"status": status,
}
)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
provider = None
model = None
total_tokens = 0
completion_tokens = 0
prompt_tokens = 0
if process_data and process_data.get("model_mode") == "chat":
run_type = "llm"
provider = process_data.get("model_provider", None)
model = process_data.get("model_name", "")
metadata.update(
{
"ls_provider": provider,
"ls_model_name": model,
}
)
try:
if outputs.get("usage"):
total_tokens = outputs["usage"].get("total_tokens", 0)
prompt_tokens = outputs["usage"].get("prompt_tokens", 0)
completion_tokens = outputs["usage"].get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
else:
run_type = "tool"
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
if not total_tokens:
total_tokens = execution_metadata.get("total_tokens", 0)
span_data = {
"trace_id": opik_trace_id,
"id": prepare_opik_uuid(created_at, node_execution_id),
"parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id),
"name": node_type,
"type": run_type,
"start_time": created_at,
"end_time": finished_at,
"metadata": wrap_metadata(metadata),
"input": wrap_dict("input", inputs),
"output": wrap_dict("output", outputs),
"tags": ["node_execution"],
"project_name": self.project,
"usage": {
"total_tokens": total_tokens,
"completion_tokens": completion_tokens,
"prompt_tokens": prompt_tokens,
},
"model": model,
"provider": provider,
}
self.add_span(span_data)
def message_trace(self, trace_info: MessageTraceInfo):
# get message file data
file_list = cast(list[str], trace_info.file_list) or []
message_file_data: Optional[MessageFile] = trace_info.message_file_data
if message_file_data is not None:
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
message_data = trace_info.message_data
if message_data is None:
return
metadata = trace_info.metadata
message_id = trace_info.message_id
user_id = message_data.from_account_id
metadata["user_id"] = user_id
metadata["file_list"] = file_list
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id
metadata["end_user_id"] = end_user_id
trace_data = {
"id": prepare_opik_uuid(trace_info.start_time, message_id),
"name": TraceTaskName.MESSAGE_TRACE.value,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(metadata),
"input": trace_info.inputs,
"output": message_data.answer,
"tags": ["message", str(trace_info.conversation_mode)],
"project_name": self.project,
}
trace = self.add_trace(trace_data)
span_data = {
"trace_id": trace.id,
"name": "llm",
"type": "llm",
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(metadata),
"input": {"input": trace_info.inputs},
"output": {"output": message_data.answer},
"tags": ["llm", str(trace_info.conversation_mode)],
"usage": {
"completion_tokens": trace_info.answer_tokens,
"prompt_tokens": trace_info.message_tokens,
"total_tokens": trace_info.total_tokens,
},
"project_name": self.project,
}
self.add_span(span_data)
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
"name": TraceTaskName.MODERATION_TRACE.value,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": {
"action": trace_info.action,
"flagged": trace_info.flagged,
"preset_response": trace_info.preset_response,
"inputs": trace_info.inputs,
},
"tags": ["moderation"],
}
self.add_span(span_data)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
start_time = trace_info.start_time or message_data.created_at
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or message_data.updated_at,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": wrap_dict("output", trace_info.suggested_question),
"tags": ["suggested_question"],
}
self.add_span(span_data)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": {"documents": trace_info.documents},
"tags": ["dataset_retrieval"],
}
self.add_span(span_data)
def tool_trace(self, trace_info: ToolTraceInfo):
span_data = {
"trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
"name": trace_info.tool_name,
"type": "tool",
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.tool_inputs),
"output": wrap_dict("output", trace_info.tool_outputs),
"tags": ["tool", trace_info.tool_name],
}
self.add_span(span_data)
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
trace_data = {
"id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),
"input": trace_info.inputs,
"output": trace_info.outputs,
"tags": ["generate_name"],
"project_name": self.project,
}
trace = self.add_trace(trace_data)
span_data = {
"trace_id": trace.id,
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": wrap_dict("output", trace_info.outputs),
"tags": ["generate_name"],
}
self.add_span(span_data)
def add_trace(self, opik_trace_data: dict) -> Trace:
try:
trace = self.opik_client.trace(**opik_trace_data)
logger.debug("Opik Trace created successfully")
return trace
except Exception as e:
raise ValueError(f"Opik Failed to create trace: {str(e)}")
def add_span(self, opik_span_data: dict):
try:
self.opik_client.span(**opik_span_data)
logger.debug("Opik Span created successfully")
except Exception as e:
raise ValueError(f"Opik Failed to create span: {str(e)}")
def api_check(self):
try:
self.opik_client.auth_check()
return True
except Exception as e:
logger.info(f"Opik API check failed: {str(e)}", exc_info=True)
raise ValueError(f"Opik API check failed: {str(e)}")
def get_project_url(self):
try:
return self.opik_client.get_project_url(project_name=self.project)
except Exception as e:
logger.info(f"Opik get run url failed: {str(e)}", exc_info=True)
raise ValueError(f"Opik get run url failed: {str(e)}")

View File

@ -17,6 +17,7 @@ from core.ops.entities.config_entity import (
OPS_FILE_PATH,
LangfuseConfig,
LangSmithConfig,
OpikConfig,
TracingProviderEnum,
)
from core.ops.entities.trace_entity import (
@ -32,6 +33,7 @@ from core.ops.entities.trace_entity import (
)
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from core.ops.opik_trace.opik_trace import OpikDataTrace
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
@ -52,6 +54,12 @@ provider_config_map: dict[str, dict[str, Any]] = {
"other_keys": ["project", "endpoint"],
"trace_instance": LangSmithDataTrace,
},
TracingProviderEnum.OPIK.value: {
"config_class": OpikConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "url", "workspace"],
"trace_instance": OpikDataTrace,
},
}

View File

@ -22,7 +22,12 @@ from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.helper.position_helper import is_filtered
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
from core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
CredentialFormSchema,
FormType,
ProviderEntity,
)
from core.model_runtime.model_providers import model_provider_factory
from extensions import ext_hosting_provider
from extensions.ext_database import db
@ -835,11 +840,18 @@ class ProviderManager:
:return:
"""
# Get provider model credential secret variables
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema
else []
)
if ConfigurateMethod.PREDEFINED_MODEL in provider_entity.configurate_methods:
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.provider_credential_schema.credential_form_schemas
if provider_entity.provider_credential_schema
else []
)
else:
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema
else []
)
model_settings: list[ModelSettings] = []
if not provider_model_settings:

View File

@ -0,0 +1,104 @@
import json
import logging
from typing import Any, Optional
from flask import current_app
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchConfig,
ElasticSearchVector,
ElasticSearchVectorFactory,
)
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class ElasticSearchJaVector(ElasticSearchVector):
def create_collection(
self,
embeddings: list[list[float]],
metadatas: Optional[list[dict[Any, Any]]] = None,
index_params: Optional[dict] = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return
if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
settings = {
"analysis": {
"analyzer": {
"ja_analyzer": {
"type": "custom",
"char_filter": [
"icu_normalizer",
"kuromoji_iteration_mark",
],
"tokenizer": "kuromoji_tokenizer",
"filter": [
"kuromoji_baseform",
"kuromoji_part_of_speech",
"ja_stop",
"kuromoji_number",
"kuromoji_stemmer",
],
}
}
}
}
mappings = {
"properties": {
Field.CONTENT_KEY.value: {
"type": "text",
"analyzer": "ja_analyzer",
"search_analyzer": "ja_analyzer",
},
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
},
},
}
}
self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
return ElasticSearchJaVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get("ELASTICSEARCH_HOST", "localhost"),
port=config.get("ELASTICSEARCH_PORT", 9200),
username=config.get("ELASTICSEARCH_USERNAME", ""),
password=config.get("ELASTICSEARCH_PASSWORD", ""),
),
attributes=[],
)

View File

@ -6,6 +6,8 @@ class Field(Enum):
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR = "vector"
# Sparse Vector aims to support full text search
SPARSE_VECTOR = "sparse_vector"
TEXT_KEY = "text"
PRIMARY_KEY = "id"
DOC_ID = "metadata.doc_id"

View File

@ -258,7 +258,7 @@ class LindormVectorStore(BaseVector):
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
nlist = kwargs.pop("nlist", 1000)
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False)
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", nlist >= 5000)
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
@ -305,7 +305,7 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic
if method_name == "ivfpq":
ivfpq_m = kwargs["ivfpq_m"]
nlist = kwargs["nlist"]
centroids_use_hnsw = True if nlist > 10000 else False
centroids_use_hnsw = nlist > 10000
centroids_hnsw_m = 24
centroids_hnsw_ef_construct = 500
centroids_hnsw_ef_search = 100

View File

@ -2,6 +2,7 @@ import json
import logging
from typing import Any, Optional
from packaging import version
from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException # type: ignore
from pymilvus.milvus_client import IndexParams # type: ignore
@ -20,16 +21,25 @@ logger = logging.getLogger(__name__)
class MilvusConfig(BaseModel):
uri: str
token: Optional[str] = None
user: str
password: str
batch_size: int = 100
database: str = "default"
"""
Configuration class for Milvus connection.
"""
uri: str # Milvus server URI
token: Optional[str] = None # Optional token for authentication
user: str # Username for authentication
password: str # Password for authentication
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
"""
Validate the configuration values.
Raises ValueError if required fields are missing.
"""
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")
if not values.get("user"):
@ -39,6 +49,9 @@ class MilvusConfig(BaseModel):
return values
def to_milvus_params(self):
"""
Convert the configuration to a dictionary of Milvus connection parameters.
"""
return {
"uri": self.uri,
"token": self.token,
@ -49,26 +62,57 @@ class MilvusConfig(BaseModel):
class MilvusVector(BaseVector):
"""
Milvus vector storage implementation.
"""
def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = "Session"
self._fields: list[str] = []
self._consistency_level = "Session" # Consistency level for Milvus operations
self._fields: list[str] = [] # List of fields in the collection
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
def _check_hybrid_search_support(self) -> bool:
"""
Check if the current Milvus version supports hybrid search.
Returns True if the version is >= 2.5.0, otherwise False.
"""
if not self._client_config.enable_hybrid_search:
return False
try:
milvus_version = self._client.get_server_version()
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
except Exception as e:
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
return False
def get_type(self) -> str:
"""
Get the type of vector storage (Milvus).
"""
return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""
Create a collection and add texts with embeddings.
"""
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""
Add texts and their embeddings to the collection.
"""
insert_dict_list = []
for i in range(len(documents)):
insert_dict = {
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
# function will automatically convert the native text into a sparse vector for us.
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
@ -76,12 +120,11 @@ class MilvusVector(BaseVector):
insert_dict_list.append(insert_dict)
# Total insert count
total_count = len(insert_dict_list)
pks: list[str] = []
for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i : i + 1000]
# Insert into the collection.
batch_insert_list = insert_dict_list[i : i + 1000]
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
@ -91,6 +134,9 @@ class MilvusVector(BaseVector):
return pks
def get_ids_by_metadata_field(self, key: str, value: str):
"""
Get document IDs by metadata field key and value.
"""
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
)
@ -100,12 +146,18 @@ class MilvusVector(BaseVector):
return None
def delete_by_metadata_field(self, key: str, value: str):
"""
Delete documents by metadata field key and value.
"""
if self._client.has_collection(self._collection_name):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, ids: list[str]) -> None:
"""
Delete documents by their IDs.
"""
if self._client.has_collection(self._collection_name):
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
@ -115,10 +167,16 @@ class MilvusVector(BaseVector):
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self) -> None:
"""
Delete the entire collection.
"""
if self._client.has_collection(self._collection_name):
self._client.drop_collection(self._collection_name, None)
def text_exists(self, id: str) -> bool:
"""
Check if a text with the given ID exists in the collection.
"""
if not self._client.has_collection(self._collection_name):
return False
@ -128,32 +186,80 @@ class MilvusVector(BaseVector):
return len(result) > 0
def field_exists(self, field: str) -> bool:
"""
Check if a field exists in the collection.
"""
return field in self._fields
def _process_search_results(
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
) -> list[Document]:
"""
Common method to process search results
:param results: Search results
:param output_fields: Fields to be output
:param score_threshold: Score threshold for filtering
:return: List of documents
"""
docs = []
for result in results[0]:
metadata = result["entity"].get(output_fields[1], {})
metadata["score"] = result["distance"]
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
docs.append(doc)
return docs
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters.
"""
Search for documents by vector similarity.
"""
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
anns_field=Field.VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
# Organize results.
docs = []
for result in results[0]:
metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []
"""
Search for documents by full-text search (if hybrid search is enabled).
"""
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
return []
results = self._client.search(
collection_name=self._collection_name,
data=[query],
anns_field=Field.SPARSE_VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
"""
Create a new collection in Milvus with the specified schema and index parameters.
"""
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
@ -161,7 +267,7 @@ class MilvusVector(BaseVector):
return
# Grab the existing collection if it exists
if not self._client.has_collection(self._collection_name):
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
# Determine embedding dim
@ -170,16 +276,36 @@ class MilvusVector(BaseVector):
if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
# Create the text field, enable_analyzer will be set True to support milvus automatically
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
fields.append(
FieldSchema(
Field.CONTENT_KEY.value,
DataType.VARCHAR,
max_length=65_535,
enable_analyzer=self._hybrid_search_enabled,
)
)
# Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
# Create the schema for the collection
schema = CollectionSchema(fields)
# Create custom function to support text to sparse vector by BM25
if self._hybrid_search_enabled:
bm25_function = Function(
name="text_bm25_emb",
input_field_names=[Field.CONTENT_KEY.value],
output_field_names=[Field.SPARSE_VECTOR.value],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
for x in schema.fields:
self._fields.append(x.name)
# Since primary field is auto-id, no need to track it
@ -189,10 +315,15 @@ class MilvusVector(BaseVector):
index_params_obj = IndexParams()
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
index_params_obj.add_index(
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
)
# Create the collection
collection_name = self._collection_name
self._client.create_collection(
collection_name=collection_name,
collection_name=self._collection_name,
schema=schema,
index_params=index_params_obj,
consistency_level=self._consistency_level,
@ -200,12 +331,22 @@ class MilvusVector(BaseVector):
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:
"""
Initialize and return a Milvus client.
"""
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
return client
class MilvusVectorFactory(AbstractVectorFactory):
"""
Factory class for creating MilvusVector instances.
"""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
"""
Initialize a MilvusVector instance for the given dataset.
"""
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
@ -222,5 +363,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
user=dify_config.MILVUS_USER or "",
password=dify_config.MILVUS_PASSWORD or "",
database=dify_config.MILVUS_DATABASE or "",
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
),
)

View File

@ -409,27 +409,27 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
if not tidb_auth_binding:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
else:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID or "",
@ -451,7 +451,6 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
db.session.add(new_tidb_auth_binding)
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
else:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"

View File

@ -90,6 +90,12 @@ class Vector:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.ELASTICSEARCH_JA:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
ElasticSearchJaVectorFactory,
)
return ElasticSearchJaVectorFactory
case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory

View File

@ -16,6 +16,7 @@ class VectorType(StrEnum):
TENCENT = "tencent"
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"
ELASTICSEARCH_JA = "elasticsearch-ja"
LINDORM = "lindorm"
COUCHBASE = "couchbase"
BAIDU = "baidu"

View File

@ -31,7 +31,7 @@ class FirecrawlApp:
"markdown": data.get("markdown"),
}
else:
raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}')
raise Exception(f"Failed to scrape URL. Error: {response_data['error']}")
elif response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")

View File

@ -358,8 +358,7 @@ class NotionExtractor(BaseExtractor):
if not data_source_binding:
raise Exception(
f"No notion data source binding found for tenant {tenant_id} "
f"and notion workspace {notion_workspace_id}"
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
)
return cast(str, data_source_binding.access_token)

View File

@ -23,7 +23,6 @@ class PdfExtractor(BaseExtractor):
self._file_cache_key = file_cache_key
def extract(self) -> list[Document]:
plaintext_file_key = ""
plaintext_file_exists = False
if self._file_cache_key:
try:
@ -39,8 +38,8 @@ class PdfExtractor(BaseExtractor):
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode("utf-8"))
if not plaintext_file_exists and self._file_cache_key:
storage.save(self._file_cache_key, text.encode("utf-8"))
return documents

View File

@ -112,7 +112,7 @@ class QAIndexProcessor(BaseIndexProcessor):
df = pd.read_csv(file)
text_docs = []
for index, row in df.iterrows():
data = Document(page_content=row[0], metadata={"answer": row[1]})
data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]})
text_docs.append(data)
if len(text_docs) == 0:
raise ValueError("The CSV file is empty.")

View File

@ -127,7 +127,7 @@ class AIPPTGenerateToolAdapter:
response = response.json()
if response.get("code") != 0:
raise Exception(f'Failed to create task: {response.get("msg")}')
raise Exception(f"Failed to create task: {response.get('msg')}")
return response.get("data", {}).get("id")
@ -222,7 +222,7 @@ class AIPPTGenerateToolAdapter:
elif model == "wenxin":
response = response.json()
if response.get("code") != 0:
raise Exception(f'Failed to generate content: {response.get("msg")}')
raise Exception(f"Failed to generate content: {response.get('msg')}")
return response.get("data", "")
@ -254,7 +254,7 @@ class AIPPTGenerateToolAdapter:
response = response.json()
if response.get("code") != 0:
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
id = response.get("data", {}).get("id")
cover_url = response.get("data", {}).get("cover_url")
@ -270,7 +270,7 @@ class AIPPTGenerateToolAdapter:
response = response.json()
if response.get("code") != 0:
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
export_code = response.get("data")
if not export_code:
@ -290,7 +290,7 @@ class AIPPTGenerateToolAdapter:
response = response.json()
if response.get("code") != 0:
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
if response.get("msg") == "导出中":
current_iteration += 1
@ -343,7 +343,7 @@ class AIPPTGenerateToolAdapter:
raise Exception(f"Failed to connect to aippt: {response.text}")
response = response.json()
if response.get("code") != 0:
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
token = response.get("data", {}).get("token")
expire = response.get("data", {}).get("time_expire")
@ -379,7 +379,7 @@ class AIPPTGenerateToolAdapter:
if cls._style_cache[key]["expire"] < now:
del cls._style_cache[key]
key = f'{credentials["aippt_access_key"]}#@#{user_id}'
key = f"{credentials['aippt_access_key']}#@#{user_id}"
if key in cls._style_cache:
return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"]
@ -396,11 +396,11 @@ class AIPPTGenerateToolAdapter:
response = response.json()
if response.get("code") != 0:
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
colors = [
{
"id": f'id-{item.get("id")}',
"id": f"id-{item.get('id')}",
"name": item.get("name"),
"en_name": item.get("en_name", item.get("name")),
}
@ -408,7 +408,7 @@ class AIPPTGenerateToolAdapter:
]
styles = [
{
"id": f'id-{item.get("id")}',
"id": f"id-{item.get('id')}",
"name": item.get("title"),
}
for item in response.get("data", {}).get("suit_style") or []
@ -454,7 +454,7 @@ class AIPPTGenerateToolAdapter:
response = response.json()
if response.get("code") != 0:
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
if len(response.get("data", {}).get("list") or []) > 0:
return response.get("data", {}).get("list")[0].get("id")

View File

@ -14,14 +14,38 @@ class BedrockRetrieveTool(BuiltinTool):
topk: int = None
def _bedrock_retrieve(
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
self,
query_input: str,
knowledge_base_id: str,
num_results: int,
search_type: str,
rerank_model_id: str,
metadata_filter: Optional[dict] = None,
):
try:
retrieval_query = {"text": query_input}
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
if search_type not in ["HYBRID", "SEMANTIC"]:
raise RuntimeException("search_type should be HYBRID or SEMANTIC")
# Add metadata filter to retrieval configuration if present
retrieval_configuration = {
"vectorSearchConfiguration": {"numberOfResults": num_results, "overrideSearchType": search_type}
}
if rerank_model_id != "default":
model_for_rerank_arn = f"arn:aws:bedrock:us-west-2::foundation-model/{rerank_model_id}"
rerankingConfiguration = {
"bedrockRerankingConfiguration": {
"numberOfRerankedResults": num_results,
"modelConfiguration": {"modelArn": model_for_rerank_arn},
},
"type": "BEDROCK_RERANKING_MODEL",
}
retrieval_configuration["vectorSearchConfiguration"]["rerankingConfiguration"] = rerankingConfiguration
retrieval_configuration["vectorSearchConfiguration"]["numberOfResults"] = num_results * 5
# 如果有元数据过滤条件,则添加到检索配置中
if metadata_filter:
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
@ -77,15 +101,20 @@ class BedrockRetrieveTool(BuiltinTool):
if not query:
return self.create_text_message("Please input query")
# Get metadata filter conditions (if they exist)
# 获取元数据过滤条件(如果存在)
metadata_filter_str = tool_parameters.get("metadata_filter")
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
search_type = tool_parameters.get("search_type")
rerank_model_id = tool_parameters.get("rerank_model_id")
line = 4
retrieved_docs = self._bedrock_retrieve(
query_input=query,
knowledge_base_id=self.knowledge_base_id,
num_results=self.topk,
search_type=search_type,
rerank_model_id=rerank_model_id,
metadata_filter=metadata_filter,
)
@ -109,7 +138,7 @@ class BedrockRetrieveTool(BuiltinTool):
if not parameters.get("query"):
raise ValueError("query is required")
# Optional: Validate if metadata filter is a valid JSON string (if provided)
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
metadata_filter_str = parameters.get("metadata_filter")
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
raise ValueError("metadata_filter must be a valid JSON object")

View File

@ -59,6 +59,57 @@ parameters:
max: 10
default: 5
- name: search_type
type: select
required: false
label:
en_US: search type
zh_Hans: 搜索类型
pt_BR: search type
human_description:
en_US: search type
zh_Hans: 搜索类型
pt_BR: search type
llm_description: search type
default: SEMANTIC
options:
- value: SEMANTIC
label:
en_US: SEMANTIC
zh_Hans: 语义搜索
- value: HYBRID
label:
en_US: HYBRID
zh_Hans: 混合搜索
form: form
- name: rerank_model_id
type: select
required: false
label:
en_US: rerank model id
zh_Hans: 重拍模型ID
pt_BR: rerank model id
human_description:
en_US: rerank model id
zh_Hans: 重拍模型ID
pt_BR: rerank model id
llm_description: rerank model id
options:
- value: default
label:
en_US: default
zh_Hans: 默认
- value: cohere.rerank-v3-5:0
label:
en_US: cohere.rerank-v3-5:0
zh_Hans: cohere.rerank-v3-5:0
- value: amazon.rerank-v1:0
label:
en_US: amazon.rerank-v1:0
zh_Hans: amazon.rerank-v1:0
form: form
- name: aws_region
type: string
required: false

View File

@ -229,8 +229,7 @@ class NovaReelTool(BuiltinTool):
if async_mode:
return self.create_text_message(
f"Video generation started.\nInvocation ARN: {invocation_arn}\n"
f"Video will be available at: {video_uri}"
f"Video generation started.\nInvocation ARN: {invocation_arn}\nVideo will be available at: {video_uri}"
)
return self._wait_for_completion(bedrock, s3_client, invocation_arn)

View File

@ -65,7 +65,7 @@ class BaiduFieldTranslateTool(BuiltinTool, BaiduTranslateToolBase):
if "trans_result" in result:
result_text = result["trans_result"][0]["dst"]
else:
result_text = f'{result["error_code"]}: {result["error_msg"]}'
result_text = f"{result['error_code']}: {result['error_msg']}"
return self.create_text_message(str(result_text))
except requests.RequestException as e:

View File

@ -52,7 +52,7 @@ class BaiduLanguageTool(BuiltinTool, BaiduTranslateToolBase):
result_text = ""
if result["error_code"] != 0:
result_text = f'{result["error_code"]}: {result["error_msg"]}'
result_text = f"{result['error_code']}: {result['error_msg']}"
else:
result_text = result["data"]["src"]
result_text = self.mapping_result(description_language, result_text)

View File

@ -58,7 +58,7 @@ class BaiduTranslateTool(BuiltinTool, BaiduTranslateToolBase):
if "trans_result" in result:
result_text = result["trans_result"][0]["dst"]
else:
result_text = f'{result["error_code"]}: {result["error_msg"]}'
result_text = f"{result['error_code']}: {result['error_msg']}"
return self.create_text_message(str(result_text))
except requests.RequestException as e:

View File

@ -30,7 +30,7 @@ class BingSearchTool(BuiltinTool):
headers = {"Ocp-Apim-Subscription-Key": subscription_key, "Accept-Language": accept_language}
query = quote(query)
server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
server_url = f"{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={','.join(filters)}"
response = get(server_url, headers=headers)
if response.status_code != 200:
@ -47,23 +47,23 @@ class BingSearchTool(BuiltinTool):
results = []
if search_results:
for result in search_results:
url = f': {result["url"]}' if "url" in result else ""
results.append(self.create_text_message(text=f'{result["name"]}{url}'))
url = f": {result['url']}" if "url" in result else ""
results.append(self.create_text_message(text=f"{result['name']}{url}"))
if entities:
for entity in entities:
url = f': {entity["url"]}' if "url" in entity else ""
results.append(self.create_text_message(text=f'{entity.get("name", "")}{url}'))
url = f": {entity['url']}" if "url" in entity else ""
results.append(self.create_text_message(text=f"{entity.get('name', '')}{url}"))
if news:
for news_item in news:
url = f': {news_item["url"]}' if "url" in news_item else ""
results.append(self.create_text_message(text=f'{news_item.get("name", "")}{url}'))
url = f": {news_item['url']}" if "url" in news_item else ""
results.append(self.create_text_message(text=f"{news_item.get('name', '')}{url}"))
if related_searches:
for related in related_searches:
url = f': {related["displayText"]}' if "displayText" in related else ""
results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}'))
url = f": {related['displayText']}" if "displayText" in related else ""
results.append(self.create_text_message(text=f"{related.get('displayText', '')}{url}"))
return results
elif result_type == "json":
@ -106,29 +106,29 @@ class BingSearchTool(BuiltinTool):
text = ""
if search_results:
for i, result in enumerate(search_results):
text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
text += f"{i + 1}: {result.get('name', '')} - {result.get('snippet', '')}\n"
if computation and "expression" in computation and "value" in computation:
text += "\nComputation:\n"
text += f'{computation["expression"]} = {computation["value"]}\n'
text += f"{computation['expression']} = {computation['value']}\n"
if entities:
text += "\nEntities:\n"
for entity in entities:
url = f'- {entity["url"]}' if "url" in entity else ""
text += f'{entity.get("name", "")}{url}\n'
url = f"- {entity['url']}" if "url" in entity else ""
text += f"{entity.get('name', '')}{url}\n"
if news:
text += "\nNews:\n"
for news_item in news:
url = f'- {news_item["url"]}' if "url" in news_item else ""
text += f'{news_item.get("name", "")}{url}\n'
url = f"- {news_item['url']}" if "url" in news_item else ""
text += f"{news_item.get('name', '')}{url}\n"
if related_searches:
text += "\n\nRelated Searches:\n"
for related in related_searches:
url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else ""
text += f'{related.get("displayText", "")}{url}\n'
url = f"- {related['webSearchUrl']}" if "webSearchUrl" in related else ""
text += f"{related.get('displayText', '')}{url}\n"
return self.create_text_message(text=self.summary(user_id=user_id, content=text))

View File

@ -83,5 +83,5 @@ class DIDApp:
if status["status"] == "done":
return status
elif status["status"] == "error" or status["status"] == "rejected":
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error", {}).get("description")}')
raise HTTPError(f"Talks {id} failed: {status['status']} {status.get('error', {}).get('description')}")
time.sleep(poll_interval)

View File

@ -20,33 +20,33 @@ class SendEmailToolParameters(BaseModel):
encrypt_method: str
def send_mail(parmas: SendEmailToolParameters):
def send_mail(params: SendEmailToolParameters):
timeout = 60
msg = MIMEMultipart("alternative")
msg["From"] = parmas.email_account
msg["To"] = parmas.sender_to
msg["Subject"] = parmas.subject
msg.attach(MIMEText(parmas.email_content, "plain"))
msg.attach(MIMEText(parmas.email_content, "html"))
msg["From"] = params.email_account
msg["To"] = params.sender_to
msg["Subject"] = params.subject
msg.attach(MIMEText(params.email_content, "plain"))
msg.attach(MIMEText(params.email_content, "html"))
ctx = ssl.create_default_context()
if parmas.encrypt_method.upper() == "SSL":
if params.encrypt_method.upper() == "SSL":
try:
with smtplib.SMTP_SSL(parmas.smtp_server, parmas.smtp_port, context=ctx, timeout=timeout) as server:
server.login(parmas.email_account, parmas.email_password)
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string())
with smtplib.SMTP_SSL(params.smtp_server, params.smtp_port, context=ctx, timeout=timeout) as server:
server.login(params.email_account, params.email_password)
server.sendmail(params.email_account, params.sender_to, msg.as_string())
return True
except Exception as e:
logging.exception("send email failed")
return False
else: # NONE or TLS
try:
with smtplib.SMTP(parmas.smtp_server, parmas.smtp_port, timeout=timeout) as server:
if parmas.encrypt_method.upper() == "TLS":
with smtplib.SMTP(params.smtp_server, params.smtp_port, timeout=timeout) as server:
if params.encrypt_method.upper() == "TLS":
server.starttls(context=ctx)
server.login(parmas.email_account, parmas.email_password)
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string())
server.login(params.email_account, params.email_password)
server.sendmail(params.email_account, params.sender_to, msg.as_string())
return True
except Exception as e:
logging.exception("send email failed")

View File

@ -74,7 +74,7 @@ class FirecrawlApp:
if response is None:
raise HTTPError("Failed to initiate crawl after multiple retries")
elif response.get("success") == False:
raise HTTPError(f'Failed to crawl: {response.get("error")}')
raise HTTPError(f"Failed to crawl: {response.get('error')}")
job_id: str = response["id"]
if wait:
return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval)
@ -100,7 +100,7 @@ class FirecrawlApp:
if status["status"] == "completed":
return status
elif status["status"] == "failed":
raise HTTPError(f'Job {job_id} failed: {status["error"]}')
raise HTTPError(f"Job {job_id} failed: {status['error']}")
time.sleep(poll_interval)

View File

@ -37,8 +37,9 @@ class GaodeRepositoriesTool(BuiltinTool):
CityCode = City_data["districts"][0]["adcode"]
weatherInfo_response = s.request(
method="GET",
url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json"
"".format(url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")),
url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json".format(
url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")
),
)
weatherInfo_data = weatherInfo_response.json()
if weatherInfo_response.status_code == 200 and weatherInfo_data.get("info") == "OK":

View File

@ -11,19 +11,21 @@ class GitlabFilesTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
project = tool_parameters.get("project", "")
repository = tool_parameters.get("repository", "")
project = tool_parameters.get("project", "")
branch = tool_parameters.get("branch", "")
path = tool_parameters.get("path", "")
file_path = tool_parameters.get("file_path", "")
if not project and not repository:
return self.create_text_message("Either project or repository is required")
if not repository and not project:
return self.create_text_message("Either repository or project is required")
if not branch:
return self.create_text_message("Branch is required")
if not path:
return self.create_text_message("Path is required")
if not path and not file_path:
return self.create_text_message("Either path or file_path is required")
access_token = self.runtime.credentials.get("access_tokens")
headers = {"PRIVATE-TOKEN": access_token}
site_url = self.runtime.credentials.get("site_url")
if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"):
@ -31,33 +33,45 @@ class GitlabFilesTool(BuiltinTool):
if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"):
site_url = "https://gitlab.com"
# Get file content
if repository:
result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True)
# URL encode the repository path
identifier = urllib.parse.quote(repository, safe="")
else:
result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False)
identifier = self.get_project_id(site_url, access_token, project)
if not identifier:
raise Exception(f"Project '{project}' not found.)")
return [self.create_json_message(item) for item in result]
# Get file content
if path:
results = self.fetch_files(site_url, headers, identifier, branch, path)
return [self.create_json_message(item) for item in results]
else:
result = self.fetch_file(site_url, headers, identifier, branch, file_path)
return [self.create_json_message(result)]
@staticmethod
def fetch_file(
site_url: str,
headers: dict[str, str],
identifier: str,
branch: str,
path: str,
) -> dict[str, Any]:
encoded_file_path = urllib.parse.quote(path, safe="")
file_url = f"{site_url}/api/v4/projects/{identifier}/repository/files/{encoded_file_path}/raw?ref={branch}"
file_response = requests.get(file_url, headers=headers)
file_response.raise_for_status()
file_content = file_response.text
return {"path": path, "branch": branch, "content": file_content}
def fetch_files(
self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool
self, site_url: str, headers: dict[str, str], identifier: str, branch: str, path: str
) -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = []
try:
if is_repository:
# URL encode the repository path
encoded_identifier = urllib.parse.quote(identifier, safe="")
tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}"
else:
# Get project ID from project name
project_id = self.get_project_id(site_url, access_token, identifier)
if not project_id:
return self.create_text_message(f"Project '{identifier}' not found.")
tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
tree_url = f"{site_url}/api/v4/projects/{identifier}/repository/tree?path={path}&ref={branch}"
response = requests.get(tree_url, headers=headers)
response.raise_for_status()
items = response.json()
@ -65,26 +79,10 @@ class GitlabFilesTool(BuiltinTool):
for item in items:
item_path = item["path"]
if item["type"] == "tree": # It's a directory
results.extend(
self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository)
)
results.extend(self.fetch_files(site_url, headers, identifier, branch, item_path))
else: # It's a file
encoded_item_path = urllib.parse.quote(item_path, safe="")
if is_repository:
file_url = (
f"{domain}/api/v4/projects/{encoded_identifier}/repository/files"
f"/{encoded_item_path}/raw?ref={branch}"
)
else:
file_url = (
f"{domain}/api/v4/projects/{project_id}/repository/files"
f"{encoded_item_path}/raw?ref={branch}"
)
file_response = requests.get(file_url, headers=headers)
file_response.raise_for_status()
file_content = file_response.text
results.append({"path": item_path, "branch": branch, "content": file_content})
result = self.fetch_file(site_url, headers, identifier, branch, item_path)
results.append(result)
except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}")

View File

@ -29,7 +29,7 @@ parameters:
zh_Hans: 项目
human_description:
en_US: project
zh_Hans: 项目
zh_Hans: 项目(和仓库路径二选一,都填写以仓库路径优先)
llm_description: Project for GitLab
form: llm
- name: branch
@ -45,12 +45,21 @@ parameters:
form: llm
- name: path
type: string
required: true
label:
en_US: path
zh_Hans: 文件路径
zh_Hans: 文件
human_description:
en_US: path
zh_Hans: 文件夹
llm_description: Dir path for GitLab
form: llm
- name: file_path
type: string
label:
en_US: file_path
zh_Hans: 文件路径
human_description:
en_US: file_path
zh_Hans: 文件路径(和文件夹二选一,都填写以文件夹优先)
llm_description: File path for GitLab
form: llm

View File

@ -110,7 +110,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
result["rows"].append(self.get_row_field_value(row, schema))
return self.create_text_message(json.dumps(result, ensure_ascii=False))
else:
result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"."
result_text = f'Found {result["total"]} rows in worksheet "{worksheet_name}".'
if result["total"] > 0:
result_text += (
f" The following are {min(limit, result['total'])}"

View File

@ -28,4 +28,4 @@ class BaseStabilityAuthorization:
"""
This method is responsible for generating the authorization headers.
"""
return {"Authorization": f'Bearer {credentials.get("api_key", "")}'}
return {"Authorization": f"Bearer {credentials.get('api_key', '')}"}

View File

@ -38,7 +38,7 @@ class VannaProvider(BuiltinToolProviderController):
tool_parameters={
"model": "chinook",
"db_type": "SQLite",
"url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite',
"url": f"{self._get_protocol_and_main_domain(credentials['base_url'])}/Chinook.sqlite",
"query": "What are the top 10 customers by sales?",
},
)

View File

@ -43,7 +43,7 @@ class SerplyApi:
def parse_results(res: dict) -> str:
"""Process response from Serply Job Search."""
jobs = res.get("jobs", [])
if not jobs:
if not res or "jobs" not in res:
raise ValueError(f"Got error from Serply: {res}")
string = []

View File

@ -43,7 +43,7 @@ class SerplyApi:
def parse_results(res: dict) -> str:
"""Process response from Serply News Search."""
news = res.get("entries", [])
if not news:
if not res or "entries" not in res:
raise ValueError(f"Got error from Serply: {res}")
string = []

View File

@ -43,7 +43,7 @@ class SerplyApi:
def parse_results(res: dict) -> str:
"""Process response from Serply News Search."""
articles = res.get("articles", [])
if not articles:
if not res or "articles" not in res:
raise ValueError(f"Got error from Serply: {res}")
string = []

View File

@ -42,7 +42,7 @@ class SerplyApi:
def parse_results(res: dict) -> str:
"""Process response from Serply Web Search."""
results = res.get("results", [])
if not results:
if not res or "results" not in res:
raise ValueError(f"Got error from Serply: {res}")
string = []

View File

@ -84,9 +84,9 @@ class ApiTool(Tool):
if "api_key_header_prefix" in credentials:
api_key_header_prefix = credentials["api_key_header_prefix"]
if api_key_header_prefix == "basic" and credentials["api_key_value"]:
credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}'
credentials["api_key_value"] = f"Basic {credentials['api_key_value']}"
elif api_key_header_prefix == "bearer" and credentials["api_key_value"]:
credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}'
credentials["api_key_value"] = f"Bearer {credentials['api_key_value']}"
elif api_key_header_prefix == "custom":
pass

View File

@ -29,7 +29,7 @@ class ToolFileMessageTransformer:
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_url=message.message
)
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}"
result.append(
ToolInvokeMessage(
@ -122,4 +122,4 @@ class ToolFileMessageTransformer:
@classmethod
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
return f"/files/tools/{tool_file_id}{extension or '.bin'}"

View File

@ -5,6 +5,7 @@ from json import loads as json_loads
from json.decoder import JSONDecodeError
from typing import Optional
from flask import request
from requests import get
from yaml import YAMLError, safe_load # type: ignore
@ -29,6 +30,10 @@ class ApiBasedToolSchemaParser:
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
server_url = openapi["servers"][0]["url"]
request_env = request.headers.get("X-Request-Env")
if request_env:
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
server_url = matched_servers[0] if matched_servers else server_url
# list all interfaces
interfaces = []
@ -112,7 +117,7 @@ class ApiBasedToolSchemaParser:
llm_description=property.get("description", ""),
default=property.get("default", None),
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
)
@ -144,7 +149,7 @@ class ApiBasedToolSchemaParser:
if not path:
path = str(uuid.uuid4())
interface["operation"]["operationId"] = f'{path}_{interface["method"]}'
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
bundles.append(
ApiToolBundle(

View File

@ -134,6 +134,10 @@ class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[str]
@property
def text(self) -> str:
return json.dumps(self.value)
class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER

View File

@ -1,6 +1,7 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Optional
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
@ -48,25 +49,35 @@ class StreamProcessor(ABC):
# we remove the node maybe shortcut the answer node, so comment this code for now
# there is not effect on the answer node and the workflow, when we have a better solution
# we can open this code. Issues: #11542 #9560 #10638 #10564
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
if "answer" in ids:
continue
else:
reachable_node_ids.extend(ids)
# ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
# if "answer" in ids:
# continue
# else:
# reachable_node_ids.extend(ids)
# The branch_identify parameter is added to ensure that
# only nodes in the correct logical branch are included.
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
reachable_node_ids.extend(ids)
else:
unreachable_first_node_ids.append(edge.target_node_id)
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
# Only follow edges that match the branch_identify or have no run_condition
if edge.run_condition and edge.run_condition.branch_identify:
if not branch_identify or edge.run_condition.branch_identify != branch_identify:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:

Some files were not shown because too many files have changed in this diff Show More