Compare commits

...

46 Commits

Author SHA1 Message Date
83b6abf4ad Update parse.py to handle empty list result (#10915)
Co-authored-by: crazywoola <427733928@qq.com>
2024-11-21 14:14:07 +08:00
ea0ebc020c fix: chat history might be empty in log detail view (#10905) 2024-11-21 14:12:01 +08:00
f358db9f02 feat : Add Japanese translations for API documentation: chat, advanced-chat, completion, and workflow (#10927) 2024-11-21 14:02:46 +08:00
94c9cadbd8 fix image files not deleted on indexing_estimate #9541 (#10798)
Co-authored-by: root <root@localhost.localdomain>
2024-11-21 13:03:16 +08:00
2ae6460f46 Add googlenews tools from rapidapi (#10877)
Co-authored-by: steven <sunzwj@digitalchina.com>
2024-11-21 10:39:49 +08:00
0067b16d1e fix: refactor all 'or []' and 'or {}' logic to make code more clear (#10883)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-11-21 10:34:43 +08:00
ec9f6220c9 doc: fix better doc for api develop, droping dead hint (#10906)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-11-21 10:34:23 +08:00
af53e2b6b0 Fix : Add a process to fetch the mime type from the file name for signed url in remote_url #10872 version2 (#10908) 2024-11-20 22:57:49 +08:00
b42b333a72 fix: handle redis authentication for healthcheck command (#10907) 2024-11-20 20:10:51 +08:00
99b0369f1b Gitee AI embedding tool (#10903) 2024-11-20 17:40:34 +08:00
d6ea1e2f12 fix: explicitly use new token when retrying ssePost after refresh (#10864)
Co-authored-by: liusurong.lsr <liusurong.lsr@alibaba-inc.com>
2024-11-20 16:11:33 +08:00
4d6b45427c Support streaming output for OpenAI o1-preview and o1-mini (#10890) 2024-11-20 15:10:41 +08:00
1be8365684 Fix/input-value-type-in-moderation (#10893) 2024-11-20 15:10:12 +08:00
c3d11c8ff6 fix: aws presign url is not workable remote url (#10884)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
2024-11-20 14:24:41 +08:00
8ff65abbc6 ext_redis.py support redis clusters --- Fixes #9538 (#9789)
Signed-off-by: root <root@localhost.localdomain>
Co-authored-by: root <root@localhost.localdomain>
Co-authored-by: Bowen Liang <bowenliang@apache.org>
2024-11-20 13:44:35 +08:00
bf4b6e5f80 feat: support custom tool upload file (#10796) 2024-11-20 13:26:42 +08:00
25fda7adc5 fix(http_request): allow content type application/x-javascript (#10862) 2024-11-20 12:55:06 +08:00
f3af7b5f35 fix: tool's file input display string (#10887) 2024-11-20 12:54:24 +08:00
33cfc56ad0 fix: update email validation regex to allow periods in local part (#10868) 2024-11-20 12:33:02 +08:00
464cc26ccf Fix : Add a process to fetch the mime type from the file name for signed url in remote_url (#10872) 2024-11-20 12:30:25 +08:00
d18754afdd feat: admin can also change member role (#10651) 2024-11-20 11:29:49 +08:00
beb7953d38 feat: enhance the custom note (#8885) 2024-11-20 11:24:45 +08:00
fbfc811a44 feat: support function call for ollama block chat api (#10784) 2024-11-20 11:15:19 +08:00
7e66e5a713 feat: make toc panel can collapse (#10875) 2024-11-20 10:07:30 +08:00
07b5bbae06 feat: add a minimal separator between pinned apps and unpinned apps in the explore page (#10871) 2024-11-20 09:32:59 +08:00
3087913b74 Fix the situation where output_tokens/input_tokens may be None in response.usage (#10728) 2024-11-19 21:19:13 +08:00
904ea05bf6 fix: download some remote files raise error (#10781) 2024-11-19 21:18:53 +08:00
6f4885d86d Encode invitee email in the invitation link (#10842) 2024-11-19 21:08:37 +08:00
Joe
2dc29cfee3 Feat/add langsmith dotted order (#10856) 2024-11-19 21:08:23 +08:00
bd05df5cc5 fix tongyi embedding endpoint return None output (#10857) 2024-11-19 21:04:17 +08:00
ee1f14621a fix httpx doesn't support stream parameter (#10859) 2024-11-19 21:03:01 +08:00
58a9d9eb9a fix: better WeightRerankRunner run logic use O(1) and delete unused code (#10849)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-11-19 20:12:13 +08:00
bc1013dacf feat: support json schema for gemini models (#10835) 2024-11-19 17:49:58 +08:00
9f195df103 Support Video Proxy and TED Embedding (#10819) 2024-11-19 17:49:14 +08:00
1cc7dc6360 style: refactor fetch and context (#10795) 2024-11-19 17:16:06 +08:00
328965ed7c Fix: crash of workflow file upload (#10831)
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2024-11-19 14:15:18 +08:00
133de9a087 fix: upload file component support multiple (#10817) 2024-11-19 14:00:54 +08:00
7261384655 fix: close child modal on log drawer close (#10839) 2024-11-19 12:09:55 +08:00
4718071cbb feat: Knowledge-base-api-get-post-method-text-error-#10836 (#10837) 2024-11-19 12:08:10 +08:00
22be0816aa feat: add TOC to app develop doc (#10799) 2024-11-19 09:06:12 +08:00
49e88322de doc: add clarification for length limit of init password (#10824) 2024-11-19 09:05:05 +08:00
14f3d44c37 refactor: improve handling of leading punctuation removal (#10761) 2024-11-18 21:32:33 +08:00
0ba17ec116 fix: correct typo in ETL type comment in .env.example (#10822) 2024-11-18 20:58:43 +08:00
79d59c004b chore: update .gitignore to include mise.toml (#10778) 2024-11-18 19:35:12 +08:00
873e9720e9 feat: AnalyticDB vector store supports invocation via SQL. (#10802)
Co-authored-by: 璟义 <yangshangpo.ysp@alibaba-inc.com>
2024-11-18 19:29:54 +08:00
de6d3e493c fix: script rendering in message (#10807)
Co-authored-by: crazywoola <427733928@qq.com>
2024-11-18 19:19:10 +08:00
126 changed files with 5667 additions and 741 deletions

View File

@ -42,6 +42,11 @@ REDIS_SENTINEL_USERNAME=
REDIS_SENTINEL_PASSWORD=
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
# redis Cluster configuration.
REDIS_USE_CLUSTERS=false
REDIS_CLUSTERS=
REDIS_CLUSTERS_PASSWORD=
# PostgreSQL database configuration
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
@ -234,6 +239,10 @@ ANALYTICDB_ACCOUNT=testaccount
ANALYTICDB_PASSWORD=testpassword
ANALYTICDB_NAMESPACE=dify
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
ANALYTICDB_HOST=gp-test.aliyuncs.com
ANALYTICDB_PORT=5432
ANALYTICDB_MIN_CONNECTION=1
ANALYTICDB_MAX_CONNECTION=5
# OpenSearch configuration
OPENSEARCH_HOST=127.0.0.1

View File

@ -18,12 +18,17 @@
```
2. Copy `.env.example` to `.env`
```cli
cp .env.example .env
```
3. Generate a `SECRET_KEY` in the `.env` file.
bash for Linux
```bash for Linux
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
```
bash for Mac
```bash for Mac
secret_key=$(openssl rand -base64 42)
sed -i '' "/^SECRET_KEY=/c\\
@ -41,14 +46,6 @@
poetry install
```
In case of contributors missing to update dependencies for `pyproject.toml`, you can perform the following shell instead.
```bash
poetry shell # activate current environment
poetry add $(cat requirements.txt) # install dependencies of production and update pyproject.toml
poetry add $(cat requirements-dev.txt) --group dev # install dependencies of development and update pyproject.toml
```
6. Run migrate
Before the first launch, migrate the database to the latest version.

View File

@ -68,3 +68,18 @@ class RedisConfig(BaseSettings):
description="Socket timeout in seconds for Redis Sentinel connections",
default=0.1,
)
REDIS_USE_CLUSTERS: bool = Field(
description="Enable Redis Clusters mode for high availability",
default=False,
)
REDIS_CLUSTERS: Optional[str] = Field(
description="Comma-separated list of Redis Clusters nodes (host:port)",
default=None,
)
REDIS_CLUSTERS_PASSWORD: Optional[str] = Field(
description="Password for Redis Clusters authentication (if required)",
default=None,
)

View File

@ -1,6 +1,6 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PositiveInt
class AnalyticdbConfig(BaseModel):
@ -40,3 +40,11 @@ class AnalyticdbConfig(BaseModel):
description="The password for accessing the specified namespace within the AnalyticDB instance"
" (if namespace feature is enabled).",
)
ANALYTICDB_HOST: Optional[str] = Field(
default=None, description="The host of the AnalyticDB instance you want to connect to."
)
ANALYTICDB_PORT: PositiveInt = Field(
default=5432, description="The port of the AnalyticDB instance you want to connect to."
)
ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.")
ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.")

View File

@ -45,7 +45,7 @@ class RemoteFileUploadApi(Resource):
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3)
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
resp.raise_for_status()
file_info = helpers.guess_file_info_from_response(resp)

View File

@ -1,3 +1,5 @@
from urllib import parse
from flask_login import current_user
from flask_restful import Resource, abort, marshal_with, reqparse
@ -57,11 +59,12 @@ class MemberInviteEmailApi(Resource):
token = RegisterService.invite_new_member(
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
)
encoded_invitee_email = parse.quote(invitee_email)
invitation_results.append(
{
"status": "success",
"email": invitee_email,
"url": f"{console_web_url}/activate?email={invitee_email}&token={token}",
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:

View File

@ -114,16 +114,9 @@ class BaseAgentRunner(AppRunner):
# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
self.stream_tool_call = True
else:
self.stream_tool_call = False
# check if model supports vision
if model_schema and ModelFeature.VISION in (model_schema.features or []):
self.files = application_generate_entity.files
else:
self.files = []
features = model_schema.features if model_schema and model_schema.features else []
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
self.query = None
self._current_thoughts: list[PromptMessage] = []
@ -250,7 +243,7 @@ class BaseAgentRunner(AppRunner):
update prompt message tool
"""
# try to get tool runtime parameters
tool_runtime_parameters = tool.get_runtime_parameters() or []
tool_runtime_parameters = tool.get_runtime_parameters()
for parameter in tool_runtime_parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:

View File

@ -16,9 +16,7 @@ class FileUploadConfigManager:
file_upload_dict = config.get("file_upload")
if file_upload_dict:
if file_upload_dict.get("enabled"):
transform_methods = file_upload_dict.get("allowed_file_upload_methods") or file_upload_dict.get(
"allowed_upload_methods", []
)
transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
data = {
"image_config": {
"number_limits": file_upload_dict["number_limits"],

View File

@ -33,8 +33,8 @@ class BaseAppGenerator:
tenant_id=app_config.tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
)
for k, v in user_inputs.items()
@ -47,8 +47,8 @@ class BaseAppGenerator:
tenant_id=app_config.tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
)
for k, v in user_inputs.items()

View File

@ -381,7 +381,7 @@ class WorkflowCycleManage:
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
inputs=workflow_run.inputs_dict or {},
inputs=workflow_run.inputs_dict,
created_at=int(workflow_run.created_at.timestamp()),
),
)
@ -428,7 +428,7 @@ class WorkflowCycleManage:
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
),
)

View File

@ -28,8 +28,8 @@ class FileUploadConfig(BaseModel):
image_config: Optional[ImageConfig] = None
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_extensions: Sequence[str] = Field(default_factory=list)
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
number_limits: int = 0

View File

@ -39,6 +39,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
)
retries = 0
stream = kwargs.pop("stream", False)
while retries <= max_retries:
try:
if dify_config.SSRF_PROXY_ALL_URL:
@ -52,6 +53,8 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST:
if stream:
return response.iter_bytes()
return response
else:
logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list")

View File

@ -29,6 +29,8 @@ from core.rag.splitter.fixed_text_splitter import (
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.text_splitter import TextSplitter
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@ -278,6 +280,19 @@ class IndexingRunner:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
try:
storage.delete(image_file.key)
except Exception:
logging.exception(
"Delete image_files failed while indexing_estimate, \
image_upload_file_is: {}".format(upload_file_id)
)
db.session.delete(image_file)
if doc_form and doc_form == "qa_model":
if len(preview_texts) > 0:
# qa model document
@ -500,11 +515,7 @@ class IndexingRunner:
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
document_node.page_content = remove_leading_symbols(page_content)
if document_node.page_content:
split_documents.append(document_node)

View File

@ -325,14 +325,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
assistant_prompt_message.tool_calls.append(tool_call)
# calculate num tokens
if response.usage:
# transform usage
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
prompt_tokens = (response.usage and response.usage.input_tokens) or self.get_num_tokens(
model, credentials, prompt_messages
)
completion_tokens = (response.usage and response.usage.output_tokens) or self.get_num_tokens(
model, credentials, [assistant_prompt_message]
)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)

View File

@ -2,13 +2,11 @@
import base64
import json
import logging
import mimetypes
from collections.abc import Generator
from typing import Optional, Union, cast
# 3rd import
import boto3
import requests
from botocore.config import Config
from botocore.exceptions import (
ClientError,
@ -439,22 +437,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
url = message_content.data
image_content = requests.get(url).content
if "?" in url:
url = url.split("?")[0]
mime_type, _ = mimetypes.guess_type(url)
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
image_content = base64.b64decode(base64_data)
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
image_content = base64.b64decode(base64_data)
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(

View File

@ -691,8 +691,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
base_model_schema = cast(AIModelEntity, base_model_schema)
base_model_schema_features = base_model_schema.features or []
base_model_schema_model_properties = base_model_schema.model_properties or {}
base_model_schema_parameters_rules = base_model_schema.parameter_rules or []
base_model_schema_model_properties = base_model_schema.model_properties
base_model_schema_parameters_rules = base_model_schema.parameter_rules
entity = AIModelEntity(
model=model,

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -32,3 +32,4 @@ pricing:
output: '0.00'
unit: '0.000001'
currency: USD
deprecated: true

View File

@ -36,3 +36,4 @@ pricing:
output: '0.00'
unit: '0.000001'
currency: USD
deprecated: true

View File

@ -1,7 +1,6 @@
import base64
import io
import json
import logging
from collections.abc import Generator
from typing import Optional, Union, cast
@ -36,17 +35,6 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
""" # noqa: E501
class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(
@ -155,7 +143,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
try:
ping_message = SystemPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -184,7 +172,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
if schema := config_kwargs.pop("json_schema", None):
try:
schema = json.loads(schema)
except:
raise exceptions.InvalidArgument("Invalid JSON Schema")
if tools:
raise exceptions.InvalidArgument("gemini not support use Tools and JSON Schema at same time")
config_kwargs["response_schema"] = schema
config_kwargs["response_mime_type"] = "application/json"
if stop:
config_kwargs["stop_sequences"] = stop

View File

@ -22,6 +22,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
@ -86,6 +87,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
@ -153,6 +155,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
@ -196,6 +199,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, "api/chat")
data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
if tools:
data["tools"] = [self._convert_prompt_message_tool_to_dict(tool) for tool in tools]
else:
endpoint_url = urljoin(endpoint_url, "api/generate")
first_prompt_message = prompt_messages[0]
@ -232,7 +237,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
if stream:
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages, tools)
def _handle_generate_response(
self,
@ -241,6 +246,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
completion_type: LLMMode,
response: requests.Response,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]],
) -> LLMResult:
"""
Handle llm completion response
@ -253,14 +259,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:return: llm result
"""
response_json = response.json()
tool_calls = []
if completion_type is LLMMode.CHAT:
message = response_json.get("message", {})
response_content = message.get("content", "")
response_tool_calls = message.get("tool_calls", [])
tool_calls = [self._extract_response_tool_call(tool_call) for tool_call in response_tool_calls]
else:
response_content = response_json["response"]
assistant_message = AssistantPromptMessage(content=response_content)
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=tool_calls)
if "prompt_eval_count" in response_json and "eval_count" in response_json:
# transform usage
@ -405,9 +413,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
chunk_index += 1
def _convert_prompt_message_tool_to_dict(self, tool: PromptMessageTool) -> dict:
"""
Convert PromptMessageTool to dict for Ollama API
:param tool: tool
:return: tool dict
"""
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
},
}
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for Ollama API
:param message: prompt message
:return: message dict
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
@ -432,6 +459,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"role": "tool", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
@ -452,6 +482,29 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
return num_tokens
def _extract_response_tool_call(self, response_tool_call: dict) -> AssistantPromptMessage.ToolCall:
"""
Extract response tool call
"""
tool_call = None
if response_tool_call and "function" in response_tool_call:
# Convert arguments to JSON string if it's a dict
arguments = response_tool_call.get("function").get("arguments")
if isinstance(arguments, dict):
arguments = json.dumps(arguments)
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.get("function").get("name"),
arguments=arguments,
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.get("function").get("name"),
type="function",
function=function,
)
return tool_call
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
Get customizable model schema.
@ -461,10 +514,15 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:return: model schema
"""
extras = {}
extras = {
"features": [],
}
if "vision_support" in credentials and credentials["vision_support"] == "true":
extras["features"] = [ModelFeature.VISION]
extras["features"].append(ModelFeature.VISION)
if "function_call_support" in credentials and credentials["function_call_support"] == "true":
extras["features"].append(ModelFeature.TOOL_CALL)
extras["features"].append(ModelFeature.MULTI_TOOL_CALL)
entity = AIModelEntity(
model=model,

View File

@ -96,3 +96,22 @@ model_credential_schema:
label:
en_US: 'No'
zh_Hans:
- variable: function_call_support
label:
zh_Hans: 是否支持函数调用
en_US: Function call support
show_on:
- variable: __model_type
value: llm
default: 'false'
type: radio
required: false
options:
- value: 'true'
label:
en_US: 'Yes'
zh_Hans:
- value: 'false'
label:
en_US: 'No'
zh_Hans:

View File

@ -615,19 +615,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
# o1 compatibility
block_as_stream = False
if model.startswith("o1"):
if "max_tokens" in model_parameters:
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
del model_parameters["max_tokens"]
if stream:
block_as_stream = True
stream = False
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]
if "stop" in extra_model_kwargs:
del extra_model_kwargs["stop"]
@ -644,47 +636,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
if block_as_stream:
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
return block_result
def _handle_chat_block_as_stream_response(
self,
block_result: LLMResult,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:param stop: stop words
:return: llm response chunk generator
"""
text = block_result.message.content
text = cast(str, text)
if stop:
text = self.enforce_stop_tokens(text, stop)
yield LLMResultChunk(
model=block_result.model,
prompt_messages=prompt_messages,
system_fingerprint=block_result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=text),
finish_reason="stop",
usage=block_result.usage,
),
)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
def _handle_chat_generate_response(
self,
@ -1178,8 +1130,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
base_model_schema = model_map[base_model]
base_model_schema_features = base_model_schema.features or []
base_model_schema_model_properties = base_model_schema.model_properties or {}
base_model_schema_parameters_rules = base_model_schema.parameter_rules or []
base_model_schema_model_properties = base_model_schema.model_properties
base_model_schema_parameters_rules = base_model_schema.parameter_rules
entity = AIModelEntity(
model=model,

View File

@ -37,13 +37,14 @@ class OpenLLMGenerateMessage:
class OpenLLMGenerate:
def generate(
self,
*,
server_url: str,
model_name: str,
stream: bool,
model_parameters: dict[str, Any],
stop: list[str],
stop: list[str] | None = None,
prompt_messages: list[OpenLLMGenerateMessage],
user: str,
user: str | None = None,
) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]:
if not server_url:
raise InvalidAuthenticationError("Invalid server URL")

View File

@ -45,19 +45,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._update_credential(model, credentials)
block_as_stream = False
if model.startswith("openai/o1"):
block_as_stream = True
stop = None
# invoke block as stream
if stream and block_as_stream:
return self._generate_block_as_stream(
model, credentials, prompt_messages, model_parameters, tools, stop, user
)
else:
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _generate_block_as_stream(
self,
@ -69,9 +57,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
stop: Optional[list[str]] = None,
user: Optional[str] = None,
) -> Generator:
resp: LLMResult = super()._generate(
model, credentials, prompt_messages, model_parameters, tools, stop, False, user
)
resp = super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, False, user)
yield LLMResultChunk(
model=model,

View File

@ -65,6 +65,8 @@ class GTERerankModel(RerankModel):
)
rerank_documents = []
if not response.output:
return RerankResult(model=model, docs=rerank_documents)
for _, result in enumerate(response.output.results):
# format document
rerank_document = RerankDocument(

View File

@ -1,3 +1,6 @@
from collections.abc import Sequence
from typing import Any
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
@ -62,5 +65,5 @@ class KeywordsModeration(Moderation):
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
def _check_keywords_in_value(self, keywords_list, value) -> bool:
return any(keyword.lower() in value.lower() for keyword in keywords_list)
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:
return any(keyword.lower() in str(value).lower() for keyword in keywords_list)

View File

@ -49,6 +49,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run")
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
@field_validator("inputs", "outputs")
@classmethod

View File

@ -25,7 +25,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunType,
LangSmithRunUpdateModel,
)
from core.ops.utils import filter_none_values
from core.ops.utils import filter_none_values, generate_dotted_order
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
@ -62,6 +62,16 @@ class LangSmithDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.message_id or trace_info.workflow_app_log_id or trace_info.workflow_run_id
message_dotted_order = (
generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
)
workflow_dotted_order = generate_dotted_order(
trace_info.workflow_app_log_id or trace_info.workflow_run_id,
trace_info.workflow_data.created_at,
message_dotted_order,
)
if trace_info.message_id:
message_run = LangSmithRunModel(
id=trace_info.message_id,
@ -76,6 +86,8 @@ class LangSmithDataTrace(BaseTraceInstance):
},
tags=["message", "workflow"],
error=trace_info.error,
trace_id=trace_id,
dotted_order=message_dotted_order,
)
self.add_run(message_run)
@ -95,6 +107,8 @@ class LangSmithDataTrace(BaseTraceInstance):
error=trace_info.error,
tags=["workflow"],
parent_run_id=trace_info.message_id or None,
trace_id=trace_id,
dotted_order=workflow_dotted_order,
)
self.add_run(langsmith_run)
@ -177,6 +191,7 @@ class LangSmithDataTrace(BaseTraceInstance):
else:
run_type = LangSmithRunType.tool
node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
langsmith_run = LangSmithRunModel(
total_tokens=node_total_tokens,
name=node_type,
@ -191,6 +206,9 @@ class LangSmithDataTrace(BaseTraceInstance):
},
parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
tags=["node_execution"],
id=node_execution_id,
trace_id=trace_id,
dotted_order=node_dotted_order,
)
self.add_run(langsmith_run)

View File

@ -1,5 +1,6 @@
from contextlib import contextmanager
from datetime import datetime
from typing import Optional, Union
from extensions.ext_database import db
from models.model import Message
@ -43,3 +44,19 @@ def replace_text_with_content(data):
return [replace_text_with_content(item) for item in data]
else:
return data
def generate_dotted_order(
run_id: str, start_time: Union[str, datetime], parent_dotted_order: Optional[str] = None
) -> str:
"""
generate dotted_order for langsmith
"""
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
current_segment = f"{timestamp}{run_id}"
if parent_dotted_order is None:
return current_segment
return f"{parent_dotted_order}.{current_segment}"

View File

@ -1,310 +1,62 @@
import json
from typing import Any
from pydantic import BaseModel
_import_err_msg = (
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from configs import dify_config
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
class AnalyticdbConfig(BaseModel):
access_key_id: str
access_key_secret: str
region_id: str
instance_id: str
account: str
account_password: str
namespace: str = ("dify",)
namespace_password: str = (None,)
metrics: str = ("cosine",)
read_timeout: int = 60000
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret,
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
class AnalyticdbVector(BaseVector):
def __init__(self, collection_name: str, config: AnalyticdbConfig):
self._collection_name = collection_name.lower()
try:
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models
except:
raise ImportError(_import_err_msg)
self.config = config
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
self._client = Client(self._client_config)
self._initialize()
def _initialize(self) -> None:
cache_key = f"vector_indexing_{self.config.instance_id}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self.config.instance_id}"
if redis_client.get(collection_exist_cache_key):
return
self._initialize_vector_database()
self._create_namespace_if_not_exists()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.init_vector_database(request)
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.describe_namespace(request)
except TeaException as e:
if e.statusCode == 404:
request = gpdb_20160503_models.CreateNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
)
self._client.create_namespace(request)
else:
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
try:
request = gpdb_20160503_models.DescribeCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
)
self._client.describe_collection(request)
except TeaException as e:
if e.statusCode == 404:
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
full_text_retrieval_fields = "page_content"
request = gpdb_20160503_models.CreateCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
collection=self._collection_name,
dimension=embedding_dimension,
metrics=self.config.metrics,
metadata=metadata,
full_text_retrieval_fields=full_text_retrieval_fields,
)
self._client.create_collection(request)
else:
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def __init__(
self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
):
super().__init__(collection_name)
if api_config is not None:
self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
else:
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
def get_type(self) -> str:
return VectorType.ANALYTICDB
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection_if_not_exists(dimension)
self.add_texts(texts, embeddings)
self.analyticdb_vector._create_collection_if_not_exists(dimension)
self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
)
)
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
rows=rows,
)
self._client.upsert_collection_data(request)
def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.analyticdb_vector.add_texts(texts, embeddings)
def text_exists(self, id: str) -> bool:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'",
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
return self.analyticdb_vector.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
ids_str = f"({ids_str})"
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
self.analyticdb_vector.delete_by_ids(ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
self.analyticdb_vector.delete_by_metadata_field(key, value)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
return self.analyticdb_vector.search_by_vector(query_vector)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.metadata.get("vector"),
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
def delete(self) -> None:
try:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
collection=self._collection_name,
dbinstance_id=self.config.instance_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
region_id=self.config.region_id,
)
self._client.delete_collection(request)
except Exception as e:
raise e
self.analyticdb_vector.delete()
class AnalyticdbVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
@ -313,26 +65,9 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
# handle optional params
if dify_config.ANALYTICDB_KEY_ID is None:
raise ValueError("ANALYTICDB_KEY_ID should not be None")
if dify_config.ANALYTICDB_KEY_SECRET is None:
raise ValueError("ANALYTICDB_KEY_SECRET should not be None")
if dify_config.ANALYTICDB_REGION_ID is None:
raise ValueError("ANALYTICDB_REGION_ID should not be None")
if dify_config.ANALYTICDB_INSTANCE_ID is None:
raise ValueError("ANALYTICDB_INSTANCE_ID should not be None")
if dify_config.ANALYTICDB_ACCOUNT is None:
raise ValueError("ANALYTICDB_ACCOUNT should not be None")
if dify_config.ANALYTICDB_PASSWORD is None:
raise ValueError("ANALYTICDB_PASSWORD should not be None")
if dify_config.ANALYTICDB_NAMESPACE is None:
raise ValueError("ANALYTICDB_NAMESPACE should not be None")
if dify_config.ANALYTICDB_NAMESPACE_PASSWORD is None:
raise ValueError("ANALYTICDB_NAMESPACE_PASSWORD should not be None")
return AnalyticdbVector(
collection_name,
AnalyticdbConfig(
if dify_config.ANALYTICDB_HOST is None:
# implemented through OpenAPI
apiConfig = AnalyticdbVectorOpenAPIConfig(
access_key_id=dify_config.ANALYTICDB_KEY_ID,
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
region_id=dify_config.ANALYTICDB_REGION_ID,
@ -341,5 +76,22 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
account_password=dify_config.ANALYTICDB_PASSWORD,
namespace=dify_config.ANALYTICDB_NAMESPACE,
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
),
)
sqlConfig = None
else:
# implemented through sql
sqlConfig = AnalyticdbVectorBySqlConfig(
host=dify_config.ANALYTICDB_HOST,
port=dify_config.ANALYTICDB_PORT,
account=dify_config.ANALYTICDB_ACCOUNT,
account_password=dify_config.ANALYTICDB_PASSWORD,
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
namespace=dify_config.ANALYTICDB_NAMESPACE,
)
apiConfig = None
return AnalyticdbVector(
collection_name,
apiConfig,
sqlConfig,
)

View File

@ -0,0 +1,309 @@
import json
from typing import Any
from pydantic import BaseModel, model_validator
_import_err_msg = (
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
class AnalyticdbVectorOpenAPIConfig(BaseModel):
access_key_id: str
access_key_secret: str
region_id: str
instance_id: str
account: str
account_password: str
namespace: str = "dify"
namespace_password: str = (None,)
metrics: str = "cosine"
read_timeout: int = 60000
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["access_key_id"]:
raise ValueError("config ANALYTICDB_KEY_ID is required")
if not values["access_key_secret"]:
raise ValueError("config ANALYTICDB_KEY_SECRET is required")
if not values["region_id"]:
raise ValueError("config ANALYTICDB_REGION_ID is required")
if not values["instance_id"]:
raise ValueError("config ANALYTICDB_INSTANCE_ID is required")
if not values["account"]:
raise ValueError("config ANALYTICDB_ACCOUNT is required")
if not values["account_password"]:
raise ValueError("config ANALYTICDB_PASSWORD is required")
if not values["namespace_password"]:
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
return values
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret,
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
class AnalyticdbVectorOpenAPI:
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
try:
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models
except:
raise ImportError(_import_err_msg)
self._collection_name = collection_name.lower()
self.config = config
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
self._client = Client(self._client_config)
self._initialize()
def _initialize(self) -> None:
cache_key = f"vector_initialize_{self.config.instance_id}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
database_exist_cache_key = f"vector_initialize_{self.config.instance_id}"
if redis_client.get(database_exist_cache_key):
return
self._initialize_vector_database()
self._create_namespace_if_not_exists()
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.init_vector_database(request)
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.describe_namespace(request)
except TeaException as e:
if e.statusCode == 404:
request = gpdb_20160503_models.CreateNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
)
self._client.create_namespace(request)
else:
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
try:
request = gpdb_20160503_models.DescribeCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
)
self._client.describe_collection(request)
except TeaException as e:
if e.statusCode == 404:
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
full_text_retrieval_fields = "page_content"
request = gpdb_20160503_models.CreateCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
collection=self._collection_name,
dimension=embedding_dimension,
metrics=self.config.metrics,
metadata=metadata,
full_text_retrieval_fields=full_text_retrieval_fields,
)
self._client.create_collection(request)
else:
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
)
)
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
rows=rows,
)
self._client.upsert_collection_data(request)
def text_exists(self, id: str) -> bool:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'",
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
ids_str = f"({ids_str})"
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
def delete_by_metadata_field(self, key: str, value: str) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.values.value,
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.values.value,
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
def delete(self) -> None:
try:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
collection=self._collection_name,
dbinstance_id=self.config.instance_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
region_id=self.config.region_id,
)
self._client.delete_collection(request)
except Exception as e:
raise e

View File

@ -0,0 +1,245 @@
import json
import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
class AnalyticdbVectorBySqlConfig(BaseModel):
host: str
port: int
account: str
account_password: str
min_connection: int
max_connection: int
namespace: str = "dify"
metrics: str = "cosine"
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config ANALYTICDB_HOST is required")
if not values["port"]:
raise ValueError("config ANALYTICDB_PORT is required")
if not values["account"]:
raise ValueError("config ANALYTICDB_ACCOUNT is required")
if not values["account_password"]:
raise ValueError("config ANALYTICDB_PASSWORD is required")
if not values["min_connection"]:
raise ValueError("config ANALYTICDB_MIN_CONNECTION is required")
if not values["max_connection"]:
raise ValueError("config ANALYTICDB_MAX_CONNECTION is required")
if values["min_connection"] > values["max_connection"]:
raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION")
return values
class AnalyticdbVectorBySql:
def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig):
self._collection_name = collection_name.lower()
self.databaseName = "knowledgebase"
self.config = config
self.table_name = f"{self.config.namespace}.{self._collection_name}"
self.pool = None
self._initialize()
if not self.pool:
self.pool = self._create_connection_pool()
def _initialize(self) -> None:
cache_key = f"vector_initialize_{self.config.host}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
database_exist_cache_key = f"vector_initialize_{self.config.host}"
if redis_client.get(database_exist_cache_key):
return
self._initialize_vector_database()
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _create_connection_pool(self):
return psycopg2.pool.SimpleConnectionPool(
self.config.min_connection,
self.config.max_connection,
host=self.config.host,
port=self.config.port,
user=self.config.account,
password=self.config.account_password,
database=self.databaseName,
)
@contextmanager
def _get_cursor(self):
conn = self.pool.getconn()
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
self.pool.putconn(conn)
def _initialize_vector_database(self) -> None:
conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
user=self.config.account,
password=self.config.account_password,
database="postgres",
)
conn.autocommit = True
cur = conn.cursor()
try:
cur.execute(f"CREATE DATABASE {self.databaseName}")
except Exception as e:
if "already exists" in str(e):
return
raise e
finally:
cur.close()
conn.close()
self.pool = self._create_connection_pool()
with self._get_cursor() as cur:
try:
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
except Exception as e:
if "already exists" not in str(e):
raise e
cur.execute(
"CREATE OR REPLACE FUNCTION "
"public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) "
"RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ "
"SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) "
"FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) "
"AS words_only;$function$"
)
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
with self._get_cursor() as cur:
cur.execute(
f"CREATE TABLE IF NOT EXISTS {self.table_name}("
f"id text PRIMARY KEY,"
f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
f"to_tsvector TSVECTOR"
f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
)
if embedding_dimension is not None:
index_name = f"{self._collection_name}_embedding_idx"
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
cur.execute(
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
f"pq_enable=0, external_storage=0)"
)
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
id_prefix = str(uuid.uuid4()) + "_"
sql = f"""
INSERT INTO {self.table_name}
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
"""
for i, doc in enumerate(documents):
values.append(
(
id_prefix + str(i),
doc.metadata.get("doc_id", str(uuid.uuid4())),
embeddings[i],
doc.page_content,
json.dumps(doc.metadata),
doc.page_content,
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_batch(cur, sql, values)
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,))
return cur.fetchone() is not None
def delete_by_ids(self, ids: list[str]) -> None:
with self._get_cursor() as cur:
try:
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),))
except Exception as e:
if "does not exist" not in str(e):
raise e
def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_cursor() as cur:
try:
cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value))
except Exception as e:
if "does not exist" not in str(e):
raise e
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
with self._get_cursor() as cur:
query_vector_str = json.dumps(query_vector)
query_vector_str = "{" + query_vector_str[1:-1] + "}"
cur.execute(
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
f"t.page_content as page_content, t.metadata_ AS metadata_ "
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
(query_vector_str,),
)
documents = []
for record in cur:
id, vector, score, page_content, metadata = record
if score > score_threshold:
metadata["score"] = score
doc = Document(
page_content=page_content,
vector=vector,
metadata=metadata,
)
documents.append(doc)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
with self._get_cursor() as cur:
cur.execute(
f"""SELECT id, vector, page_content, metadata_,
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
ORDER BY score DESC
LIMIT {top_k}""",
(f"'{query}'", f"'{query}'"),
)
documents = []
for record in cur:
id, vector, page_content, metadata, score = record
metadata["score"] = score
doc = Document(
page_content=page_content,
vector=vector,
metadata=metadata,
)
documents.append(doc)
return documents
def delete(self) -> None:
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View File

@ -11,6 +11,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset
@ -43,11 +44,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:].strip()
else:
page_content = page_content
page_content = remove_leading_symbols(document_node.page_content).strip()
if len(page_content) > 0:
document_node.page_content = page_content
split_documents.append(document_node)

View File

@ -18,6 +18,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset
@ -53,11 +54,7 @@ class QAIndexProcessor(BaseIndexProcessor):
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
document_node.page_content = remove_leading_symbols(page_content)
split_documents.append(document_node)
all_documents.extend(split_documents)
for i in range(0, len(all_documents), 10):

View File

@ -36,23 +36,21 @@ class WeightRerankRunner(BaseRerankRunner):
:return:
"""
docs = []
doc_id = []
unique_documents = []
doc_id = set()
for document in documents:
if document.metadata["doc_id"] not in doc_id:
doc_id.append(document.metadata["doc_id"])
docs.append(document.page_content)
doc_id = document.metadata.get("doc_id")
if doc_id not in doc_id:
doc_id.add(doc_id)
unique_documents.append(document)
documents = unique_documents
rerank_documents = []
query_scores = self._calculate_keyword_score(query, documents)
query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
rerank_documents = []
for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
# format document
score = (
self.weights.vector_setting.vector_weight * query_vector_score
+ self.weights.keyword_setting.keyword_weight * query_score
@ -61,7 +59,8 @@ class WeightRerankRunner(BaseRerankRunner):
continue
document.metadata["score"] = score
rerank_documents.append(document)
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True)
rerank_documents.sort(key=lambda x: x.metadata["score"], reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, ClassVar
from duckduckgo_search import DDGS
@ -11,6 +11,17 @@ class DuckDuckGoVideoSearchTool(BuiltinTool):
Tool for performing a video search using DuckDuckGo search engine.
"""
IFRAME_TEMPLATE: ClassVar[str] = """
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden; \
max-width: 100%; border-radius: 8px;">
<iframe
style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;"
src="{src}"
frameborder="0"
allowfullscreen>
</iframe>
</div>"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
query_dict = {
"keywords": tool_parameters.get("query"),
@ -26,6 +37,9 @@ class DuckDuckGoVideoSearchTool(BuiltinTool):
# Remove None values to use API defaults
query_dict = {k: v for k, v in query_dict.items() if v is not None}
# Get proxy URL from parameters
proxy_url = tool_parameters.get("proxy_url", "").strip()
response = DDGS().videos(**query_dict)
# Create HTML result with embedded iframes
@ -36,20 +50,21 @@ class DuckDuckGoVideoSearchTool(BuiltinTool):
title = res.get("title", "")
embed_html = res.get("embed_html", "")
description = res.get("description", "")
content_url = res.get("content", "")
# Modify iframe to be responsive
if embed_html:
# Replace fixed dimensions with responsive wrapper and iframe
embed_html = """
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden; \
max-width: 100%; border-radius: 8px;">
<iframe
style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;"
src="{src}"
frameborder="0"
allowfullscreen>
</iframe>
</div>""".format(src=res.get("embed_url", ""))
# Handle TED.com videos
if not embed_html and "ted.com/talks" in content_url:
embed_url = content_url.replace("www.ted.com", "embed.ted.com")
if proxy_url:
embed_url = f"{proxy_url}{embed_url}"
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
# Original YouTube/other platform handling
elif embed_html:
embed_url = res.get("embed_url", "")
if proxy_url and embed_url:
embed_url = f"{proxy_url}{embed_url}"
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
markdown_result += f"{title}\n\n"
markdown_result += f"{embed_html}\n\n"

View File

@ -1,40 +1,43 @@
identity:
name: ddgo_video
author: Assistant
author: Tao Wang
label:
en_US: DuckDuckGo Video Search
zh_Hans: DuckDuckGo 视频搜索
description:
human:
en_US: Perform video searches on DuckDuckGo and get results with embedded videos.
zh_Hans: 在 DuckDuckGo 上进行视频搜索并获取可嵌入视频结果。
llm: Perform video searches on DuckDuckGo and get results with embedded videos.
en_US: Search and embedded videos.
zh_Hans: 搜索并嵌入视频
llm: Search videos on duckduckgo and embed videos in iframe
parameters:
- name: query
type: string
required: true
label:
en_US: Query String
zh_Hans: 查询语句
type: string
required: true
human_description:
en_US: Search Query
zh_Hans: 搜索查询语句
zh_Hans: 搜索查询语句
llm_description: Key words for searching
form: llm
- name: max_results
label:
en_US: Max Results
zh_Hans: 最大结果数量
type: number
required: true
default: 3
minimum: 1
maximum: 10
label:
en_US: Max Results
zh_Hans: 最大结果数量
human_description:
en_US: The max results (1-10).
zh_Hans: 最大结果数量1-10
en_US: The max results (1-10)
zh_Hans: 最大结果数量1-10
form: form
- name: timelimit
label:
en_US: Result Time Limit
zh_Hans: 结果时间限制
type: select
required: false
options:
@ -54,14 +57,14 @@ parameters:
label:
en_US: Current Year
zh_Hans: 今年
label:
en_US: Result Time Limit
zh_Hans: 结果时间限制
human_description:
en_US: Use when querying results within a specific time range only.
en_US: Query results within a specific time range only
zh_Hans: 只查询一定时间范围内的结果时使用
form: form
- name: duration
label:
en_US: Video Duration
zh_Hans: 视频时长
type: select
required: false
options:
@ -77,10 +80,18 @@ parameters:
label:
en_US: Long (>20 minutes)
zh_Hans: 长视频(>20分钟
label:
en_US: Video Duration
zh_Hans: 视频时长
human_description:
en_US: Filter videos by duration
zh_Hans: 按时长筛选视频
form: form
- name: proxy_url
label:
en_US: Proxy URL
zh_Hans: 视频代理地址
type: string
required: false
default: ""
human_description:
en_US: Proxy URL
zh_Hans: 视频代理地址
form: form

View File

@ -17,7 +17,7 @@ class SendMailTool(BuiltinTool):
invoke tools
"""
sender = self.runtime.credentials.get("email_account", "")
email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
password = self.runtime.credentials.get("email_password", "")
smtp_server = self.runtime.credentials.get("smtp_server", "")
if not smtp_server:

View File

@ -18,7 +18,7 @@ class SendMailTool(BuiltinTool):
invoke tools
"""
sender = self.runtime.credentials.get("email_account", "")
email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
password = self.runtime.credentials.get("email_password", "")
smtp_server = self.runtime.credentials.get("smtp_server", "")
if not smtp_server:

View File

@ -0,0 +1,25 @@
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GiteeAIToolEmbedding(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
headers = {
"content-type": "application/json",
"authorization": f"Bearer {self.runtime.credentials['api_key']}",
}
payload = {"inputs": tool_parameters.get("inputs")}
model = tool_parameters.get("model", "bge-m3")
url = f"https://ai.gitee.com/api/serverless/{model}/embeddings"
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
return self.create_text_message(f"Got Error Response:{response.text}")
return [self.create_text_message(response.content.decode("utf-8"))]

View File

@ -0,0 +1,37 @@
identity:
name: embedding
author: gitee_ai
label:
en_US: embedding
icon: icon.svg
description:
human:
en_US: Generate word embeddings using Serverless-supported models (compatible with OpenAI)
llm: This tool is used to generate word embeddings from text input.
parameters:
- name: model
type: string
required: true
in: path
description:
en_US: Supported Embedding (compatible with OpenAI) interface models
enum:
- bge-m3
- bge-large-zh-v1.5
- bge-small-zh-v1.5
label:
en_US: Service Model
zh_Hans: 服务模型
default: bge-m3
form: form
- name: inputs
type: string
required: true
label:
en_US: Input Text
zh_Hans: 输入文本
human_description:
en_US: The text input used to generate embeddings.
zh_Hans: 用于生成词向量的输入文本。
llm_description: This text input will be used to generate embeddings.
form: llm

View File

@ -6,7 +6,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GiteeAITool(BuiltinTool):
class GiteeAIToolText2Image(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:

View File

@ -40,6 +40,9 @@ class JSONParseTool(BuiltinTool):
expr = parse(json_filter)
result = [match.value for match in expr.find(input_data)]
if not result:
return ""
if len(result) == 1:
result = result[0]

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

View File

@ -0,0 +1,22 @@
from typing import Any
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.rapidapi.tools.google_news import GooglenewsTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class RapidapiProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
GooglenewsTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id="",
tool_parameters={
"language_region": "en-US",
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,39 @@
identity:
name: rapidapi
author: Steven Sun
label:
en_US: RapidAPI
zh_Hans: RapidAPI
description:
en_US: RapidAPI is the world's largest API marketplace with over 1,000,000 developers and 10,000 APIs.
zh_Hans: RapidAPI是全球最大的API市场拥有超过100万开发人员和10000个API。
icon: rapidapi.png
tags:
- news
credentials_for_provider:
x-rapidapi-host:
type: text-input
required: true
label:
en_US: x-rapidapi-host
zh_Hans: x-rapidapi-host
placeholder:
en_US: Please input your x-rapidapi-host
zh_Hans: 请输入你的 x-rapidapi-host
help:
en_US: Get your x-rapidapi-host from RapidAPI.
zh_Hans: 从 RapidAPI 获取您的 x-rapidapi-host。
url: https://rapidapi.com/
x-rapidapi-key:
type: secret-input
required: true
label:
en_US: x-rapidapi-key
zh_Hans: x-rapidapi-key
placeholder:
en_US: Please input your x-rapidapi-key
zh_Hans: 请输入你的 x-rapidapi-key
help:
en_US: Get your x-rapidapi-key from RapidAPI.
zh_Hans: 从 RapidAPI 获取您的 x-rapidapi-key。
url: https://rapidapi.com/

View File

@ -0,0 +1,33 @@
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolInvokeError, ToolProviderCredentialValidationError
from core.tools.tool.builtin_tool import BuiltinTool
class GooglenewsTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
key = self.runtime.credentials.get("x-rapidapi-key", "")
host = self.runtime.credentials.get("x-rapidapi-host", "")
if not all([key, host]):
raise ToolProviderCredentialValidationError("Please input correct x-rapidapi-key and x-rapidapi-host")
headers = {"x-rapidapi-key": key, "x-rapidapi-host": host}
lr = tool_parameters.get("language_region", "")
url = f"https://{host}/latest?lr={lr}"
response = requests.get(url, headers=headers)
if response.status_code != 200:
raise ToolInvokeError(f"Error {response.status_code}: {response.text}")
return self.create_text_message(response.text)
def validate_credentials(self, parameters: dict[str, Any]) -> None:
parameters["validate"] = True
self._invoke(parameters)

View File

@ -0,0 +1,24 @@
identity:
name: google_news
author: Steven Sun
label:
en_US: GoogleNews
zh_Hans: 谷歌新闻
description:
human:
en_US: google news is a news aggregator service developed by Google. It presents a continuous, customizable flow of articles organized from thousands of publishers and magazines.
zh_Hans: 谷歌新闻是由谷歌开发的新闻聚合服务。它提供了一个持续的、可定制的文章流,这些文章是从成千上万的出版商和杂志中整理出来的。
llm: A tool to get the latest news from Google News.
parameters:
- name: language_region
type: string
required: true
label:
en_US: Language and Region
zh_Hans: 语言和地区
human_description:
en_US: The language and region determine the language and region of the search results, and its value is assigned according to the "National Language Code Comparison Table", such as en-US, which stands for English (United States); zh-CN, stands for Chinese (Simplified).
zh_Hans: 语言和地区决定了搜索结果的语言和地区,其赋值按照《国家语言代码对照表》形如en-US,代表英语美国zh-CN代表中文简体
llm_description: The language and region determine the language and region of the search results, and its value is assigned according to the "National Language Code Comparison Table", such as en-US, which stands for English (United States); zh-CN, stands for Chinese (Simplified).
default: en-US
form: llm

View File

@ -5,6 +5,7 @@ from urllib.parse import urlencode
import httpx
from core.file.file_manager import download
from core.helper import ssrf_proxy
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
@ -138,6 +139,7 @@ class ApiTool(Tool):
path_params = {}
body = {}
cookies = {}
files = []
# check parameters
for parameter in self.api_bundle.openapi.get("parameters", []):
@ -166,8 +168,12 @@ class ApiTool(Tool):
properties = body_schema.get("properties", {})
for name, property in properties.items():
if name in parameters:
# convert type
body[name] = self._convert_body_property_type(property, parameters[name])
if property.get("format") == "binary":
f = parameters[name]
files.append((name, (f.filename, download(f), f.mime_type)))
else:
# convert type
body[name] = self._convert_body_property_type(property, parameters[name])
elif name in required:
raise ToolParameterValidationError(
f"Missing required parameter {name} in operation {self.api_bundle.operation_id}"
@ -182,7 +188,7 @@ class ApiTool(Tool):
for name, value in path_params.items():
url = url.replace(f"{{{name}}}", f"{value}")
# parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored
# parse http body data if needed
if "Content-Type" in headers:
if headers["Content-Type"] == "application/json":
body = json.dumps(body)
@ -198,6 +204,7 @@ class ApiTool(Tool):
headers=headers,
cookies=cookies,
data=body,
files=files,
timeout=API_TOOL_DEFAULT_TIMEOUT,
follow_redirects=True,
)

View File

@ -261,7 +261,7 @@ class Tool(BaseModel, ABC):
"""
parameters = self.parameters or []
parameters = parameters.copy()
user_parameters = self.get_runtime_parameters() or []
user_parameters = self.get_runtime_parameters()
user_parameters = user_parameters.copy()
# override parameters

View File

@ -55,7 +55,7 @@ class ToolEngine:
# check if this tool has only one parameter
parameters = [
parameter
for parameter in tool.get_runtime_parameters() or []
for parameter in tool.get_runtime_parameters()
if parameter.form == ToolParameter.ToolParameterForm.LLM
]
if parameters and len(parameters) == 1:

View File

@ -127,7 +127,7 @@ class ToolParameterConfigurationManager(BaseModel):
# get tool parameters
tool_parameters = self.tool_runtime.parameters or []
# get tool runtime parameters
runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
runtime_parameters = self.tool_runtime.get_runtime_parameters()
# override parameters
current_parameters = tool_parameters.copy()
for runtime_parameter in runtime_parameters:

View File

@ -161,6 +161,9 @@ class ApiBasedToolSchemaParser:
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType:
parameter = parameter or {}
typ = None
if parameter.get("format") == "binary":
return ToolParameter.ToolParameterType.FILE
if "type" in parameter:
typ = parameter["type"]
elif "schema" in parameter and "type" in parameter["schema"]:

View File

@ -0,0 +1,16 @@
import re
def remove_leading_symbols(text: str) -> str:
"""
Remove leading punctuation or symbols from the given text.
Args:
text (str): The input text to process.
Returns:
str: The text with leading punctuation or symbols removed.
"""
# Match Unicode ranges for punctuation and symbols
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,\-./:;<=>?@\[\]^_`{|}~]+"
return re.sub(pattern, "", text)

View File

@ -1,4 +1,6 @@
import mimetypes
from collections.abc import Sequence
from email.message import Message
from typing import Any, Literal, Optional
import httpx
@ -7,14 +9,6 @@ from pydantic import BaseModel, Field, ValidationInfo, field_validator
from configs import dify_config
from core.workflow.nodes.base import BaseNodeData
NON_FILE_CONTENT_TYPES = (
"application/json",
"application/xml",
"text/html",
"text/plain",
"application/x-www-form-urlencoded",
)
class HttpRequestNodeAuthorizationConfig(BaseModel):
type: Literal["basic", "bearer", "custom"]
@ -93,13 +87,53 @@ class Response:
@property
def is_file(self):
content_type = self.content_type
"""
Determine if the response contains a file by checking:
1. Content-Disposition header (RFC 6266)
2. Content characteristics
3. MIME type analysis
"""
content_type = self.content_type.split(";")[0].strip().lower()
content_disposition = self.response.headers.get("content-disposition", "")
return "attachment" in content_disposition or (
not any(non_file in content_type for non_file in NON_FILE_CONTENT_TYPES)
and any(file_type in content_type for file_type in ("application/", "image/", "audio/", "video/"))
)
# Check if it's explicitly marked as an attachment
if content_disposition:
msg = Message()
msg["content-disposition"] = content_disposition
disp_type = msg.get_content_disposition() # Returns 'attachment', 'inline', or None
filename = msg.get_filename() # Returns filename if present, None otherwise
if disp_type == "attachment" or filename is not None:
return True
# For application types, try to detect if it's a text-based format
if content_type.startswith("application/"):
# Common text-based application types
if any(
text_type in content_type
for text_type in ("json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql")
):
return False
# Try to detect if content is text-based by sampling first few bytes
try:
# Sample first 1024 bytes for text detection
content_sample = self.response.content[:1024]
content_sample.decode("utf-8")
# If we can decode as UTF-8 and find common text patterns, likely not a file
text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ")
if any(marker in content_sample for marker in text_markers):
return False
except UnicodeDecodeError:
# If we can't decode as UTF-8, likely a binary file
return True
# For other types, use MIME type analysis
main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or ""))
if main_type:
return main_type.split("/")[0] in ("application", "image", "audio", "video")
# For unknown types, check if it's a media type
return any(media_type in content_type for media_type in ("image/", "audio/", "video/"))
@property
def content_type(self) -> str:

View File

@ -1,11 +1,12 @@
import redis
from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
from redis.sentinel import Sentinel
from configs import dify_config
class RedisClientWrapper(redis.Redis):
class RedisClientWrapper:
"""
A wrapper class for the Redis client that addresses the issue where the global
`redis_client` variable cannot be updated when a new Redis instance is returned
@ -71,6 +72,12 @@ def init_app(app):
)
master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
redis_client.initialize(master)
elif dify_config.REDIS_USE_CLUSTERS:
nodes = [
ClusterNode(host=node.split(":")[0], port=int(node.split.split(":")[1]))
for node in dify_config.REDIS_CLUSTERS.split(",")
]
redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD))
else:
redis_params.update(
{

View File

@ -166,9 +166,9 @@ def _build_from_remote_url(
def _get_remote_file_info(url: str):
mime_type = mimetypes.guess_type(url)[0] or ""
file_size = -1
filename = url.split("/")[-1].split("?")[0] or "unknown_file"
mime_type = mimetypes.guess_type(filename)[0] or ""
resp = ssrf_proxy.head(url, follow_redirects=True)
if resp.status_code == httpx.codes.OK:
@ -233,10 +233,10 @@ def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool:
if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM:
return False
if config.allowed_extensions and file.extension not in config.allowed_extensions:
if config.allowed_file_extensions and file.extension not in config.allowed_file_extensions:
return False
if config.allowed_upload_methods and file.transfer_method not in config.allowed_upload_methods:
if config.allowed_file_upload_methods and file.transfer_method not in config.allowed_file_upload_methods:
return False
if file.type == FileType.IMAGE and config.image_config:

View File

@ -169,9 +169,9 @@ class Workflow(db.Model):
)
features["file_upload"]["enabled"] = image_enabled
features["file_upload"]["number_limits"] = image_number_limits
features["file_upload"]["allowed_upload_methods"] = image_transfer_methods
features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods
features["file_upload"]["allowed_file_types"] = ["image"]
features["file_upload"]["allowed_extensions"] = []
features["file_upload"]["allowed_file_extensions"] = []
del features["file_upload"]["image"]
self._features = json.dumps(features)
return self._features

View File

@ -341,7 +341,7 @@ class AppService:
if not app_model_config:
return meta
agent_config = app_model_config.agent_mode_dict or {}
agent_config = app_model_config.agent_mode_dict
# get all tools
tools = agent_config.get("tools", [])

View File

@ -242,7 +242,7 @@ class ToolTransformService:
# get tool parameters
parameters = tool.parameters or []
# get tool runtime parameters
runtime_parameters = tool.get_runtime_parameters() or []
runtime_parameters = tool.get_runtime_parameters()
# override parameters
current_parameters = parameters.copy()
for runtime_parameter in runtime_parameters:

View File

@ -51,8 +51,8 @@ class WebsiteService:
excludes = options.get("excludes").split(",") if options.get("excludes") else []
params = {
"crawlerOptions": {
"includes": includes or [],
"excludes": excludes or [],
"includes": includes,
"excludes": excludes,
"generateImgAltText": True,
"limit": options.get("limit", 1),
"returnOnlyUrls": False,

View File

@ -78,6 +78,7 @@ def clean_dataset_task(
"Delete image_files failed when storage deleted, \
image_upload_file_is: {}".format(upload_file_id)
)
db.session.delete(image_file)
db.session.delete(segment)
db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete()

View File

@ -51,6 +51,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
"Delete image_files failed when storage deleted, \
image_upload_file_is: {}".format(upload_file_id)
)
db.session.delete(image_file)
db.session.delete(segment)
db.session.commit()

View File

@ -31,7 +31,7 @@ def test_invoke_model(setup_google_mock):
model = GoogleLargeLanguageModel()
response = model.invoke(
model="gemini-pro",
model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
@ -48,7 +48,7 @@ def test_invoke_model(setup_google_mock):
]
),
],
model_parameters={"temperature": 0.5, "top_p": 1.0, "max_tokens_to_sample": 2048},
model_parameters={"temperature": 0.5, "top_p": 1.0, "max_output_tokens": 2048},
stop=["How"],
stream=False,
user="abc-123",
@ -63,7 +63,7 @@ def test_invoke_stream_model(setup_google_mock):
model = GoogleLargeLanguageModel()
response = model.invoke(
model="gemini-pro",
model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
@ -80,7 +80,7 @@ def test_invoke_stream_model(setup_google_mock):
]
),
],
model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens_to_sample": 2048},
model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens": 2048},
stream=True,
user="abc-123",
)
@ -99,7 +99,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
model = GoogleLargeLanguageModel()
result = model.invoke(
model="gemini-pro-vision",
model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
@ -128,7 +128,7 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
model = GoogleLargeLanguageModel()
result = model.invoke(
model="gemini-pro-vision",
model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[
SystemPromptMessage(content="You are a helpful AI assistant."),
@ -164,7 +164,7 @@ def test_get_num_tokens():
model = GoogleLargeLanguageModel()
num_tokens = model.get_num_tokens(
model="gemini-pro",
model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[
SystemPromptMessage(

View File

@ -1,27 +1,43 @@
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbConfig, AnalyticdbVector
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
class AnalyticdbVectorTest(AbstractVectorTest):
def __init__(self):
def __init__(self, config_type: str):
super().__init__()
# Analyticdb requires collection_name length less than 60.
# it's ok for normal usage.
self.collection_name = self.collection_name.replace("_test", "")
self.vector = AnalyticdbVector(
collection_name=self.collection_name,
config=AnalyticdbConfig(
access_key_id="test_key_id",
access_key_secret="test_key_secret",
region_id="test_region",
instance_id="test_id",
account="test_account",
account_password="test_passwd",
namespace="difytest_namespace",
collection="difytest_collection",
namespace_password="test_passwd",
),
)
if config_type == "sql":
self.vector = AnalyticdbVector(
collection_name=self.collection_name,
sql_config=AnalyticdbVectorBySqlConfig(
host="test_host",
port=5432,
account="test_account",
account_password="test_passwd",
namespace="difytest_namespace",
),
api_config=None,
)
else:
self.vector = AnalyticdbVector(
collection_name=self.collection_name,
sql_config=None,
api_config=AnalyticdbVectorOpenAPIConfig(
access_key_id="test_key_id",
access_key_secret="test_key_secret",
region_id="test_region",
instance_id="test_id",
account="test_account",
account_password="test_passwd",
namespace="difytest_namespace",
collection="difytest_collection",
namespace_password="test_passwd",
),
)
def run_all_tests(self):
self.vector.delete()
@ -29,4 +45,5 @@ class AnalyticdbVectorTest(AbstractVectorTest):
def test_chroma_vector(setup_mock_redis):
AnalyticdbVectorTest().run_all_tests()
AnalyticdbVectorTest("api").run_all_tests()
AnalyticdbVectorTest("sql").run_all_tests()

View File

@ -27,8 +27,8 @@ NEW_VERSION_WORKFLOW_FEATURES = {
"file_upload": {
"enabled": True,
"allowed_file_types": ["image"],
"allowed_extensions": [],
"allowed_upload_methods": ["remote_url", "local_file"],
"allowed_file_extensions": [],
"allowed_file_upload_methods": ["remote_url", "local_file"],
"number_limits": 6,
},
"opening_statement": "",

View File

@ -1,4 +1,7 @@
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
import json
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType, FileUploadConfig
from models.workflow import Workflow
def test_file_loads_and_dumps():
@ -38,3 +41,40 @@ def test_file_to_dict():
file_dict = file.to_dict()
assert "_extra_config" not in file_dict
assert "url" in file_dict
def test_workflow_features_with_image():
# Create a feature dict that mimics the old structure with image config
features = {
"file_upload": {
"image": {"enabled": True, "number_limits": 5, "transfer_methods": ["remote_url", "local_file"]}
}
}
# Create a workflow instance with the features
workflow = Workflow(
tenant_id="tenant-1",
app_id="app-1",
type="chat",
version="1.0",
graph="{}",
features=json.dumps(features),
created_by="user-1",
environment_variables=[],
conversation_variables=[],
)
# Get the converted features through the property
converted_features = json.loads(workflow.features)
# Create FileUploadConfig from the converted features
file_upload_config = FileUploadConfig.model_validate(converted_features["file_upload"])
# Validate the config
assert file_upload_config.number_limits == 5
assert list(file_upload_config.allowed_file_types) == [FileType.IMAGE]
assert list(file_upload_config.allowed_file_upload_methods) == [
FileTransferMethod.REMOTE_URL,
FileTransferMethod.LOCAL_FILE,
]
assert list(file_upload_config.allowed_file_extensions) == []

View File

@ -1,10 +1,12 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import pytest
import redis
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.model_manager import LBModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_redis import redis_client
@pytest.fixture
@ -38,6 +40,9 @@ def lb_model_manager():
def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
# initialize redis client
redis_client.initialize(redis.Redis())
assert len(lb_model_manager._load_balancing_configs) == 3
config1 = lb_model_manager._load_balancing_configs[0]
@ -55,12 +60,13 @@ def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
start_index += 1
return start_index
mocker.patch("redis.Redis.incr", side_effect=incr)
mocker.patch("redis.Redis.set", return_value=None)
mocker.patch("redis.Redis.expire", return_value=None)
with (
patch.object(redis_client, "incr", side_effect=incr),
patch.object(redis_client, "set", return_value=None),
patch.object(redis_client, "expire", return_value=None),
):
config = lb_model_manager.fetch_next()
assert config == config2
config = lb_model_manager.fetch_next()
assert config == config2
config = lb_model_manager.fetch_next()
assert config == config3
config = lb_model_manager.fetch_next()
assert config == config3

View File

@ -0,0 +1,140 @@
from unittest.mock import Mock, PropertyMock, patch
import httpx
import pytest
from core.workflow.nodes.http_request.entities import Response
@pytest.fixture
def mock_response():
response = Mock(spec=httpx.Response)
response.headers = {}
return response
def test_is_file_with_attachment_disposition(mock_response):
"""Test is_file when content-disposition header contains 'attachment'"""
mock_response.headers = {"content-disposition": "attachment; filename=test.pdf", "content-type": "application/pdf"}
response = Response(mock_response)
assert response.is_file
def test_is_file_with_filename_disposition(mock_response):
"""Test is_file when content-disposition header contains filename parameter"""
mock_response.headers = {"content-disposition": "inline; filename=test.pdf", "content-type": "application/pdf"}
response = Response(mock_response)
assert response.is_file
@pytest.mark.parametrize("content_type", ["application/pdf", "image/jpeg", "audio/mp3", "video/mp4"])
def test_is_file_with_file_content_types(mock_response, content_type):
"""Test is_file with various file content types"""
mock_response.headers = {"content-type": content_type}
# Mock binary content
type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512))
response = Response(mock_response)
assert response.is_file, f"Content type {content_type} should be identified as a file"
@pytest.mark.parametrize(
"content_type",
[
"application/json",
"application/xml",
"application/javascript",
"application/x-www-form-urlencoded",
"application/yaml",
"application/graphql",
],
)
def test_text_based_application_types(mock_response, content_type):
"""Test common text-based application types are not identified as files"""
mock_response.headers = {"content-type": content_type}
response = Response(mock_response)
assert not response.is_file, f"Content type {content_type} should not be identified as a file"
@pytest.mark.parametrize(
("content", "content_type"),
[
(b'{"key": "value"}', "application/octet-stream"),
(b"[1, 2, 3]", "application/unknown"),
(b"function test() {}", "application/x-unknown"),
(b"<root>test</root>", "application/binary"),
(b"var x = 1;", "application/data"),
],
)
def test_content_based_detection(mock_response, content, content_type):
"""Test content-based detection for text-like content"""
mock_response.headers = {"content-type": content_type}
type(mock_response).content = PropertyMock(return_value=content)
response = Response(mock_response)
assert not response.is_file, f"Content {content} with type {content_type} should not be identified as a file"
@pytest.mark.parametrize(
("content", "content_type"),
[
(bytes([0x00, 0xFF] * 512), "application/octet-stream"),
(bytes([0x89, 0x50, 0x4E, 0x47]), "application/unknown"), # PNG magic numbers
(bytes([0xFF, 0xD8, 0xFF]), "application/binary"), # JPEG magic numbers
],
)
def test_binary_content_detection(mock_response, content, content_type):
"""Test content-based detection for binary content"""
mock_response.headers = {"content-type": content_type}
type(mock_response).content = PropertyMock(return_value=content)
response = Response(mock_response)
assert response.is_file, f"Binary content with type {content_type} should be identified as a file"
@pytest.mark.parametrize(
("content_type", "expected_main_type"),
[
("x-world/x-vrml", "model"), # VRML 3D model
("font/ttf", "application"), # TrueType font
("text/csv", "text"), # CSV text file
("unknown/xyz", None), # Unknown type
],
)
def test_mimetype_based_detection(mock_response, content_type, expected_main_type):
"""Test detection using mimetypes.guess_type for non-application content types"""
mock_response.headers = {"content-type": content_type}
type(mock_response).content = PropertyMock(return_value=bytes([0x00])) # Dummy content
with patch("core.workflow.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type:
# Mock the return value based on expected_main_type
if expected_main_type:
mock_guess_type.return_value = (f"{expected_main_type}/subtype", None)
else:
mock_guess_type.return_value = (None, None)
response = Response(mock_response)
# Check if the result matches our expectation
if expected_main_type in ("application", "image", "audio", "video"):
assert response.is_file, f"Content type {content_type} should be identified as a file"
else:
assert not response.is_file, f"Content type {content_type} should not be identified as a file"
# Verify that guess_type was called
mock_guess_type.assert_called_once()
def test_is_file_with_inline_disposition(mock_response):
"""Test is_file when content-disposition is 'inline'"""
mock_response.headers = {"content-disposition": "inline", "content-type": "application/pdf"}
# Mock binary content
type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512))
response = Response(mock_response)
assert response.is_file
def test_is_file_with_no_content_disposition(mock_response):
"""Test is_file when no content-disposition header is present"""
mock_response.headers = {"content-type": "application/pdf"}
# Mock binary content
type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512))
response = Response(mock_response)
assert response.is_file

View File

@ -0,0 +1,20 @@
from textwrap import dedent
import pytest
from core.tools.utils.text_processing_utils import remove_leading_symbols
@pytest.mark.parametrize(
("input_text", "expected_output"),
[
("...Hello, World!", "Hello, World!"),
("。测试中文标点", "测试中文标点"),
("!@#Test symbols", "Test symbols"),
("Hello, World!", "Hello, World!"),
("", ""),
(" ", " "),
],
)
def test_remove_leading_symbols(input_text, expected_output):
assert remove_leading_symbols(input_text) == expected_output

View File

@ -75,7 +75,8 @@ SECRET_KEY=sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U
# Password for admin user initialization.
# If left unset, admin user will not be prompted for a password
# when creating the initial admin account.
# when creating the initial admin account.
# The length of the password cannot exceed 30 charactors.
INIT_PASSWORD=
# Deployment environment.
@ -239,6 +240,12 @@ REDIS_SENTINEL_USERNAME=
REDIS_SENTINEL_PASSWORD=
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
# List of Redis Cluster nodes. If Cluster mode is enabled, provide at least one Cluster IP and port.
# Format: `<Cluster1_ip>:<Cluster1_port>,<Cluster2_ip>:<Cluster2_port>,<Cluster3_ip>:<Cluster3_port>`
REDIS_USE_CLUSTERS=false
REDIS_CLUSTERS=
REDIS_CLUSTERS_PASSWORD=
# ------------------------------
# Celery Configuration
# ------------------------------
@ -450,6 +457,10 @@ ANALYTICDB_ACCOUNT=testaccount
ANALYTICDB_PASSWORD=testpassword
ANALYTICDB_NAMESPACE=dify
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
ANALYTICDB_HOST=gp-test.aliyuncs.com
ANALYTICDB_PORT=5432
ANALYTICDB_MIN_CONNECTION=1
ANALYTICDB_MAX_CONNECTION=5
# TiDB vector configurations, only available when VECTOR_STORE is `tidb`
TIDB_VECTOR_HOST=tidb
@ -558,7 +569,7 @@ UPLOAD_FILE_SIZE_LIMIT=15
# The maximum number of files that can be uploaded at a time, default 5.
UPLOAD_FILE_BATCH_LIMIT=5
# ETl type, support: `dify`, `Unstructured`
# ETL type, support: `dify`, `Unstructured`
# `dify` Dify's proprietary file extraction scheme
# `Unstructured` Unstructured.io file extraction scheme
ETL_TYPE=dify

View File

@ -36,7 +36,7 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T
- Navigate to the `docker` directory.
- Ensure the `middleware.env` file is created by running `cp middleware.env.example middleware.env` (refer to the `middleware.env.example` file).
2. **Running Middleware Services**:
- Execute `docker-compose -f docker-compose.middleware.yaml up -d` to start the middleware services.
- Execute `docker-compose -f docker-compose.middleware.yaml up --env-file middleware.env -d` to start the middleware services.
### Migration for Existing Users

View File

@ -29,11 +29,13 @@ services:
redis:
image: redis:6-alpine
restart: always
environment:
REDISCLI_AUTH: ${REDIS_PASSWORD:-difyai123456}
volumes:
# Mount the redis data directory to the container.
- ${REDIS_HOST_VOLUME:-./volumes/redis/data}:/data
# Set the redis password when startup redis server.
command: redis-server --requirepass difyai123456
command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456}
ports:
- "${EXPOSE_REDIS_PORT:-6379}:6379"
healthcheck:

View File

@ -55,6 +55,9 @@ x-shared-env: &shared-api-worker-env
REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-}
REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-}
REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1}
REDIS_CLUSTERS: ${REDIS_CLUSTERS:-}
REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false}
REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-}
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1}
BROKER_USE_SSL: ${BROKER_USE_SSL:-false}
@ -185,6 +188,10 @@ x-shared-env: &shared-api-worker-env
ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-}
ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify}
ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-}
ANALYTICDB_HOST: ${ANALYTICDB_HOST:-}
ANALYTICDB_PORT: ${ANALYTICDB_PORT:-5432}
ANALYTICDB_MIN_CONNECTION: ${ANALYTICDB_MIN_CONNECTION:-1}
ANALYTICDB_MAX_CONNECTION: ${ANALYTICDB_MAX_CONNECTION:-5}
OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch}
OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200}
OPENSEARCH_USER: ${OPENSEARCH_USER:-admin}
@ -359,6 +366,8 @@ services:
redis:
image: redis:6-alpine
restart: always
environment:
REDISCLI_AUTH: ${REDIS_PASSWORD:-difyai123456}
volumes:
# Mount the redis data directory to the container.
- ./volumes/redis/data:/data

View File

@ -42,11 +42,13 @@ POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB
# -----------------------------
# Environment Variables for redis Service
REDIS_HOST_VOLUME=./volumes/redis/data
# -----------------------------
REDIS_HOST_VOLUME=./volumes/redis/data
REDIS_PASSWORD=difyai123456
# ------------------------------
# Environment Variables for sandbox Service
# ------------------------------
SANDBOX_API_KEY=dify-sandbox
SANDBOX_GIN_MODE=release
SANDBOX_WORKER_TIMEOUT=15
@ -54,7 +56,6 @@ SANDBOX_ENABLE_NETWORK=true
SANDBOX_HTTP_PROXY=http://ssrf_proxy:3128
SANDBOX_HTTPS_PROXY=http://ssrf_proxy:3128
SANDBOX_PORT=8194
# ------------------------------
# ------------------------------
# Environment Variables for ssrf_proxy Service

6
web/.gitignore vendored
View File

@ -50,4 +50,8 @@ package-lock.json
pnpm-lock.yaml
.favorites.json
*storybook.log
*storybook.log
# mise
mise.toml

View File

@ -5,7 +5,6 @@ import { useEffect, useMemo, useRef, useState } from 'react'
import { useRouter } from 'next/navigation'
import { useTranslation } from 'react-i18next'
import { useDebounceFn } from 'ahooks'
import useSWR from 'swr'
// Components
import ExternalAPIPanel from '../../components/datasets/external-api/external-api-panel'
@ -28,6 +27,8 @@ import { useTabSearchParams } from '@/hooks/use-tab-searchparams'
import { useStore as useTagStore } from '@/app/components/base/tag-management/store'
import { useAppContext } from '@/context/app-context'
import { useExternalApiPanel } from '@/context/external-api-panel-context'
// eslint-disable-next-line import/order
import { useQuery } from '@tanstack/react-query'
const Container = () => {
const { t } = useTranslation()
@ -47,7 +48,13 @@ const Container = () => {
defaultTab: 'dataset',
})
const containerRef = useRef<HTMLDivElement>(null)
const { data } = useSWR(activeTab === 'dataset' ? null : '/datasets/api-base-info', fetchDatasetApiBaseUrl)
const { data } = useQuery(
{
queryKey: ['datasetApiBaseInfo'],
queryFn: () => fetchDatasetApiBaseUrl('/datasets/api-base-info'),
enabled: activeTab !== 'dataset',
},
)
const [keywords, setKeywords] = useState('')
const [searchKeywords, setSearchKeywords] = useState('')

View File

@ -329,7 +329,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Col sticky>
<CodeGroup
title="Request"
tag="POST"
tag="GET"
label="/datasets"
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets?page=1&limit=20' \\\n--header 'Authorization: Bearer {api_key}'`}
>

View File

@ -329,7 +329,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Col sticky>
<CodeGroup
title="Request"
tag="POST"
tag="GET"
label="/datasets"
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets?page=1&limit=20' \\\n--header 'Authorization: Bearer {api_key}'`}
>

View File

@ -8,24 +8,27 @@ import Header from '@/app/components/header'
import { EventEmitterContextProvider } from '@/context/event-emitter'
import { ProviderContextProvider } from '@/context/provider-context'
import { ModalContextProvider } from '@/context/modal-context'
import { TanstackQueryIniter } from '@/context/query-client'
const Layout = ({ children }: { children: ReactNode }) => {
return (
<>
<GA gaType={GaType.admin} />
<SwrInitor>
<AppContextProvider>
<EventEmitterContextProvider>
<ProviderContextProvider>
<ModalContextProvider>
<HeaderWrapper>
<Header />
</HeaderWrapper>
{children}
</ModalContextProvider>
</ProviderContextProvider>
</EventEmitterContextProvider>
</AppContextProvider>
<TanstackQueryIniter>
<AppContextProvider>
<EventEmitterContextProvider>
<ProviderContextProvider>
<ModalContextProvider>
<HeaderWrapper>
<Header />
</HeaderWrapper>
{children}
</ModalContextProvider>
</ProviderContextProvider>
</EventEmitterContextProvider>
</AppContextProvider>
</TanstackQueryIniter>
</SwrInitor>
</>
)

View File

@ -676,6 +676,10 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
const [showDrawer, setShowDrawer] = useState<boolean>(false) // Whether to display the chat details drawer
const [currentConversation, setCurrentConversation] = useState<ChatConversationGeneralDetail | CompletionConversationGeneralDetail | undefined>() // Currently selected conversation
const isChatMode = appDetail.mode !== 'completion' // Whether the app is a chat app
const { setShowPromptLogModal, setShowAgentLogModal } = useAppStore(useShallow(state => ({
setShowPromptLogModal: state.setShowPromptLogModal,
setShowAgentLogModal: state.setShowAgentLogModal,
})))
// Annotated data needs to be highlighted
const renderTdValue = (value: string | number | null, isEmptyStyle: boolean, isHighlight = false, annotation?: LogAnnotation) => {
@ -699,6 +703,8 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
onRefresh()
setShowDrawer(false)
setCurrentConversation(undefined)
setShowPromptLogModal(false)
setShowAgentLogModal(false)
}
if (!logs)

View File

@ -1804,8 +1804,85 @@ exports[`build chat item tree and get thread messages should get thread messages
]
`;
exports[`build chat item tree and get thread messages should work with partial messages 1`] = `
exports[`build chat item tree and get thread messages should work with partial messages 1 1`] = `
[
{
"children": [
{
"agent_thoughts": [
{
"chain_id": null,
"created_at": 1726105799,
"files": [],
"id": "9730d587-9268-4683-9dd9-91a1cab9510b",
"message_id": "4c5d0841-1206-463e-95d8-71f812877658",
"observation": "",
"position": 1,
"thought": "I'll go with 112. Your turn!",
"tool": "",
"tool_input": "",
"tool_labels": {},
},
],
"children": [],
"content": "I'll go with 112. Your turn!",
"conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80",
"feedbackDisabled": false,
"id": "4c5d0841-1206-463e-95d8-71f812877658",
"input": {
"inputs": {},
"query": "99",
},
"isAnswer": true,
"log": [
{
"files": [],
"role": "user",
"text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38",
},
{
"files": [],
"role": "assistant",
"text": "Sure, I'll play! My number is 57. Your turn!",
},
{
"files": [],
"role": "user",
"text": "58",
},
{
"files": [],
"role": "assistant",
"text": "I choose 83. What's your next number?",
},
{
"files": [],
"role": "user",
"text": "99",
},
{
"files": [],
"role": "assistant",
"text": "I'll go with 112. Your turn!",
},
],
"message_files": [],
"more": {
"latency": "1.49",
"time": "09/11/2024 09:50 PM",
"tokens": 86,
},
"parentMessageId": "question-4c5d0841-1206-463e-95d8-71f812877658",
"siblingIndex": 0,
"workflow_run_id": null,
},
],
"content": "99",
"id": "question-4c5d0841-1206-463e-95d8-71f812877658",
"isAnswer": false,
"message_files": [],
"parentMessageId": "73bbad14-d915-499d-87bf-0df14d40779d",
},
{
"children": [
{
@ -2078,6 +2155,178 @@ exports[`build chat item tree and get thread messages should work with partial m
]
`;
exports[`build chat item tree and get thread messages should work with partial messages 2 1`] = `
[
{
"children": [
{
"children": [],
"content": "237.",
"id": "ebb73fe2-15de-46dd-aab5-75416d8448eb",
"isAnswer": true,
"parentMessageId": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb",
"siblingIndex": 0,
},
],
"content": "123",
"id": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb",
"isAnswer": false,
"parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418",
},
{
"children": [
{
"children": [],
"content": "My number is 256.",
"id": "3553d508-3850-462e-8594-078539f940f9",
"isAnswer": true,
"parentMessageId": "question-3553d508-3850-462e-8594-078539f940f9",
"siblingIndex": 1,
},
],
"content": "123",
"id": "question-3553d508-3850-462e-8594-078539f940f9",
"isAnswer": false,
"parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418",
},
{
"children": [
{
"children": [
{
"children": [
{
"children": [
{
"children": [
{
"children": [
{
"children": [
{
"children": [
{
"children": [
{
"children": [
{
"children": [
{
"children": [
{
"children": [
{
"children": [
{
"children": [
{
"children": [],
"content": "My number is 3e (approximately 8.15).",
"id": "9eac3bcc-8d3b-4e56-a12b-44c34cebc719",
"isAnswer": true,
"parentMessageId": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719",
"siblingIndex": 0,
},
],
"content": "e",
"id": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719",
"isAnswer": false,
"parentMessageId": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c",
},
],
"content": "My number is 2π (approximately 6.28).",
"id": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c",
"isAnswer": true,
"parentMessageId": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c",
"siblingIndex": 0,
},
],
"content": "π",
"id": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c",
"isAnswer": false,
"parentMessageId": "46a49bb9-0881-459e-8c6a-24d20ae48d2f",
},
],
"content": "My number is 145.",
"id": "46a49bb9-0881-459e-8c6a-24d20ae48d2f",
"isAnswer": true,
"parentMessageId": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f",
"siblingIndex": 0,
},
],
"content": "78",
"id": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f",
"isAnswer": false,
"parentMessageId": "3cded945-855a-4a24-aab7-43c7dd54664c",
},
],
"content": "My number is 7.89.",
"id": "3cded945-855a-4a24-aab7-43c7dd54664c",
"isAnswer": true,
"parentMessageId": "question-3cded945-855a-4a24-aab7-43c7dd54664c",
"siblingIndex": 0,
},
],
"content": "3.11",
"id": "question-3cded945-855a-4a24-aab7-43c7dd54664c",
"isAnswer": false,
"parentMessageId": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7",
},
],
"content": "My number is 22.",
"id": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7",
"isAnswer": true,
"parentMessageId": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7",
"siblingIndex": 0,
},
],
"content": "-5",
"id": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7",
"isAnswer": false,
"parentMessageId": "93bac05d-1470-4ac9-b090-fe21cd7c3d55",
},
],
"content": "My number is 4782.",
"id": "93bac05d-1470-4ac9-b090-fe21cd7c3d55",
"isAnswer": true,
"parentMessageId": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55",
"siblingIndex": 0,
},
],
"content": "3306",
"id": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55",
"isAnswer": false,
"parentMessageId": "9e51a13b-7780-4565-98dc-f2d8c3b1758f",
},
],
"content": "My number is 2048.",
"id": "9e51a13b-7780-4565-98dc-f2d8c3b1758f",
"isAnswer": true,
"parentMessageId": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f",
"siblingIndex": 0,
},
],
"content": "1024",
"id": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f",
"isAnswer": false,
"parentMessageId": "507f9df9-1f06-4a57-bb38-f00228c42c22",
},
],
"content": "My number is 259.",
"id": "507f9df9-1f06-4a57-bb38-f00228c42c22",
"isAnswer": true,
"parentMessageId": "question-507f9df9-1f06-4a57-bb38-f00228c42c22",
"siblingIndex": 2,
},
],
"content": "123",
"id": "question-507f9df9-1f06-4a57-bb38-f00228c42c22",
"isAnswer": false,
"parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418",
},
]
`;
exports[`build chat item tree and get thread messages should work with real world messages 1`] = `
[
{

View File

@ -0,0 +1,122 @@
[
{
"id": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb",
"content": "123",
"isAnswer": false,
"parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418"
},
{
"id": "ebb73fe2-15de-46dd-aab5-75416d8448eb",
"content": "237.",
"isAnswer": true,
"parentMessageId": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb"
},
{
"id": "question-3553d508-3850-462e-8594-078539f940f9",
"content": "123",
"isAnswer": false,
"parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418"
},
{
"id": "3553d508-3850-462e-8594-078539f940f9",
"content": "My number is 256.",
"isAnswer": true,
"parentMessageId": "question-3553d508-3850-462e-8594-078539f940f9"
},
{
"id": "question-507f9df9-1f06-4a57-bb38-f00228c42c22",
"content": "123",
"isAnswer": false,
"parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418"
},
{
"id": "507f9df9-1f06-4a57-bb38-f00228c42c22",
"content": "My number is 259.",
"isAnswer": true,
"parentMessageId": "question-507f9df9-1f06-4a57-bb38-f00228c42c22"
},
{
"id": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f",
"content": "1024",
"isAnswer": false,
"parentMessageId": "507f9df9-1f06-4a57-bb38-f00228c42c22"
},
{
"id": "9e51a13b-7780-4565-98dc-f2d8c3b1758f",
"content": "My number is 2048.",
"isAnswer": true,
"parentMessageId": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f"
},
{
"id": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55",
"content": "3306",
"isAnswer": false,
"parentMessageId": "9e51a13b-7780-4565-98dc-f2d8c3b1758f"
},
{
"id": "93bac05d-1470-4ac9-b090-fe21cd7c3d55",
"content": "My number is 4782.",
"isAnswer": true,
"parentMessageId": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55"
},
{
"id": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7",
"content": "-5",
"isAnswer": false,
"parentMessageId": "93bac05d-1470-4ac9-b090-fe21cd7c3d55"
},
{
"id": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7",
"content": "My number is 22.",
"isAnswer": true,
"parentMessageId": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7"
},
{
"id": "question-3cded945-855a-4a24-aab7-43c7dd54664c",
"content": "3.11",
"isAnswer": false,
"parentMessageId": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7"
},
{
"id": "3cded945-855a-4a24-aab7-43c7dd54664c",
"content": "My number is 7.89.",
"isAnswer": true,
"parentMessageId": "question-3cded945-855a-4a24-aab7-43c7dd54664c"
},
{
"id": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f",
"content": "78",
"isAnswer": false,
"parentMessageId": "3cded945-855a-4a24-aab7-43c7dd54664c"
},
{
"id": "46a49bb9-0881-459e-8c6a-24d20ae48d2f",
"content": "My number is 145.",
"isAnswer": true,
"parentMessageId": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f"
},
{
"id": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c",
"content": "π",
"isAnswer": false,
"parentMessageId": "46a49bb9-0881-459e-8c6a-24d20ae48d2f"
},
{
"id": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c",
"content": "My number is 2π (approximately 6.28).",
"isAnswer": true,
"parentMessageId": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c"
},
{
"id": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719",
"content": "e",
"isAnswer": false,
"parentMessageId": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c"
},
{
"id": "9eac3bcc-8d3b-4e56-a12b-44c34cebc719",
"content": "My number is 3e (approximately 8.15).",
"isAnswer": true,
"parentMessageId": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719"
}
]

View File

@ -7,6 +7,7 @@ import mixedTestMessages from './mixedTestMessages.json'
import multiRootNodesMessages from './multiRootNodesMessages.json'
import multiRootNodesWithLegacyTestMessages from './multiRootNodesWithLegacyTestMessages.json'
import realWorldMessages from './realWorldMessages.json'
import partialMessages from './partialMessages.json'
function visitNode(tree: ChatItemInTree | ChatItemInTree[], path: string): ChatItemInTree {
return get(tree, path)
@ -256,9 +257,15 @@ describe('build chat item tree and get thread messages', () => {
expect(threadMessages6_2).toMatchSnapshot()
})
const partialMessages = (realWorldMessages as ChatItemInTree[]).slice(-10)
const tree7 = buildChatItemTree(partialMessages)
it('should work with partial messages', () => {
const partialMessages1 = (realWorldMessages as ChatItemInTree[]).slice(-10)
const tree7 = buildChatItemTree(partialMessages1)
it('should work with partial messages 1', () => {
expect(tree7).toMatchSnapshot()
})
const partialMessages2 = (partialMessages as ChatItemInTree[])
const tree8 = buildChatItemTree(partialMessages2)
it('should work with partial messages 2', () => {
expect(tree8).toMatchSnapshot()
})
})

View File

@ -127,19 +127,16 @@ function buildChatItemTree(allMessages: IChatItem[]): ChatItemInTree[] {
lastAppendedLegacyAnswer = answerNode
}
else {
if (!parentMessageId)
if (
!parentMessageId
|| !allMessages.some(item => item.id === parentMessageId) // parent message might not be fetched yet, in this case we will append the question to the root nodes
)
rootNodes.push(questionNode)
else
map[parentMessageId]?.children!.push(questionNode)
}
}
// If no messages have parentMessageId=null (indicating a root node),
// then we likely have a partial chat history. In this case,
// use the first available message as the root node.
if (rootNodes.length === 0 && allMessages.length > 0)
rootNodes.push(map[allMessages[0]!.id]!)
return rootNodes
}

View File

@ -13,10 +13,19 @@ const FileInput = ({
const files = useStore(s => s.files)
const { handleLocalFileUpload } = useFile(fileConfig)
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const file = e.target.files?.[0]
const targetFiles = e.target.files
if (file)
handleLocalFileUpload(file)
if (targetFiles) {
if (fileConfig.number_limits) {
for (let i = 0; i < targetFiles.length; i++) {
if (i + 1 + files.length <= fileConfig.number_limits)
handleLocalFileUpload(targetFiles[i])
}
}
else {
handleLocalFileUpload(targetFiles[0])
}
}
}
const allowedFileTypes = fileConfig.allowed_file_types
@ -32,6 +41,7 @@ const FileInput = ({
onChange={handleChange}
accept={accept}
disabled={!!(fileConfig.number_limits && files.length >= fileConfig?.number_limits)}
multiple={!!fileConfig.number_limits && fileConfig.number_limits > 1}
/>
)
}

Some files were not shown because too many files have changed in this diff Show More