mirror of
https://github.com/langgenius/dify.git
synced 2026-01-25 14:25:57 +08:00
Compare commits
46 Commits
0.11.2
...
test/disab
| Author | SHA1 | Date | |
|---|---|---|---|
| 83b6abf4ad | |||
| ea0ebc020c | |||
| f358db9f02 | |||
| 94c9cadbd8 | |||
| 2ae6460f46 | |||
| 0067b16d1e | |||
| ec9f6220c9 | |||
| af53e2b6b0 | |||
| b42b333a72 | |||
| 99b0369f1b | |||
| d6ea1e2f12 | |||
| 4d6b45427c | |||
| 1be8365684 | |||
| c3d11c8ff6 | |||
| 8ff65abbc6 | |||
| bf4b6e5f80 | |||
| 25fda7adc5 | |||
| f3af7b5f35 | |||
| 33cfc56ad0 | |||
| 464cc26ccf | |||
| d18754afdd | |||
| beb7953d38 | |||
| fbfc811a44 | |||
| 7e66e5a713 | |||
| 07b5bbae06 | |||
| 3087913b74 | |||
| 904ea05bf6 | |||
| 6f4885d86d | |||
| 2dc29cfee3 | |||
| bd05df5cc5 | |||
| ee1f14621a | |||
| 58a9d9eb9a | |||
| bc1013dacf | |||
| 9f195df103 | |||
| 1cc7dc6360 | |||
| 328965ed7c | |||
| 133de9a087 | |||
| 7261384655 | |||
| 4718071cbb | |||
| 22be0816aa | |||
| 49e88322de | |||
| 14f3d44c37 | |||
| 0ba17ec116 | |||
| 79d59c004b | |||
| 873e9720e9 | |||
| de6d3e493c |
@ -42,6 +42,11 @@ REDIS_SENTINEL_USERNAME=
|
||||
REDIS_SENTINEL_PASSWORD=
|
||||
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
|
||||
|
||||
# redis Cluster configuration.
|
||||
REDIS_USE_CLUSTERS=false
|
||||
REDIS_CLUSTERS=
|
||||
REDIS_CLUSTERS_PASSWORD=
|
||||
|
||||
# PostgreSQL database configuration
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
@ -234,6 +239,10 @@ ANALYTICDB_ACCOUNT=testaccount
|
||||
ANALYTICDB_PASSWORD=testpassword
|
||||
ANALYTICDB_NAMESPACE=dify
|
||||
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
||||
ANALYTICDB_HOST=gp-test.aliyuncs.com
|
||||
ANALYTICDB_PORT=5432
|
||||
ANALYTICDB_MIN_CONNECTION=1
|
||||
ANALYTICDB_MAX_CONNECTION=5
|
||||
|
||||
# OpenSearch configuration
|
||||
OPENSEARCH_HOST=127.0.0.1
|
||||
|
||||
@ -18,12 +18,17 @@
|
||||
```
|
||||
|
||||
2. Copy `.env.example` to `.env`
|
||||
|
||||
```cli
|
||||
cp .env.example .env
|
||||
```
|
||||
3. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
bash for Linux
|
||||
```bash for Linux
|
||||
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
```
|
||||
|
||||
bash for Mac
|
||||
```bash for Mac
|
||||
secret_key=$(openssl rand -base64 42)
|
||||
sed -i '' "/^SECRET_KEY=/c\\
|
||||
@ -41,14 +46,6 @@
|
||||
poetry install
|
||||
```
|
||||
|
||||
In case of contributors missing to update dependencies for `pyproject.toml`, you can perform the following shell instead.
|
||||
|
||||
```bash
|
||||
poetry shell # activate current environment
|
||||
poetry add $(cat requirements.txt) # install dependencies of production and update pyproject.toml
|
||||
poetry add $(cat requirements-dev.txt) --group dev # install dependencies of development and update pyproject.toml
|
||||
```
|
||||
|
||||
6. Run migrate
|
||||
|
||||
Before the first launch, migrate the database to the latest version.
|
||||
|
||||
15
api/configs/middleware/cache/redis_config.py
vendored
15
api/configs/middleware/cache/redis_config.py
vendored
@ -68,3 +68,18 @@ class RedisConfig(BaseSettings):
|
||||
description="Socket timeout in seconds for Redis Sentinel connections",
|
||||
default=0.1,
|
||||
)
|
||||
|
||||
REDIS_USE_CLUSTERS: bool = Field(
|
||||
description="Enable Redis Clusters mode for high availability",
|
||||
default=False,
|
||||
)
|
||||
|
||||
REDIS_CLUSTERS: Optional[str] = Field(
|
||||
description="Comma-separated list of Redis Clusters nodes (host:port)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_CLUSTERS_PASSWORD: Optional[str] = Field(
|
||||
description="Password for Redis Clusters authentication (if required)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
@ -40,3 +40,11 @@ class AnalyticdbConfig(BaseModel):
|
||||
description="The password for accessing the specified namespace within the AnalyticDB instance"
|
||||
" (if namespace feature is enabled).",
|
||||
)
|
||||
ANALYTICDB_HOST: Optional[str] = Field(
|
||||
default=None, description="The host of the AnalyticDB instance you want to connect to."
|
||||
)
|
||||
ANALYTICDB_PORT: PositiveInt = Field(
|
||||
default=5432, description="The port of the AnalyticDB instance you want to connect to."
|
||||
)
|
||||
ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.")
|
||||
ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.")
|
||||
|
||||
@ -45,7 +45,7 @@ class RemoteFileUploadApi(Resource):
|
||||
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3)
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from urllib import parse
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, marshal_with, reqparse
|
||||
|
||||
@ -57,11 +59,12 @@ class MemberInviteEmailApi(Resource):
|
||||
token = RegisterService.invite_new_member(
|
||||
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
|
||||
)
|
||||
encoded_invitee_email = parse.quote(invitee_email)
|
||||
invitation_results.append(
|
||||
{
|
||||
"status": "success",
|
||||
"email": invitee_email,
|
||||
"url": f"{console_web_url}/activate?email={invitee_email}&token={token}",
|
||||
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
|
||||
}
|
||||
)
|
||||
except AccountAlreadyInTenantError:
|
||||
|
||||
@ -114,16 +114,9 @@ class BaseAgentRunner(AppRunner):
|
||||
# check if model supports stream tool call
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
|
||||
self.stream_tool_call = True
|
||||
else:
|
||||
self.stream_tool_call = False
|
||||
|
||||
# check if model supports vision
|
||||
if model_schema and ModelFeature.VISION in (model_schema.features or []):
|
||||
self.files = application_generate_entity.files
|
||||
else:
|
||||
self.files = []
|
||||
features = model_schema.features if model_schema and model_schema.features else []
|
||||
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
||||
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
||||
self.query = None
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
@ -250,7 +243,7 @@ class BaseAgentRunner(AppRunner):
|
||||
update prompt message tool
|
||||
"""
|
||||
# try to get tool runtime parameters
|
||||
tool_runtime_parameters = tool.get_runtime_parameters() or []
|
||||
tool_runtime_parameters = tool.get_runtime_parameters()
|
||||
|
||||
for parameter in tool_runtime_parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
|
||||
@ -16,9 +16,7 @@ class FileUploadConfigManager:
|
||||
file_upload_dict = config.get("file_upload")
|
||||
if file_upload_dict:
|
||||
if file_upload_dict.get("enabled"):
|
||||
transform_methods = file_upload_dict.get("allowed_file_upload_methods") or file_upload_dict.get(
|
||||
"allowed_upload_methods", []
|
||||
)
|
||||
transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
|
||||
data = {
|
||||
"image_config": {
|
||||
"number_limits": file_upload_dict["number_limits"],
|
||||
|
||||
@ -33,8 +33,8 @@ class BaseAppGenerator:
|
||||
tenant_id=app_config.tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
),
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
@ -47,8 +47,8 @@ class BaseAppGenerator:
|
||||
tenant_id=app_config.tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
),
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
|
||||
@ -381,7 +381,7 @@ class WorkflowCycleManage:
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
inputs=workflow_run.inputs_dict or {},
|
||||
inputs=workflow_run.inputs_dict,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
),
|
||||
)
|
||||
@ -428,7 +428,7 @@ class WorkflowCycleManage:
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -28,8 +28,8 @@ class FileUploadConfig(BaseModel):
|
||||
|
||||
image_config: Optional[ImageConfig] = None
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
number_limits: int = 0
|
||||
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
)
|
||||
|
||||
retries = 0
|
||||
stream = kwargs.pop("stream", False)
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
@ -52,6 +53,8 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
|
||||
if response.status_code not in STATUS_FORCELIST:
|
||||
if stream:
|
||||
return response.iter_bytes()
|
||||
return response
|
||||
else:
|
||||
logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list")
|
||||
|
||||
@ -29,6 +29,8 @@ from core.rag.splitter.fixed_text_splitter import (
|
||||
FixedRecursiveCharacterTextSplitter,
|
||||
)
|
||||
from core.rag.splitter.text_splitter import TextSplitter
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
@ -278,6 +280,19 @@ class IndexingRunner:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
|
||||
# delete image files and related db records
|
||||
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"Delete image_files failed while indexing_estimate, \
|
||||
image_upload_file_is: {}".format(upload_file_id)
|
||||
)
|
||||
db.session.delete(image_file)
|
||||
|
||||
if doc_form and doc_form == "qa_model":
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
@ -500,11 +515,7 @@ class IndexingRunner:
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
# delete Splitter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:]
|
||||
else:
|
||||
page_content = page_content
|
||||
document_node.page_content = page_content
|
||||
document_node.page_content = remove_leading_symbols(page_content)
|
||||
|
||||
if document_node.page_content:
|
||||
split_documents.append(document_node)
|
||||
|
||||
@ -325,14 +325,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
assistant_prompt_message.tool_calls.append(tool_call)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
# transform usage
|
||||
prompt_tokens = response.usage.input_tokens
|
||||
completion_tokens = response.usage.output_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
prompt_tokens = (response.usage and response.usage.input_tokens) or self.get_num_tokens(
|
||||
model, credentials, prompt_messages
|
||||
)
|
||||
|
||||
completion_tokens = (response.usage and response.usage.output_tokens) or self.get_num_tokens(
|
||||
model, credentials, [assistant_prompt_message]
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
@ -2,13 +2,11 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
# 3rd import
|
||||
import boto3
|
||||
import requests
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import (
|
||||
ClientError,
|
||||
@ -439,22 +437,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
if not message_content.data.startswith("data:"):
|
||||
# fetch image data from url
|
||||
try:
|
||||
url = message_content.data
|
||||
image_content = requests.get(url).content
|
||||
if "?" in url:
|
||||
url = url.split("?")[0]
|
||||
mime_type, _ = mimetypes.guess_type(url)
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
else:
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
image_content = base64.b64decode(base64_data)
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
image_content = base64.b64decode(base64_data)
|
||||
|
||||
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||
raise ValueError(
|
||||
|
||||
@ -691,8 +691,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
base_model_schema = cast(AIModelEntity, base_model_schema)
|
||||
|
||||
base_model_schema_features = base_model_schema.features or []
|
||||
base_model_schema_model_properties = base_model_schema.model_properties or {}
|
||||
base_model_schema_parameters_rules = base_model_schema.parameter_rules or []
|
||||
base_model_schema_model_properties = base_model_schema.model_properties
|
||||
base_model_schema_parameters_rules = base_model_schema.parameter_rules
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
|
||||
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -32,3 +32,4 @@ pricing:
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
||||
@ -36,3 +36,4 @@ pricing:
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
@ -36,17 +35,6 @@ from core.model_runtime.errors.invoke import (
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(
|
||||
@ -155,7 +143,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
|
||||
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
|
||||
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
@ -184,7 +172,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
config_kwargs = model_parameters.copy()
|
||||
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
|
||||
if schema := config_kwargs.pop("json_schema", None):
|
||||
try:
|
||||
schema = json.loads(schema)
|
||||
except:
|
||||
raise exceptions.InvalidArgument("Invalid JSON Schema")
|
||||
if tools:
|
||||
raise exceptions.InvalidArgument("gemini not support use Tools and JSON Schema at same time")
|
||||
config_kwargs["response_schema"] = schema
|
||||
config_kwargs["response_mime_type"] = "application/json"
|
||||
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
@ -22,6 +22,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
@ -86,6 +87,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
@ -153,6 +155,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
@ -196,6 +199,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
if completion_type is LLMMode.CHAT:
|
||||
endpoint_url = urljoin(endpoint_url, "api/chat")
|
||||
data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
if tools:
|
||||
data["tools"] = [self._convert_prompt_message_tool_to_dict(tool) for tool in tools]
|
||||
else:
|
||||
endpoint_url = urljoin(endpoint_url, "api/generate")
|
||||
first_prompt_message = prompt_messages[0]
|
||||
@ -232,7 +237,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
|
||||
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages, tools)
|
||||
|
||||
def _handle_generate_response(
|
||||
self,
|
||||
@ -241,6 +246,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
completion_type: LLMMode,
|
||||
response: requests.Response,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm completion response
|
||||
@ -253,14 +259,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
:return: llm result
|
||||
"""
|
||||
response_json = response.json()
|
||||
|
||||
tool_calls = []
|
||||
if completion_type is LLMMode.CHAT:
|
||||
message = response_json.get("message", {})
|
||||
response_content = message.get("content", "")
|
||||
response_tool_calls = message.get("tool_calls", [])
|
||||
tool_calls = [self._extract_response_tool_call(tool_call) for tool_call in response_tool_calls]
|
||||
else:
|
||||
response_content = response_json["response"]
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=response_content)
|
||||
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=tool_calls)
|
||||
|
||||
if "prompt_eval_count" in response_json and "eval_count" in response_json:
|
||||
# transform usage
|
||||
@ -405,9 +413,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
def _convert_prompt_message_tool_to_dict(self, tool: PromptMessageTool) -> dict:
|
||||
"""
|
||||
Convert PromptMessageTool to dict for Ollama API
|
||||
|
||||
:param tool: tool
|
||||
:return: tool dict
|
||||
"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for Ollama API
|
||||
|
||||
:param message: prompt message
|
||||
:return: message dict
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
@ -432,6 +459,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"role": "tool", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
@ -452,6 +482,29 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _extract_response_tool_call(self, response_tool_call: dict) -> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Extract response tool call
|
||||
"""
|
||||
tool_call = None
|
||||
if response_tool_call and "function" in response_tool_call:
|
||||
# Convert arguments to JSON string if it's a dict
|
||||
arguments = response_tool_call.get("function").get("arguments")
|
||||
if isinstance(arguments, dict):
|
||||
arguments = json.dumps(arguments)
|
||||
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.get("function").get("name"),
|
||||
arguments=arguments,
|
||||
)
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.get("function").get("name"),
|
||||
type="function",
|
||||
function=function,
|
||||
)
|
||||
|
||||
return tool_call
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
Get customizable model schema.
|
||||
@ -461,10 +514,15 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
:return: model schema
|
||||
"""
|
||||
extras = {}
|
||||
extras = {
|
||||
"features": [],
|
||||
}
|
||||
|
||||
if "vision_support" in credentials and credentials["vision_support"] == "true":
|
||||
extras["features"] = [ModelFeature.VISION]
|
||||
extras["features"].append(ModelFeature.VISION)
|
||||
if "function_call_support" in credentials and credentials["function_call_support"] == "true":
|
||||
extras["features"].append(ModelFeature.TOOL_CALL)
|
||||
extras["features"].append(ModelFeature.MULTI_TOOL_CALL)
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
|
||||
@ -96,3 +96,22 @@ model_credential_schema:
|
||||
label:
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
- variable: function_call_support
|
||||
label:
|
||||
zh_Hans: 是否支持函数调用
|
||||
en_US: Function call support
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
default: 'false'
|
||||
type: radio
|
||||
required: false
|
||||
options:
|
||||
- value: 'true'
|
||||
label:
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
- value: 'false'
|
||||
label:
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
|
||||
@ -615,19 +615,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||
|
||||
# o1 compatibility
|
||||
block_as_stream = False
|
||||
if model.startswith("o1"):
|
||||
if "max_tokens" in model_parameters:
|
||||
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
|
||||
del model_parameters["max_tokens"]
|
||||
|
||||
if stream:
|
||||
block_as_stream = True
|
||||
stream = False
|
||||
|
||||
if "stream_options" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stream_options"]
|
||||
|
||||
if "stop" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stop"]
|
||||
|
||||
@ -644,47 +636,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
if block_as_stream:
|
||||
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
|
||||
|
||||
return block_result
|
||||
|
||||
def _handle_chat_block_as_stream_response(
|
||||
self,
|
||||
block_result: LLMResult,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
text = block_result.message.content
|
||||
text = cast(str, text)
|
||||
|
||||
if stop:
|
||||
text = self.enforce_stop_tokens(text, stop)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=block_result.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=block_result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
finish_reason="stop",
|
||||
usage=block_result.usage,
|
||||
),
|
||||
)
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
@ -1178,8 +1130,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
base_model_schema = model_map[base_model]
|
||||
|
||||
base_model_schema_features = base_model_schema.features or []
|
||||
base_model_schema_model_properties = base_model_schema.model_properties or {}
|
||||
base_model_schema_parameters_rules = base_model_schema.parameter_rules or []
|
||||
base_model_schema_model_properties = base_model_schema.model_properties
|
||||
base_model_schema_parameters_rules = base_model_schema.parameter_rules
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
|
||||
@ -37,13 +37,14 @@ class OpenLLMGenerateMessage:
|
||||
class OpenLLMGenerate:
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
server_url: str,
|
||||
model_name: str,
|
||||
stream: bool,
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str],
|
||||
stop: list[str] | None = None,
|
||||
prompt_messages: list[OpenLLMGenerateMessage],
|
||||
user: str,
|
||||
user: str | None = None,
|
||||
) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]:
|
||||
if not server_url:
|
||||
raise InvalidAuthenticationError("Invalid server URL")
|
||||
|
||||
@ -45,19 +45,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._update_credential(model, credentials)
|
||||
|
||||
block_as_stream = False
|
||||
if model.startswith("openai/o1"):
|
||||
block_as_stream = True
|
||||
stop = None
|
||||
|
||||
# invoke block as stream
|
||||
if stream and block_as_stream:
|
||||
return self._generate_block_as_stream(
|
||||
model, credentials, prompt_messages, model_parameters, tools, stop, user
|
||||
)
|
||||
else:
|
||||
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _generate_block_as_stream(
|
||||
self,
|
||||
@ -69,9 +57,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
stop: Optional[list[str]] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Generator:
|
||||
resp: LLMResult = super()._generate(
|
||||
model, credentials, prompt_messages, model_parameters, tools, stop, False, user
|
||||
)
|
||||
resp = super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, False, user)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
|
||||
@ -65,6 +65,8 @@ class GTERerankModel(RerankModel):
|
||||
)
|
||||
|
||||
rerank_documents = []
|
||||
if not response.output:
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
for _, result in enumerate(response.output.results):
|
||||
# format document
|
||||
rerank_document = RerankDocument(
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
|
||||
|
||||
@ -62,5 +65,5 @@ class KeywordsModeration(Moderation):
|
||||
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
||||
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
|
||||
|
||||
def _check_keywords_in_value(self, keywords_list, value) -> bool:
|
||||
return any(keyword.lower() in value.lower() for keyword in keywords_list)
|
||||
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:
|
||||
return any(keyword.lower() in str(value).lower() for keyword in keywords_list)
|
||||
|
||||
@ -49,6 +49,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run")
|
||||
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
|
||||
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
|
||||
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
|
||||
@ -25,7 +25,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
LangSmithRunType,
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
@ -62,6 +62,16 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = trace_info.message_id or trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
||||
message_dotted_order = (
|
||||
generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
|
||||
)
|
||||
workflow_dotted_order = generate_dotted_order(
|
||||
trace_info.workflow_app_log_id or trace_info.workflow_run_id,
|
||||
trace_info.workflow_data.created_at,
|
||||
message_dotted_order,
|
||||
)
|
||||
|
||||
if trace_info.message_id:
|
||||
message_run = LangSmithRunModel(
|
||||
id=trace_info.message_id,
|
||||
@ -76,6 +86,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
},
|
||||
tags=["message", "workflow"],
|
||||
error=trace_info.error,
|
||||
trace_id=trace_id,
|
||||
dotted_order=message_dotted_order,
|
||||
)
|
||||
self.add_run(message_run)
|
||||
|
||||
@ -95,6 +107,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
error=trace_info.error,
|
||||
tags=["workflow"],
|
||||
parent_run_id=trace_info.message_id or None,
|
||||
trace_id=trace_id,
|
||||
dotted_order=workflow_dotted_order,
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
@ -177,6 +191,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
else:
|
||||
run_type = LangSmithRunType.tool
|
||||
|
||||
node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
|
||||
langsmith_run = LangSmithRunModel(
|
||||
total_tokens=node_total_tokens,
|
||||
name=node_type,
|
||||
@ -191,6 +206,9 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
},
|
||||
parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
|
||||
tags=["node_execution"],
|
||||
id=node_execution_id,
|
||||
trace_id=trace_id,
|
||||
dotted_order=node_dotted_order,
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
@ -43,3 +44,19 @@ def replace_text_with_content(data):
|
||||
return [replace_text_with_content(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def generate_dotted_order(
|
||||
run_id: str, start_time: Union[str, datetime], parent_dotted_order: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
generate dotted_order for langsmith
|
||||
"""
|
||||
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
|
||||
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
|
||||
current_segment = f"{timestamp}{run_id}"
|
||||
|
||||
if parent_dotted_order is None:
|
||||
return current_segment
|
||||
|
||||
return f"{parent_dotted_order}.{current_segment}"
|
||||
|
||||
@ -1,310 +1,62 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
_import_err_msg = (
|
||||
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||
)
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
|
||||
AnalyticdbVectorOpenAPI,
|
||||
AnalyticdbVectorOpenAPIConfig,
|
||||
)
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region_id: str
|
||||
instance_id: str
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = ("dify",)
|
||||
namespace_password: str = (None,)
|
||||
metrics: str = ("cosine",)
|
||||
read_timeout: int = 60000
|
||||
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
"access_key_id": self.access_key_id,
|
||||
"access_key_secret": self.access_key_secret,
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: AnalyticdbConfig):
|
||||
self._collection_name = collection_name.lower()
|
||||
try:
|
||||
from alibabacloud_gpdb20160503.client import Client
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self.config = config
|
||||
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
||||
self._client = Client(self._client_config)
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
cache_key = f"vector_indexing_{self.config.instance_id}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self.config.instance_id}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
self._initialize_vector_database()
|
||||
self._create_namespace_if_not_exists()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.init_vector_database(request)
|
||||
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.describe_namespace(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
request = gpdb_20160503_models.CreateNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
)
|
||||
self._client.create_namespace(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
)
|
||||
self._client.describe_collection(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
|
||||
full_text_retrieval_fields = "page_content"
|
||||
request = gpdb_20160503_models.CreateCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
collection=self._collection_name,
|
||||
dimension=embedding_dimension,
|
||||
metrics=self.config.metrics,
|
||||
metadata=metadata,
|
||||
full_text_retrieval_fields=full_text_retrieval_fields,
|
||||
)
|
||||
self._client.create_collection(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
def __init__(
|
||||
self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
|
||||
):
|
||||
super().__init__(collection_name)
|
||||
if api_config is not None:
|
||||
self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
|
||||
else:
|
||||
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ANALYTICDB
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection_if_not_exists(dimension)
|
||||
self.add_texts(texts, embeddings)
|
||||
self.analyticdb_vector._create_collection_if_not_exists(dimension)
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
rows=rows,
|
||||
)
|
||||
self._client.upsert_collection_data(request)
|
||||
def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
return self.analyticdb_vector.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
ids_str = f"({ids_str})"
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
self.analyticdb_vector.delete_by_ids(ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
self.analyticdb_vector.delete_by_metadata_field(key, value)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents
|
||||
return self.analyticdb_vector.search_by_vector(query_vector)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.metadata.get("vector"),
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents
|
||||
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
collection=self._collection_name,
|
||||
dbinstance_id=self.config.instance_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
region_id=self.config.region_id,
|
||||
)
|
||||
self._client.delete_collection(request)
|
||||
except Exception as e:
|
||||
raise e
|
||||
self.analyticdb_vector.delete()
|
||||
|
||||
|
||||
class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
@ -313,26 +65,9 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
|
||||
|
||||
# handle optional params
|
||||
if dify_config.ANALYTICDB_KEY_ID is None:
|
||||
raise ValueError("ANALYTICDB_KEY_ID should not be None")
|
||||
if dify_config.ANALYTICDB_KEY_SECRET is None:
|
||||
raise ValueError("ANALYTICDB_KEY_SECRET should not be None")
|
||||
if dify_config.ANALYTICDB_REGION_ID is None:
|
||||
raise ValueError("ANALYTICDB_REGION_ID should not be None")
|
||||
if dify_config.ANALYTICDB_INSTANCE_ID is None:
|
||||
raise ValueError("ANALYTICDB_INSTANCE_ID should not be None")
|
||||
if dify_config.ANALYTICDB_ACCOUNT is None:
|
||||
raise ValueError("ANALYTICDB_ACCOUNT should not be None")
|
||||
if dify_config.ANALYTICDB_PASSWORD is None:
|
||||
raise ValueError("ANALYTICDB_PASSWORD should not be None")
|
||||
if dify_config.ANALYTICDB_NAMESPACE is None:
|
||||
raise ValueError("ANALYTICDB_NAMESPACE should not be None")
|
||||
if dify_config.ANALYTICDB_NAMESPACE_PASSWORD is None:
|
||||
raise ValueError("ANALYTICDB_NAMESPACE_PASSWORD should not be None")
|
||||
return AnalyticdbVector(
|
||||
collection_name,
|
||||
AnalyticdbConfig(
|
||||
if dify_config.ANALYTICDB_HOST is None:
|
||||
# implemented through OpenAPI
|
||||
apiConfig = AnalyticdbVectorOpenAPIConfig(
|
||||
access_key_id=dify_config.ANALYTICDB_KEY_ID,
|
||||
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
|
||||
region_id=dify_config.ANALYTICDB_REGION_ID,
|
||||
@ -341,5 +76,22 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
|
||||
),
|
||||
)
|
||||
sqlConfig = None
|
||||
else:
|
||||
# implemented through sql
|
||||
sqlConfig = AnalyticdbVectorBySqlConfig(
|
||||
host=dify_config.ANALYTICDB_HOST,
|
||||
port=dify_config.ANALYTICDB_PORT,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT,
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
|
||||
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||
)
|
||||
apiConfig = None
|
||||
return AnalyticdbVector(
|
||||
collection_name,
|
||||
apiConfig,
|
||||
sqlConfig,
|
||||
)
|
||||
|
||||
@ -0,0 +1,309 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
_import_err_msg = (
|
||||
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||
)
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region_id: str
|
||||
instance_id: str
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = "dify"
|
||||
namespace_password: str = (None,)
|
||||
metrics: str = "cosine"
|
||||
read_timeout: int = 60000
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["access_key_id"]:
|
||||
raise ValueError("config ANALYTICDB_KEY_ID is required")
|
||||
if not values["access_key_secret"]:
|
||||
raise ValueError("config ANALYTICDB_KEY_SECRET is required")
|
||||
if not values["region_id"]:
|
||||
raise ValueError("config ANALYTICDB_REGION_ID is required")
|
||||
if not values["instance_id"]:
|
||||
raise ValueError("config ANALYTICDB_INSTANCE_ID is required")
|
||||
if not values["account"]:
|
||||
raise ValueError("config ANALYTICDB_ACCOUNT is required")
|
||||
if not values["account_password"]:
|
||||
raise ValueError("config ANALYTICDB_PASSWORD is required")
|
||||
if not values["namespace_password"]:
|
||||
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
"access_key_id": self.access_key_id,
|
||||
"access_key_secret": self.access_key_secret,
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPI:
|
||||
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
|
||||
try:
|
||||
from alibabacloud_gpdb20160503.client import Client
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self._collection_name = collection_name.lower()
|
||||
self.config = config
|
||||
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
||||
self._client = Client(self._client_config)
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
cache_key = f"vector_initialize_{self.config.instance_id}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
database_exist_cache_key = f"vector_initialize_{self.config.instance_id}"
|
||||
if redis_client.get(database_exist_cache_key):
|
||||
return
|
||||
self._initialize_vector_database()
|
||||
self._create_namespace_if_not_exists()
|
||||
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.init_vector_database(request)
|
||||
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.describe_namespace(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
request = gpdb_20160503_models.CreateNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
)
|
||||
self._client.create_namespace(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
)
|
||||
self._client.describe_collection(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
|
||||
full_text_retrieval_fields = "page_content"
|
||||
request = gpdb_20160503_models.CreateCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
collection=self._collection_name,
|
||||
dimension=embedding_dimension,
|
||||
metrics=self.config.metrics,
|
||||
metadata=metadata,
|
||||
full_text_retrieval_fields=full_text_retrieval_fields,
|
||||
)
|
||||
self._client.create_collection(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
rows=rows,
|
||||
)
|
||||
self._client.upsert_collection_data(request)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
ids_str = f"({ids_str})"
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.values.value,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.values.value,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
collection=self._collection_name,
|
||||
dbinstance_id=self.config.instance_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
region_id=self.config.region_id,
|
||||
)
|
||||
self._client.delete_collection(request)
|
||||
except Exception as e:
|
||||
raise e
|
||||
245
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
Normal file
245
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
Normal file
@ -0,0 +1,245 @@
|
||||
import json
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class AnalyticdbVectorBySqlConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
account: str
|
||||
account_password: str
|
||||
min_connection: int
|
||||
max_connection: int
|
||||
namespace: str = "dify"
|
||||
metrics: str = "cosine"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config ANALYTICDB_HOST is required")
|
||||
if not values["port"]:
|
||||
raise ValueError("config ANALYTICDB_PORT is required")
|
||||
if not values["account"]:
|
||||
raise ValueError("config ANALYTICDB_ACCOUNT is required")
|
||||
if not values["account_password"]:
|
||||
raise ValueError("config ANALYTICDB_PASSWORD is required")
|
||||
if not values["min_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MIN_CONNECTION is required")
|
||||
if not values["max_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MAX_CONNECTION is required")
|
||||
if values["min_connection"] > values["max_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION")
|
||||
return values
|
||||
|
||||
|
||||
class AnalyticdbVectorBySql:
|
||||
def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig):
|
||||
self._collection_name = collection_name.lower()
|
||||
self.databaseName = "knowledgebase"
|
||||
self.config = config
|
||||
self.table_name = f"{self.config.namespace}.{self._collection_name}"
|
||||
self.pool = None
|
||||
self._initialize()
|
||||
if not self.pool:
|
||||
self.pool = self._create_connection_pool()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
cache_key = f"vector_initialize_{self.config.host}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
database_exist_cache_key = f"vector_initialize_{self.config.host}"
|
||||
if redis_client.get(database_exist_cache_key):
|
||||
return
|
||||
self._initialize_vector_database()
|
||||
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _create_connection_pool(self):
|
||||
return psycopg2.pool.SimpleConnectionPool(
|
||||
self.config.min_connection,
|
||||
self.config.max_connection,
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
user=self.config.account,
|
||||
password=self.config.account_password,
|
||||
database=self.databaseName,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
conn.commit()
|
||||
self.pool.putconn(conn)
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
user=self.config.account,
|
||||
password=self.config.account_password,
|
||||
database="postgres",
|
||||
)
|
||||
conn.autocommit = True
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(f"CREATE DATABASE {self.databaseName}")
|
||||
except Exception as e:
|
||||
if "already exists" in str(e):
|
||||
return
|
||||
raise e
|
||||
finally:
|
||||
cur.close()
|
||||
conn.close()
|
||||
self.pool = self._create_connection_pool()
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
|
||||
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
|
||||
except Exception as e:
|
||||
if "already exists" not in str(e):
|
||||
raise e
|
||||
cur.execute(
|
||||
"CREATE OR REPLACE FUNCTION "
|
||||
"public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) "
|
||||
"RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ "
|
||||
"SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) "
|
||||
"FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) "
|
||||
"AS words_only;$function$"
|
||||
)
|
||||
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"CREATE TABLE IF NOT EXISTS {self.table_name}("
|
||||
f"id text PRIMARY KEY,"
|
||||
f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
|
||||
f"to_tsvector TSVECTOR"
|
||||
f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
|
||||
)
|
||||
if embedding_dimension is not None:
|
||||
index_name = f"{self._collection_name}_embedding_idx"
|
||||
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
|
||||
cur.execute(
|
||||
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
|
||||
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
|
||||
f"pq_enable=0, external_storage=0)"
|
||||
)
|
||||
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
values = []
|
||||
id_prefix = str(uuid.uuid4()) + "_"
|
||||
sql = f"""
|
||||
INSERT INTO {self.table_name}
|
||||
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
|
||||
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
|
||||
"""
|
||||
for i, doc in enumerate(documents):
|
||||
values.append(
|
||||
(
|
||||
id_prefix + str(i),
|
||||
doc.metadata.get("doc_id", str(uuid.uuid4())),
|
||||
embeddings[i],
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
doc.page_content,
|
||||
)
|
||||
)
|
||||
with self._get_cursor() as cur:
|
||||
psycopg2.extras.execute_batch(cur, sql, values)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,))
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),))
|
||||
except Exception as e:
|
||||
if "does not exist" not in str(e):
|
||||
raise e
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value))
|
||||
except Exception as e:
|
||||
if "does not exist" not in str(e):
|
||||
raise e
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
with self._get_cursor() as cur:
|
||||
query_vector_str = json.dumps(query_vector)
|
||||
query_vector_str = "{" + query_vector_str[1:-1] + "}"
|
||||
cur.execute(
|
||||
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
|
||||
f"t.page_content as page_content, t.metadata_ AS metadata_ "
|
||||
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
|
||||
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
|
||||
(query_vector_str,),
|
||||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, score, page_content, metadata = record
|
||||
if score > score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT id, vector, page_content, metadata_,
|
||||
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, page_content, metadata, score = record
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
@ -11,6 +11,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
|
||||
@ -43,11 +44,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
document_node.metadata["doc_id"] = doc_id
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
# delete Splitter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:].strip()
|
||||
else:
|
||||
page_content = page_content
|
||||
page_content = remove_leading_symbols(document_node.page_content).strip()
|
||||
if len(page_content) > 0:
|
||||
document_node.page_content = page_content
|
||||
split_documents.append(document_node)
|
||||
|
||||
@ -18,6 +18,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
|
||||
@ -53,11 +54,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
# delete Splitter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:]
|
||||
else:
|
||||
page_content = page_content
|
||||
document_node.page_content = page_content
|
||||
document_node.page_content = remove_leading_symbols(page_content)
|
||||
split_documents.append(document_node)
|
||||
all_documents.extend(split_documents)
|
||||
for i in range(0, len(all_documents), 10):
|
||||
|
||||
@ -36,23 +36,21 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
|
||||
:return:
|
||||
"""
|
||||
docs = []
|
||||
doc_id = []
|
||||
unique_documents = []
|
||||
doc_id = set()
|
||||
for document in documents:
|
||||
if document.metadata["doc_id"] not in doc_id:
|
||||
doc_id.append(document.metadata["doc_id"])
|
||||
docs.append(document.page_content)
|
||||
doc_id = document.metadata.get("doc_id")
|
||||
if doc_id not in doc_id:
|
||||
doc_id.add(doc_id)
|
||||
unique_documents.append(document)
|
||||
|
||||
documents = unique_documents
|
||||
|
||||
rerank_documents = []
|
||||
query_scores = self._calculate_keyword_score(query, documents)
|
||||
|
||||
query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
|
||||
|
||||
rerank_documents = []
|
||||
for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
|
||||
# format document
|
||||
score = (
|
||||
self.weights.vector_setting.vector_weight * query_vector_score
|
||||
+ self.weights.keyword_setting.keyword_weight * query_score
|
||||
@ -61,7 +59,8 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
continue
|
||||
document.metadata["score"] = score
|
||||
rerank_documents.append(document)
|
||||
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
|
||||
rerank_documents.sort(key=lambda x: x.metadata["score"], reverse=True)
|
||||
return rerank_documents[:top_n] if top_n else rerank_documents
|
||||
|
||||
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
@ -11,6 +11,17 @@ class DuckDuckGoVideoSearchTool(BuiltinTool):
|
||||
Tool for performing a video search using DuckDuckGo search engine.
|
||||
"""
|
||||
|
||||
IFRAME_TEMPLATE: ClassVar[str] = """
|
||||
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden; \
|
||||
max-width: 100%; border-radius: 8px;">
|
||||
<iframe
|
||||
style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;"
|
||||
src="{src}"
|
||||
frameborder="0"
|
||||
allowfullscreen>
|
||||
</iframe>
|
||||
</div>"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
query_dict = {
|
||||
"keywords": tool_parameters.get("query"),
|
||||
@ -26,6 +37,9 @@ class DuckDuckGoVideoSearchTool(BuiltinTool):
|
||||
# Remove None values to use API defaults
|
||||
query_dict = {k: v for k, v in query_dict.items() if v is not None}
|
||||
|
||||
# Get proxy URL from parameters
|
||||
proxy_url = tool_parameters.get("proxy_url", "").strip()
|
||||
|
||||
response = DDGS().videos(**query_dict)
|
||||
|
||||
# Create HTML result with embedded iframes
|
||||
@ -36,20 +50,21 @@ class DuckDuckGoVideoSearchTool(BuiltinTool):
|
||||
title = res.get("title", "")
|
||||
embed_html = res.get("embed_html", "")
|
||||
description = res.get("description", "")
|
||||
content_url = res.get("content", "")
|
||||
|
||||
# Modify iframe to be responsive
|
||||
if embed_html:
|
||||
# Replace fixed dimensions with responsive wrapper and iframe
|
||||
embed_html = """
|
||||
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden; \
|
||||
max-width: 100%; border-radius: 8px;">
|
||||
<iframe
|
||||
style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;"
|
||||
src="{src}"
|
||||
frameborder="0"
|
||||
allowfullscreen>
|
||||
</iframe>
|
||||
</div>""".format(src=res.get("embed_url", ""))
|
||||
# Handle TED.com videos
|
||||
if not embed_html and "ted.com/talks" in content_url:
|
||||
embed_url = content_url.replace("www.ted.com", "embed.ted.com")
|
||||
if proxy_url:
|
||||
embed_url = f"{proxy_url}{embed_url}"
|
||||
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
|
||||
|
||||
# Original YouTube/other platform handling
|
||||
elif embed_html:
|
||||
embed_url = res.get("embed_url", "")
|
||||
if proxy_url and embed_url:
|
||||
embed_url = f"{proxy_url}{embed_url}"
|
||||
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
|
||||
|
||||
markdown_result += f"{title}\n\n"
|
||||
markdown_result += f"{embed_html}\n\n"
|
||||
|
||||
@ -1,40 +1,43 @@
|
||||
identity:
|
||||
name: ddgo_video
|
||||
author: Assistant
|
||||
author: Tao Wang
|
||||
label:
|
||||
en_US: DuckDuckGo Video Search
|
||||
zh_Hans: DuckDuckGo 视频搜索
|
||||
description:
|
||||
human:
|
||||
en_US: Perform video searches on DuckDuckGo and get results with embedded videos.
|
||||
zh_Hans: 在 DuckDuckGo 上进行视频搜索并获取可嵌入的视频结果。
|
||||
llm: Perform video searches on DuckDuckGo and get results with embedded videos.
|
||||
en_US: Search and embedded videos.
|
||||
zh_Hans: 搜索并嵌入视频
|
||||
llm: Search videos on duckduckgo and embed videos in iframe
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query String
|
||||
zh_Hans: 查询语句
|
||||
type: string
|
||||
required: true
|
||||
human_description:
|
||||
en_US: Search Query
|
||||
zh_Hans: 搜索查询语句。
|
||||
zh_Hans: 搜索查询语句
|
||||
llm_description: Key words for searching
|
||||
form: llm
|
||||
- name: max_results
|
||||
label:
|
||||
en_US: Max Results
|
||||
zh_Hans: 最大结果数量
|
||||
type: number
|
||||
required: true
|
||||
default: 3
|
||||
minimum: 1
|
||||
maximum: 10
|
||||
label:
|
||||
en_US: Max Results
|
||||
zh_Hans: 最大结果数量
|
||||
human_description:
|
||||
en_US: The max results (1-10).
|
||||
zh_Hans: 最大结果数量(1-10)。
|
||||
en_US: The max results (1-10)
|
||||
zh_Hans: 最大结果数量(1-10)
|
||||
form: form
|
||||
- name: timelimit
|
||||
label:
|
||||
en_US: Result Time Limit
|
||||
zh_Hans: 结果时间限制
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
@ -54,14 +57,14 @@ parameters:
|
||||
label:
|
||||
en_US: Current Year
|
||||
zh_Hans: 今年
|
||||
label:
|
||||
en_US: Result Time Limit
|
||||
zh_Hans: 结果时间限制
|
||||
human_description:
|
||||
en_US: Use when querying results within a specific time range only.
|
||||
en_US: Query results within a specific time range only
|
||||
zh_Hans: 只查询一定时间范围内的结果时使用
|
||||
form: form
|
||||
- name: duration
|
||||
label:
|
||||
en_US: Video Duration
|
||||
zh_Hans: 视频时长
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
@ -77,10 +80,18 @@ parameters:
|
||||
label:
|
||||
en_US: Long (>20 minutes)
|
||||
zh_Hans: 长视频(>20分钟)
|
||||
label:
|
||||
en_US: Video Duration
|
||||
zh_Hans: 视频时长
|
||||
human_description:
|
||||
en_US: Filter videos by duration
|
||||
zh_Hans: 按时长筛选视频
|
||||
form: form
|
||||
- name: proxy_url
|
||||
label:
|
||||
en_US: Proxy URL
|
||||
zh_Hans: 视频代理地址
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
human_description:
|
||||
en_US: Proxy URL
|
||||
zh_Hans: 视频代理地址
|
||||
form: form
|
||||
|
||||
@ -17,7 +17,7 @@ class SendMailTool(BuiltinTool):
|
||||
invoke tools
|
||||
"""
|
||||
sender = self.runtime.credentials.get("email_account", "")
|
||||
email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
|
||||
email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
|
||||
password = self.runtime.credentials.get("email_password", "")
|
||||
smtp_server = self.runtime.credentials.get("smtp_server", "")
|
||||
if not smtp_server:
|
||||
|
||||
@ -18,7 +18,7 @@ class SendMailTool(BuiltinTool):
|
||||
invoke tools
|
||||
"""
|
||||
sender = self.runtime.credentials.get("email_account", "")
|
||||
email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
|
||||
email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
|
||||
password = self.runtime.credentials.get("email_password", "")
|
||||
smtp_server = self.runtime.credentials.get("smtp_server", "")
|
||||
if not smtp_server:
|
||||
|
||||
25
api/core/tools/provider/builtin/gitee_ai/tools/embedding.py
Normal file
25
api/core/tools/provider/builtin/gitee_ai/tools/embedding.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GiteeAIToolEmbedding(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {self.runtime.credentials['api_key']}",
|
||||
}
|
||||
|
||||
payload = {"inputs": tool_parameters.get("inputs")}
|
||||
model = tool_parameters.get("model", "bge-m3")
|
||||
url = f"https://ai.gitee.com/api/serverless/{model}/embeddings"
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message(f"Got Error Response:{response.text}")
|
||||
|
||||
return [self.create_text_message(response.content.decode("utf-8"))]
|
||||
@ -0,0 +1,37 @@
|
||||
identity:
|
||||
name: embedding
|
||||
author: gitee_ai
|
||||
label:
|
||||
en_US: embedding
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Generate word embeddings using Serverless-supported models (compatible with OpenAI)
|
||||
llm: This tool is used to generate word embeddings from text input.
|
||||
parameters:
|
||||
- name: model
|
||||
type: string
|
||||
required: true
|
||||
in: path
|
||||
description:
|
||||
en_US: Supported Embedding (compatible with OpenAI) interface models
|
||||
enum:
|
||||
- bge-m3
|
||||
- bge-large-zh-v1.5
|
||||
- bge-small-zh-v1.5
|
||||
label:
|
||||
en_US: Service Model
|
||||
zh_Hans: 服务模型
|
||||
default: bge-m3
|
||||
form: form
|
||||
- name: inputs
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Input Text
|
||||
zh_Hans: 输入文本
|
||||
human_description:
|
||||
en_US: The text input used to generate embeddings.
|
||||
zh_Hans: 用于生成词向量的输入文本。
|
||||
llm_description: This text input will be used to generate embeddings.
|
||||
form: llm
|
||||
@ -6,7 +6,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GiteeAITool(BuiltinTool):
|
||||
class GiteeAIToolText2Image(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
@ -40,6 +40,9 @@ class JSONParseTool(BuiltinTool):
|
||||
expr = parse(json_filter)
|
||||
result = [match.value for match in expr.find(input_data)]
|
||||
|
||||
if not result:
|
||||
return ""
|
||||
|
||||
if len(result) == 1:
|
||||
result = result[0]
|
||||
|
||||
|
||||
BIN
api/core/tools/provider/builtin/rapidapi/_assets/rapidapi.png
Normal file
BIN
api/core/tools/provider/builtin/rapidapi/_assets/rapidapi.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 62 KiB |
22
api/core/tools/provider/builtin/rapidapi/rapidapi.py
Normal file
22
api/core/tools/provider/builtin/rapidapi/rapidapi.py
Normal file
@ -0,0 +1,22 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.rapidapi.tools.google_news import GooglenewsTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class RapidapiProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
GooglenewsTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"language_region": "en-US",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
39
api/core/tools/provider/builtin/rapidapi/rapidapi.yaml
Normal file
39
api/core/tools/provider/builtin/rapidapi/rapidapi.yaml
Normal file
@ -0,0 +1,39 @@
|
||||
identity:
|
||||
name: rapidapi
|
||||
author: Steven Sun
|
||||
label:
|
||||
en_US: RapidAPI
|
||||
zh_Hans: RapidAPI
|
||||
description:
|
||||
en_US: RapidAPI is the world's largest API marketplace with over 1,000,000 developers and 10,000 APIs.
|
||||
zh_Hans: RapidAPI是全球最大的API市场,拥有超过100万开发人员和10000个API。
|
||||
icon: rapidapi.png
|
||||
tags:
|
||||
- news
|
||||
credentials_for_provider:
|
||||
x-rapidapi-host:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: x-rapidapi-host
|
||||
zh_Hans: x-rapidapi-host
|
||||
placeholder:
|
||||
en_US: Please input your x-rapidapi-host
|
||||
zh_Hans: 请输入你的 x-rapidapi-host
|
||||
help:
|
||||
en_US: Get your x-rapidapi-host from RapidAPI.
|
||||
zh_Hans: 从 RapidAPI 获取您的 x-rapidapi-host。
|
||||
url: https://rapidapi.com/
|
||||
x-rapidapi-key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: x-rapidapi-key
|
||||
zh_Hans: x-rapidapi-key
|
||||
placeholder:
|
||||
en_US: Please input your x-rapidapi-key
|
||||
zh_Hans: 请输入你的 x-rapidapi-key
|
||||
help:
|
||||
en_US: Get your x-rapidapi-key from RapidAPI.
|
||||
zh_Hans: 从 RapidAPI 获取您的 x-rapidapi-key。
|
||||
url: https://rapidapi.com/
|
||||
@ -0,0 +1,33 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GooglenewsTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
key = self.runtime.credentials.get("x-rapidapi-key", "")
|
||||
host = self.runtime.credentials.get("x-rapidapi-host", "")
|
||||
if not all([key, host]):
|
||||
raise ToolProviderCredentialValidationError("Please input correct x-rapidapi-key and x-rapidapi-host")
|
||||
headers = {"x-rapidapi-key": key, "x-rapidapi-host": host}
|
||||
lr = tool_parameters.get("language_region", "")
|
||||
url = f"https://{host}/latest?lr={lr}"
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
raise ToolInvokeError(f"Error {response.status_code}: {response.text}")
|
||||
return self.create_text_message(response.text)
|
||||
|
||||
def validate_credentials(self, parameters: dict[str, Any]) -> None:
|
||||
parameters["validate"] = True
|
||||
self._invoke(parameters)
|
||||
@ -0,0 +1,24 @@
|
||||
identity:
|
||||
name: google_news
|
||||
author: Steven Sun
|
||||
label:
|
||||
en_US: GoogleNews
|
||||
zh_Hans: 谷歌新闻
|
||||
description:
|
||||
human:
|
||||
en_US: google news is a news aggregator service developed by Google. It presents a continuous, customizable flow of articles organized from thousands of publishers and magazines.
|
||||
zh_Hans: 谷歌新闻是由谷歌开发的新闻聚合服务。它提供了一个持续的、可定制的文章流,这些文章是从成千上万的出版商和杂志中整理出来的。
|
||||
llm: A tool to get the latest news from Google News.
|
||||
parameters:
|
||||
- name: language_region
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Language and Region
|
||||
zh_Hans: 语言和地区
|
||||
human_description:
|
||||
en_US: The language and region determine the language and region of the search results, and its value is assigned according to the "National Language Code Comparison Table", such as en-US, which stands for English (United States); zh-CN, stands for Chinese (Simplified).
|
||||
zh_Hans: 语言和地区决定了搜索结果的语言和地区,其赋值按照《国家语言代码对照表》,形如en-US,代表英语(美国);zh-CN,代表中文(简体)。
|
||||
llm_description: The language and region determine the language and region of the search results, and its value is assigned according to the "National Language Code Comparison Table", such as en-US, which stands for English (United States); zh-CN, stands for Chinese (Simplified).
|
||||
default: en-US
|
||||
form: llm
|
||||
@ -5,6 +5,7 @@ from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
||||
from core.file.file_manager import download
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
@ -138,6 +139,7 @@ class ApiTool(Tool):
|
||||
path_params = {}
|
||||
body = {}
|
||||
cookies = {}
|
||||
files = []
|
||||
|
||||
# check parameters
|
||||
for parameter in self.api_bundle.openapi.get("parameters", []):
|
||||
@ -166,8 +168,12 @@ class ApiTool(Tool):
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
if name in parameters:
|
||||
# convert type
|
||||
body[name] = self._convert_body_property_type(property, parameters[name])
|
||||
if property.get("format") == "binary":
|
||||
f = parameters[name]
|
||||
files.append((name, (f.filename, download(f), f.mime_type)))
|
||||
else:
|
||||
# convert type
|
||||
body[name] = self._convert_body_property_type(property, parameters[name])
|
||||
elif name in required:
|
||||
raise ToolParameterValidationError(
|
||||
f"Missing required parameter {name} in operation {self.api_bundle.operation_id}"
|
||||
@ -182,7 +188,7 @@ class ApiTool(Tool):
|
||||
for name, value in path_params.items():
|
||||
url = url.replace(f"{{{name}}}", f"{value}")
|
||||
|
||||
# parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored
|
||||
# parse http body data if needed
|
||||
if "Content-Type" in headers:
|
||||
if headers["Content-Type"] == "application/json":
|
||||
body = json.dumps(body)
|
||||
@ -198,6 +204,7 @@ class ApiTool(Tool):
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
data=body,
|
||||
files=files,
|
||||
timeout=API_TOOL_DEFAULT_TIMEOUT,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
@ -261,7 +261,7 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
parameters = self.parameters or []
|
||||
parameters = parameters.copy()
|
||||
user_parameters = self.get_runtime_parameters() or []
|
||||
user_parameters = self.get_runtime_parameters()
|
||||
user_parameters = user_parameters.copy()
|
||||
|
||||
# override parameters
|
||||
|
||||
@ -55,7 +55,7 @@ class ToolEngine:
|
||||
# check if this tool has only one parameter
|
||||
parameters = [
|
||||
parameter
|
||||
for parameter in tool.get_runtime_parameters() or []
|
||||
for parameter in tool.get_runtime_parameters()
|
||||
if parameter.form == ToolParameter.ToolParameterForm.LLM
|
||||
]
|
||||
if parameters and len(parameters) == 1:
|
||||
|
||||
@ -127,7 +127,7 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
# get tool parameters
|
||||
tool_parameters = self.tool_runtime.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
|
||||
runtime_parameters = self.tool_runtime.get_runtime_parameters()
|
||||
# override parameters
|
||||
current_parameters = tool_parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
|
||||
@ -161,6 +161,9 @@ class ApiBasedToolSchemaParser:
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType:
|
||||
parameter = parameter or {}
|
||||
typ = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
if "type" in parameter:
|
||||
typ = parameter["type"]
|
||||
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||
|
||||
16
api/core/tools/utils/text_processing_utils.py
Normal file
16
api/core/tools/utils/text_processing_utils.py
Normal file
@ -0,0 +1,16 @@
|
||||
import re
|
||||
|
||||
|
||||
def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
Remove leading punctuation or symbols from the given text.
|
||||
|
||||
Args:
|
||||
text (str): The input text to process.
|
||||
|
||||
Returns:
|
||||
str: The text with leading punctuation or symbols removed.
|
||||
"""
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,\-./:;<=>?@\[\]^_`{|}~]+"
|
||||
return re.sub(pattern, "", text)
|
||||
@ -1,4 +1,6 @@
|
||||
import mimetypes
|
||||
from collections.abc import Sequence
|
||||
from email.message import Message
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import httpx
|
||||
@ -7,14 +9,6 @@ from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
from configs import dify_config
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
NON_FILE_CONTENT_TYPES = (
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"text/html",
|
||||
"text/plain",
|
||||
"application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
|
||||
class HttpRequestNodeAuthorizationConfig(BaseModel):
|
||||
type: Literal["basic", "bearer", "custom"]
|
||||
@ -93,13 +87,53 @@ class Response:
|
||||
|
||||
@property
|
||||
def is_file(self):
|
||||
content_type = self.content_type
|
||||
"""
|
||||
Determine if the response contains a file by checking:
|
||||
1. Content-Disposition header (RFC 6266)
|
||||
2. Content characteristics
|
||||
3. MIME type analysis
|
||||
"""
|
||||
content_type = self.content_type.split(";")[0].strip().lower()
|
||||
content_disposition = self.response.headers.get("content-disposition", "")
|
||||
|
||||
return "attachment" in content_disposition or (
|
||||
not any(non_file in content_type for non_file in NON_FILE_CONTENT_TYPES)
|
||||
and any(file_type in content_type for file_type in ("application/", "image/", "audio/", "video/"))
|
||||
)
|
||||
# Check if it's explicitly marked as an attachment
|
||||
if content_disposition:
|
||||
msg = Message()
|
||||
msg["content-disposition"] = content_disposition
|
||||
disp_type = msg.get_content_disposition() # Returns 'attachment', 'inline', or None
|
||||
filename = msg.get_filename() # Returns filename if present, None otherwise
|
||||
if disp_type == "attachment" or filename is not None:
|
||||
return True
|
||||
|
||||
# For application types, try to detect if it's a text-based format
|
||||
if content_type.startswith("application/"):
|
||||
# Common text-based application types
|
||||
if any(
|
||||
text_type in content_type
|
||||
for text_type in ("json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql")
|
||||
):
|
||||
return False
|
||||
|
||||
# Try to detect if content is text-based by sampling first few bytes
|
||||
try:
|
||||
# Sample first 1024 bytes for text detection
|
||||
content_sample = self.response.content[:1024]
|
||||
content_sample.decode("utf-8")
|
||||
# If we can decode as UTF-8 and find common text patterns, likely not a file
|
||||
text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ")
|
||||
if any(marker in content_sample for marker in text_markers):
|
||||
return False
|
||||
except UnicodeDecodeError:
|
||||
# If we can't decode as UTF-8, likely a binary file
|
||||
return True
|
||||
|
||||
# For other types, use MIME type analysis
|
||||
main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or ""))
|
||||
if main_type:
|
||||
return main_type.split("/")[0] in ("application", "image", "audio", "video")
|
||||
|
||||
# For unknown types, check if it's a media type
|
||||
return any(media_type in content_type for media_type in ("image/", "audio/", "video/"))
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import redis
|
||||
from redis.cluster import ClusterNode, RedisCluster
|
||||
from redis.connection import Connection, SSLConnection
|
||||
from redis.sentinel import Sentinel
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class RedisClientWrapper(redis.Redis):
|
||||
class RedisClientWrapper:
|
||||
"""
|
||||
A wrapper class for the Redis client that addresses the issue where the global
|
||||
`redis_client` variable cannot be updated when a new Redis instance is returned
|
||||
@ -71,6 +72,12 @@ def init_app(app):
|
||||
)
|
||||
master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
|
||||
redis_client.initialize(master)
|
||||
elif dify_config.REDIS_USE_CLUSTERS:
|
||||
nodes = [
|
||||
ClusterNode(host=node.split(":")[0], port=int(node.split.split(":")[1]))
|
||||
for node in dify_config.REDIS_CLUSTERS.split(",")
|
||||
]
|
||||
redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD))
|
||||
else:
|
||||
redis_params.update(
|
||||
{
|
||||
|
||||
@ -166,9 +166,9 @@ def _build_from_remote_url(
|
||||
|
||||
|
||||
def _get_remote_file_info(url: str):
|
||||
mime_type = mimetypes.guess_type(url)[0] or ""
|
||||
file_size = -1
|
||||
filename = url.split("/")[-1].split("?")[0] or "unknown_file"
|
||||
mime_type = mimetypes.guess_type(filename)[0] or ""
|
||||
|
||||
resp = ssrf_proxy.head(url, follow_redirects=True)
|
||||
if resp.status_code == httpx.codes.OK:
|
||||
@ -233,10 +233,10 @@ def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool:
|
||||
if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM:
|
||||
return False
|
||||
|
||||
if config.allowed_extensions and file.extension not in config.allowed_extensions:
|
||||
if config.allowed_file_extensions and file.extension not in config.allowed_file_extensions:
|
||||
return False
|
||||
|
||||
if config.allowed_upload_methods and file.transfer_method not in config.allowed_upload_methods:
|
||||
if config.allowed_file_upload_methods and file.transfer_method not in config.allowed_file_upload_methods:
|
||||
return False
|
||||
|
||||
if file.type == FileType.IMAGE and config.image_config:
|
||||
|
||||
@ -169,9 +169,9 @@ class Workflow(db.Model):
|
||||
)
|
||||
features["file_upload"]["enabled"] = image_enabled
|
||||
features["file_upload"]["number_limits"] = image_number_limits
|
||||
features["file_upload"]["allowed_upload_methods"] = image_transfer_methods
|
||||
features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods
|
||||
features["file_upload"]["allowed_file_types"] = ["image"]
|
||||
features["file_upload"]["allowed_extensions"] = []
|
||||
features["file_upload"]["allowed_file_extensions"] = []
|
||||
del features["file_upload"]["image"]
|
||||
self._features = json.dumps(features)
|
||||
return self._features
|
||||
|
||||
@ -341,7 +341,7 @@ class AppService:
|
||||
if not app_model_config:
|
||||
return meta
|
||||
|
||||
agent_config = app_model_config.agent_mode_dict or {}
|
||||
agent_config = app_model_config.agent_mode_dict
|
||||
|
||||
# get all tools
|
||||
tools = agent_config.get("tools", [])
|
||||
|
||||
@ -242,7 +242,7 @@ class ToolTransformService:
|
||||
# get tool parameters
|
||||
parameters = tool.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = tool.get_runtime_parameters() or []
|
||||
runtime_parameters = tool.get_runtime_parameters()
|
||||
# override parameters
|
||||
current_parameters = parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
|
||||
@ -51,8 +51,8 @@ class WebsiteService:
|
||||
excludes = options.get("excludes").split(",") if options.get("excludes") else []
|
||||
params = {
|
||||
"crawlerOptions": {
|
||||
"includes": includes or [],
|
||||
"excludes": excludes or [],
|
||||
"includes": includes,
|
||||
"excludes": excludes,
|
||||
"generateImgAltText": True,
|
||||
"limit": options.get("limit", 1),
|
||||
"returnOnlyUrls": False,
|
||||
|
||||
@ -78,6 +78,7 @@ def clean_dataset_task(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: {}".format(upload_file_id)
|
||||
)
|
||||
db.session.delete(image_file)
|
||||
db.session.delete(segment)
|
||||
|
||||
db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete()
|
||||
|
||||
@ -51,6 +51,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: {}".format(upload_file_id)
|
||||
)
|
||||
db.session.delete(image_file)
|
||||
db.session.delete(segment)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@ -31,7 +31,7 @@ def test_invoke_model(setup_google_mock):
|
||||
model = GoogleLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="gemini-pro",
|
||||
model="gemini-1.5-pro",
|
||||
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
@ -48,7 +48,7 @@ def test_invoke_model(setup_google_mock):
|
||||
]
|
||||
),
|
||||
],
|
||||
model_parameters={"temperature": 0.5, "top_p": 1.0, "max_tokens_to_sample": 2048},
|
||||
model_parameters={"temperature": 0.5, "top_p": 1.0, "max_output_tokens": 2048},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
@ -63,7 +63,7 @@ def test_invoke_stream_model(setup_google_mock):
|
||||
model = GoogleLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="gemini-pro",
|
||||
model="gemini-1.5-pro",
|
||||
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
@ -80,7 +80,7 @@ def test_invoke_stream_model(setup_google_mock):
|
||||
]
|
||||
),
|
||||
],
|
||||
model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens_to_sample": 2048},
|
||||
model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens": 2048},
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
@ -99,7 +99,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
|
||||
model = GoogleLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="gemini-pro-vision",
|
||||
model="gemini-1.5-pro",
|
||||
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
@ -128,7 +128,7 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
|
||||
model = GoogleLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="gemini-pro-vision",
|
||||
model="gemini-1.5-pro",
|
||||
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content="You are a helpful AI assistant."),
|
||||
@ -164,7 +164,7 @@ def test_get_num_tokens():
|
||||
model = GoogleLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="gemini-pro",
|
||||
model="gemini-1.5-pro",
|
||||
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
|
||||
@ -1,27 +1,43 @@
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbConfig, AnalyticdbVector
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
|
||||
|
||||
|
||||
class AnalyticdbVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
def __init__(self, config_type: str):
|
||||
super().__init__()
|
||||
# Analyticdb requires collection_name length less than 60.
|
||||
# it's ok for normal usage.
|
||||
self.collection_name = self.collection_name.replace("_test", "")
|
||||
self.vector = AnalyticdbVector(
|
||||
collection_name=self.collection_name,
|
||||
config=AnalyticdbConfig(
|
||||
access_key_id="test_key_id",
|
||||
access_key_secret="test_key_secret",
|
||||
region_id="test_region",
|
||||
instance_id="test_id",
|
||||
account="test_account",
|
||||
account_password="test_passwd",
|
||||
namespace="difytest_namespace",
|
||||
collection="difytest_collection",
|
||||
namespace_password="test_passwd",
|
||||
),
|
||||
)
|
||||
if config_type == "sql":
|
||||
self.vector = AnalyticdbVector(
|
||||
collection_name=self.collection_name,
|
||||
sql_config=AnalyticdbVectorBySqlConfig(
|
||||
host="test_host",
|
||||
port=5432,
|
||||
account="test_account",
|
||||
account_password="test_passwd",
|
||||
namespace="difytest_namespace",
|
||||
),
|
||||
api_config=None,
|
||||
)
|
||||
else:
|
||||
self.vector = AnalyticdbVector(
|
||||
collection_name=self.collection_name,
|
||||
sql_config=None,
|
||||
api_config=AnalyticdbVectorOpenAPIConfig(
|
||||
access_key_id="test_key_id",
|
||||
access_key_secret="test_key_secret",
|
||||
region_id="test_region",
|
||||
instance_id="test_id",
|
||||
account="test_account",
|
||||
account_password="test_passwd",
|
||||
namespace="difytest_namespace",
|
||||
collection="difytest_collection",
|
||||
namespace_password="test_passwd",
|
||||
),
|
||||
)
|
||||
|
||||
def run_all_tests(self):
|
||||
self.vector.delete()
|
||||
@ -29,4 +45,5 @@ class AnalyticdbVectorTest(AbstractVectorTest):
|
||||
|
||||
|
||||
def test_chroma_vector(setup_mock_redis):
|
||||
AnalyticdbVectorTest().run_all_tests()
|
||||
AnalyticdbVectorTest("api").run_all_tests()
|
||||
AnalyticdbVectorTest("sql").run_all_tests()
|
||||
|
||||
@ -27,8 +27,8 @@ NEW_VERSION_WORKFLOW_FEATURES = {
|
||||
"file_upload": {
|
||||
"enabled": True,
|
||||
"allowed_file_types": ["image"],
|
||||
"allowed_extensions": [],
|
||||
"allowed_upload_methods": ["remote_url", "local_file"],
|
||||
"allowed_file_extensions": [],
|
||||
"allowed_file_upload_methods": ["remote_url", "local_file"],
|
||||
"number_limits": 6,
|
||||
},
|
||||
"opening_statement": "",
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
import json
|
||||
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType, FileUploadConfig
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
def test_file_loads_and_dumps():
|
||||
@ -38,3 +41,40 @@ def test_file_to_dict():
|
||||
file_dict = file.to_dict()
|
||||
assert "_extra_config" not in file_dict
|
||||
assert "url" in file_dict
|
||||
|
||||
|
||||
def test_workflow_features_with_image():
|
||||
# Create a feature dict that mimics the old structure with image config
|
||||
features = {
|
||||
"file_upload": {
|
||||
"image": {"enabled": True, "number_limits": 5, "transfer_methods": ["remote_url", "local_file"]}
|
||||
}
|
||||
}
|
||||
|
||||
# Create a workflow instance with the features
|
||||
workflow = Workflow(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
type="chat",
|
||||
version="1.0",
|
||||
graph="{}",
|
||||
features=json.dumps(features),
|
||||
created_by="user-1",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Get the converted features through the property
|
||||
converted_features = json.loads(workflow.features)
|
||||
|
||||
# Create FileUploadConfig from the converted features
|
||||
file_upload_config = FileUploadConfig.model_validate(converted_features["file_upload"])
|
||||
|
||||
# Validate the config
|
||||
assert file_upload_config.number_limits == 5
|
||||
assert list(file_upload_config.allowed_file_types) == [FileType.IMAGE]
|
||||
assert list(file_upload_config.allowed_file_upload_methods) == [
|
||||
FileTransferMethod.REMOTE_URL,
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
]
|
||||
assert list(file_upload_config.allowed_file_extensions) == []
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
from core.model_manager import LBModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -38,6 +40,9 @@ def lb_model_manager():
|
||||
|
||||
|
||||
def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
|
||||
# initialize redis client
|
||||
redis_client.initialize(redis.Redis())
|
||||
|
||||
assert len(lb_model_manager._load_balancing_configs) == 3
|
||||
|
||||
config1 = lb_model_manager._load_balancing_configs[0]
|
||||
@ -55,12 +60,13 @@ def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
|
||||
start_index += 1
|
||||
return start_index
|
||||
|
||||
mocker.patch("redis.Redis.incr", side_effect=incr)
|
||||
mocker.patch("redis.Redis.set", return_value=None)
|
||||
mocker.patch("redis.Redis.expire", return_value=None)
|
||||
with (
|
||||
patch.object(redis_client, "incr", side_effect=incr),
|
||||
patch.object(redis_client, "set", return_value=None),
|
||||
patch.object(redis_client, "expire", return_value=None),
|
||||
):
|
||||
config = lb_model_manager.fetch_next()
|
||||
assert config == config2
|
||||
|
||||
config = lb_model_manager.fetch_next()
|
||||
assert config == config2
|
||||
|
||||
config = lb_model_manager.fetch_next()
|
||||
assert config == config3
|
||||
config = lb_model_manager.fetch_next()
|
||||
assert config == config3
|
||||
|
||||
@ -0,0 +1,140 @@
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.http_request.entities import Response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response():
|
||||
response = Mock(spec=httpx.Response)
|
||||
response.headers = {}
|
||||
return response
|
||||
|
||||
|
||||
def test_is_file_with_attachment_disposition(mock_response):
|
||||
"""Test is_file when content-disposition header contains 'attachment'"""
|
||||
mock_response.headers = {"content-disposition": "attachment; filename=test.pdf", "content-type": "application/pdf"}
|
||||
response = Response(mock_response)
|
||||
assert response.is_file
|
||||
|
||||
|
||||
def test_is_file_with_filename_disposition(mock_response):
|
||||
"""Test is_file when content-disposition header contains filename parameter"""
|
||||
mock_response.headers = {"content-disposition": "inline; filename=test.pdf", "content-type": "application/pdf"}
|
||||
response = Response(mock_response)
|
||||
assert response.is_file
|
||||
|
||||
|
||||
@pytest.mark.parametrize("content_type", ["application/pdf", "image/jpeg", "audio/mp3", "video/mp4"])
|
||||
def test_is_file_with_file_content_types(mock_response, content_type):
|
||||
"""Test is_file with various file content types"""
|
||||
mock_response.headers = {"content-type": content_type}
|
||||
# Mock binary content
|
||||
type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512))
|
||||
response = Response(mock_response)
|
||||
assert response.is_file, f"Content type {content_type} should be identified as a file"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content_type",
|
||||
[
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-www-form-urlencoded",
|
||||
"application/yaml",
|
||||
"application/graphql",
|
||||
],
|
||||
)
|
||||
def test_text_based_application_types(mock_response, content_type):
|
||||
"""Test common text-based application types are not identified as files"""
|
||||
mock_response.headers = {"content-type": content_type}
|
||||
response = Response(mock_response)
|
||||
assert not response.is_file, f"Content type {content_type} should not be identified as a file"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("content", "content_type"),
|
||||
[
|
||||
(b'{"key": "value"}', "application/octet-stream"),
|
||||
(b"[1, 2, 3]", "application/unknown"),
|
||||
(b"function test() {}", "application/x-unknown"),
|
||||
(b"<root>test</root>", "application/binary"),
|
||||
(b"var x = 1;", "application/data"),
|
||||
],
|
||||
)
|
||||
def test_content_based_detection(mock_response, content, content_type):
|
||||
"""Test content-based detection for text-like content"""
|
||||
mock_response.headers = {"content-type": content_type}
|
||||
type(mock_response).content = PropertyMock(return_value=content)
|
||||
response = Response(mock_response)
|
||||
assert not response.is_file, f"Content {content} with type {content_type} should not be identified as a file"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("content", "content_type"),
|
||||
[
|
||||
(bytes([0x00, 0xFF] * 512), "application/octet-stream"),
|
||||
(bytes([0x89, 0x50, 0x4E, 0x47]), "application/unknown"), # PNG magic numbers
|
||||
(bytes([0xFF, 0xD8, 0xFF]), "application/binary"), # JPEG magic numbers
|
||||
],
|
||||
)
|
||||
def test_binary_content_detection(mock_response, content, content_type):
|
||||
"""Test content-based detection for binary content"""
|
||||
mock_response.headers = {"content-type": content_type}
|
||||
type(mock_response).content = PropertyMock(return_value=content)
|
||||
response = Response(mock_response)
|
||||
assert response.is_file, f"Binary content with type {content_type} should be identified as a file"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("content_type", "expected_main_type"),
|
||||
[
|
||||
("x-world/x-vrml", "model"), # VRML 3D model
|
||||
("font/ttf", "application"), # TrueType font
|
||||
("text/csv", "text"), # CSV text file
|
||||
("unknown/xyz", None), # Unknown type
|
||||
],
|
||||
)
|
||||
def test_mimetype_based_detection(mock_response, content_type, expected_main_type):
|
||||
"""Test detection using mimetypes.guess_type for non-application content types"""
|
||||
mock_response.headers = {"content-type": content_type}
|
||||
type(mock_response).content = PropertyMock(return_value=bytes([0x00])) # Dummy content
|
||||
|
||||
with patch("core.workflow.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type:
|
||||
# Mock the return value based on expected_main_type
|
||||
if expected_main_type:
|
||||
mock_guess_type.return_value = (f"{expected_main_type}/subtype", None)
|
||||
else:
|
||||
mock_guess_type.return_value = (None, None)
|
||||
|
||||
response = Response(mock_response)
|
||||
|
||||
# Check if the result matches our expectation
|
||||
if expected_main_type in ("application", "image", "audio", "video"):
|
||||
assert response.is_file, f"Content type {content_type} should be identified as a file"
|
||||
else:
|
||||
assert not response.is_file, f"Content type {content_type} should not be identified as a file"
|
||||
|
||||
# Verify that guess_type was called
|
||||
mock_guess_type.assert_called_once()
|
||||
|
||||
|
||||
def test_is_file_with_inline_disposition(mock_response):
|
||||
"""Test is_file when content-disposition is 'inline'"""
|
||||
mock_response.headers = {"content-disposition": "inline", "content-type": "application/pdf"}
|
||||
# Mock binary content
|
||||
type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512))
|
||||
response = Response(mock_response)
|
||||
assert response.is_file
|
||||
|
||||
|
||||
def test_is_file_with_no_content_disposition(mock_response):
|
||||
"""Test is_file when no content-disposition header is present"""
|
||||
mock_response.headers = {"content-type": "application/pdf"}
|
||||
# Mock binary content
|
||||
type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512))
|
||||
response = Response(mock_response)
|
||||
assert response.is_file
|
||||
20
api/tests/unit_tests/utils/test_text_processing.py
Normal file
20
api/tests/unit_tests/utils/test_text_processing.py
Normal file
@ -0,0 +1,20 @@
|
||||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("input_text", "expected_output"),
|
||||
[
|
||||
("...Hello, World!", "Hello, World!"),
|
||||
("。测试中文标点", "测试中文标点"),
|
||||
("!@#Test symbols", "Test symbols"),
|
||||
("Hello, World!", "Hello, World!"),
|
||||
("", ""),
|
||||
(" ", " "),
|
||||
],
|
||||
)
|
||||
def test_remove_leading_symbols(input_text, expected_output):
|
||||
assert remove_leading_symbols(input_text) == expected_output
|
||||
@ -75,7 +75,8 @@ SECRET_KEY=sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U
|
||||
|
||||
# Password for admin user initialization.
|
||||
# If left unset, admin user will not be prompted for a password
|
||||
# when creating the initial admin account.
|
||||
# when creating the initial admin account.
|
||||
# The length of the password cannot exceed 30 charactors.
|
||||
INIT_PASSWORD=
|
||||
|
||||
# Deployment environment.
|
||||
@ -239,6 +240,12 @@ REDIS_SENTINEL_USERNAME=
|
||||
REDIS_SENTINEL_PASSWORD=
|
||||
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
|
||||
|
||||
# List of Redis Cluster nodes. If Cluster mode is enabled, provide at least one Cluster IP and port.
|
||||
# Format: `<Cluster1_ip>:<Cluster1_port>,<Cluster2_ip>:<Cluster2_port>,<Cluster3_ip>:<Cluster3_port>`
|
||||
REDIS_USE_CLUSTERS=false
|
||||
REDIS_CLUSTERS=
|
||||
REDIS_CLUSTERS_PASSWORD=
|
||||
|
||||
# ------------------------------
|
||||
# Celery Configuration
|
||||
# ------------------------------
|
||||
@ -450,6 +457,10 @@ ANALYTICDB_ACCOUNT=testaccount
|
||||
ANALYTICDB_PASSWORD=testpassword
|
||||
ANALYTICDB_NAMESPACE=dify
|
||||
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
||||
ANALYTICDB_HOST=gp-test.aliyuncs.com
|
||||
ANALYTICDB_PORT=5432
|
||||
ANALYTICDB_MIN_CONNECTION=1
|
||||
ANALYTICDB_MAX_CONNECTION=5
|
||||
|
||||
# TiDB vector configurations, only available when VECTOR_STORE is `tidb`
|
||||
TIDB_VECTOR_HOST=tidb
|
||||
@ -558,7 +569,7 @@ UPLOAD_FILE_SIZE_LIMIT=15
|
||||
# The maximum number of files that can be uploaded at a time, default 5.
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
|
||||
# ETl type, support: `dify`, `Unstructured`
|
||||
# ETL type, support: `dify`, `Unstructured`
|
||||
# `dify` Dify's proprietary file extraction scheme
|
||||
# `Unstructured` Unstructured.io file extraction scheme
|
||||
ETL_TYPE=dify
|
||||
|
||||
@ -36,7 +36,7 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T
|
||||
- Navigate to the `docker` directory.
|
||||
- Ensure the `middleware.env` file is created by running `cp middleware.env.example middleware.env` (refer to the `middleware.env.example` file).
|
||||
2. **Running Middleware Services**:
|
||||
- Execute `docker-compose -f docker-compose.middleware.yaml up -d` to start the middleware services.
|
||||
- Execute `docker-compose -f docker-compose.middleware.yaml up --env-file middleware.env -d` to start the middleware services.
|
||||
|
||||
### Migration for Existing Users
|
||||
|
||||
|
||||
@ -29,11 +29,13 @@ services:
|
||||
redis:
|
||||
image: redis:6-alpine
|
||||
restart: always
|
||||
environment:
|
||||
REDISCLI_AUTH: ${REDIS_PASSWORD:-difyai123456}
|
||||
volumes:
|
||||
# Mount the redis data directory to the container.
|
||||
- ${REDIS_HOST_VOLUME:-./volumes/redis/data}:/data
|
||||
# Set the redis password when startup redis server.
|
||||
command: redis-server --requirepass difyai123456
|
||||
command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456}
|
||||
ports:
|
||||
- "${EXPOSE_REDIS_PORT:-6379}:6379"
|
||||
healthcheck:
|
||||
|
||||
@ -55,6 +55,9 @@ x-shared-env: &shared-api-worker-env
|
||||
REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-}
|
||||
REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-}
|
||||
REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1}
|
||||
REDIS_CLUSTERS: ${REDIS_CLUSTERS:-}
|
||||
REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false}
|
||||
REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-}
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
|
||||
CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1}
|
||||
BROKER_USE_SSL: ${BROKER_USE_SSL:-false}
|
||||
@ -185,6 +188,10 @@ x-shared-env: &shared-api-worker-env
|
||||
ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-}
|
||||
ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify}
|
||||
ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-}
|
||||
ANALYTICDB_HOST: ${ANALYTICDB_HOST:-}
|
||||
ANALYTICDB_PORT: ${ANALYTICDB_PORT:-5432}
|
||||
ANALYTICDB_MIN_CONNECTION: ${ANALYTICDB_MIN_CONNECTION:-1}
|
||||
ANALYTICDB_MAX_CONNECTION: ${ANALYTICDB_MAX_CONNECTION:-5}
|
||||
OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch}
|
||||
OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200}
|
||||
OPENSEARCH_USER: ${OPENSEARCH_USER:-admin}
|
||||
@ -359,6 +366,8 @@ services:
|
||||
redis:
|
||||
image: redis:6-alpine
|
||||
restart: always
|
||||
environment:
|
||||
REDISCLI_AUTH: ${REDIS_PASSWORD:-difyai123456}
|
||||
volumes:
|
||||
# Mount the redis data directory to the container.
|
||||
- ./volumes/redis/data:/data
|
||||
|
||||
@ -42,11 +42,13 @@ POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB
|
||||
|
||||
# -----------------------------
|
||||
# Environment Variables for redis Service
|
||||
REDIS_HOST_VOLUME=./volumes/redis/data
|
||||
# -----------------------------
|
||||
REDIS_HOST_VOLUME=./volumes/redis/data
|
||||
REDIS_PASSWORD=difyai123456
|
||||
|
||||
# ------------------------------
|
||||
# Environment Variables for sandbox Service
|
||||
# ------------------------------
|
||||
SANDBOX_API_KEY=dify-sandbox
|
||||
SANDBOX_GIN_MODE=release
|
||||
SANDBOX_WORKER_TIMEOUT=15
|
||||
@ -54,7 +56,6 @@ SANDBOX_ENABLE_NETWORK=true
|
||||
SANDBOX_HTTP_PROXY=http://ssrf_proxy:3128
|
||||
SANDBOX_HTTPS_PROXY=http://ssrf_proxy:3128
|
||||
SANDBOX_PORT=8194
|
||||
# ------------------------------
|
||||
|
||||
# ------------------------------
|
||||
# Environment Variables for ssrf_proxy Service
|
||||
|
||||
6
web/.gitignore
vendored
6
web/.gitignore
vendored
@ -50,4 +50,8 @@ package-lock.json
|
||||
pnpm-lock.yaml
|
||||
|
||||
.favorites.json
|
||||
*storybook.log
|
||||
*storybook.log
|
||||
|
||||
# mise
|
||||
mise.toml
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@ import { useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
import useSWR from 'swr'
|
||||
|
||||
// Components
|
||||
import ExternalAPIPanel from '../../components/datasets/external-api/external-api-panel'
|
||||
@ -28,6 +27,8 @@ import { useTabSearchParams } from '@/hooks/use-tab-searchparams'
|
||||
import { useStore as useTagStore } from '@/app/components/base/tag-management/store'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useExternalApiPanel } from '@/context/external-api-panel-context'
|
||||
// eslint-disable-next-line import/order
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
|
||||
const Container = () => {
|
||||
const { t } = useTranslation()
|
||||
@ -47,7 +48,13 @@ const Container = () => {
|
||||
defaultTab: 'dataset',
|
||||
})
|
||||
const containerRef = useRef<HTMLDivElement>(null)
|
||||
const { data } = useSWR(activeTab === 'dataset' ? null : '/datasets/api-base-info', fetchDatasetApiBaseUrl)
|
||||
const { data } = useQuery(
|
||||
{
|
||||
queryKey: ['datasetApiBaseInfo'],
|
||||
queryFn: () => fetchDatasetApiBaseUrl('/datasets/api-base-info'),
|
||||
enabled: activeTab !== 'dataset',
|
||||
},
|
||||
)
|
||||
|
||||
const [keywords, setKeywords] = useState('')
|
||||
const [searchKeywords, setSearchKeywords] = useState('')
|
||||
|
||||
@ -329,7 +329,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
<Col sticky>
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="POST"
|
||||
tag="GET"
|
||||
label="/datasets"
|
||||
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets?page=1&limit=20' \\\n--header 'Authorization: Bearer {api_key}'`}
|
||||
>
|
||||
|
||||
@ -329,7 +329,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
<Col sticky>
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="POST"
|
||||
tag="GET"
|
||||
label="/datasets"
|
||||
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets?page=1&limit=20' \\\n--header 'Authorization: Bearer {api_key}'`}
|
||||
>
|
||||
|
||||
@ -8,24 +8,27 @@ import Header from '@/app/components/header'
|
||||
import { EventEmitterContextProvider } from '@/context/event-emitter'
|
||||
import { ProviderContextProvider } from '@/context/provider-context'
|
||||
import { ModalContextProvider } from '@/context/modal-context'
|
||||
import { TanstackQueryIniter } from '@/context/query-client'
|
||||
|
||||
const Layout = ({ children }: { children: ReactNode }) => {
|
||||
return (
|
||||
<>
|
||||
<GA gaType={GaType.admin} />
|
||||
<SwrInitor>
|
||||
<AppContextProvider>
|
||||
<EventEmitterContextProvider>
|
||||
<ProviderContextProvider>
|
||||
<ModalContextProvider>
|
||||
<HeaderWrapper>
|
||||
<Header />
|
||||
</HeaderWrapper>
|
||||
{children}
|
||||
</ModalContextProvider>
|
||||
</ProviderContextProvider>
|
||||
</EventEmitterContextProvider>
|
||||
</AppContextProvider>
|
||||
<TanstackQueryIniter>
|
||||
<AppContextProvider>
|
||||
<EventEmitterContextProvider>
|
||||
<ProviderContextProvider>
|
||||
<ModalContextProvider>
|
||||
<HeaderWrapper>
|
||||
<Header />
|
||||
</HeaderWrapper>
|
||||
{children}
|
||||
</ModalContextProvider>
|
||||
</ProviderContextProvider>
|
||||
</EventEmitterContextProvider>
|
||||
</AppContextProvider>
|
||||
</TanstackQueryIniter>
|
||||
</SwrInitor>
|
||||
</>
|
||||
)
|
||||
|
||||
@ -676,6 +676,10 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
||||
const [showDrawer, setShowDrawer] = useState<boolean>(false) // Whether to display the chat details drawer
|
||||
const [currentConversation, setCurrentConversation] = useState<ChatConversationGeneralDetail | CompletionConversationGeneralDetail | undefined>() // Currently selected conversation
|
||||
const isChatMode = appDetail.mode !== 'completion' // Whether the app is a chat app
|
||||
const { setShowPromptLogModal, setShowAgentLogModal } = useAppStore(useShallow(state => ({
|
||||
setShowPromptLogModal: state.setShowPromptLogModal,
|
||||
setShowAgentLogModal: state.setShowAgentLogModal,
|
||||
})))
|
||||
|
||||
// Annotated data needs to be highlighted
|
||||
const renderTdValue = (value: string | number | null, isEmptyStyle: boolean, isHighlight = false, annotation?: LogAnnotation) => {
|
||||
@ -699,6 +703,8 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
||||
onRefresh()
|
||||
setShowDrawer(false)
|
||||
setCurrentConversation(undefined)
|
||||
setShowPromptLogModal(false)
|
||||
setShowAgentLogModal(false)
|
||||
}
|
||||
|
||||
if (!logs)
|
||||
|
||||
@ -1804,8 +1804,85 @@ exports[`build chat item tree and get thread messages should get thread messages
|
||||
]
|
||||
`;
|
||||
|
||||
exports[`build chat item tree and get thread messages should work with partial messages 1`] = `
|
||||
exports[`build chat item tree and get thread messages should work with partial messages 1 1`] = `
|
||||
[
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"agent_thoughts": [
|
||||
{
|
||||
"chain_id": null,
|
||||
"created_at": 1726105799,
|
||||
"files": [],
|
||||
"id": "9730d587-9268-4683-9dd9-91a1cab9510b",
|
||||
"message_id": "4c5d0841-1206-463e-95d8-71f812877658",
|
||||
"observation": "",
|
||||
"position": 1,
|
||||
"thought": "I'll go with 112. Your turn!",
|
||||
"tool": "",
|
||||
"tool_input": "",
|
||||
"tool_labels": {},
|
||||
},
|
||||
],
|
||||
"children": [],
|
||||
"content": "I'll go with 112. Your turn!",
|
||||
"conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80",
|
||||
"feedbackDisabled": false,
|
||||
"id": "4c5d0841-1206-463e-95d8-71f812877658",
|
||||
"input": {
|
||||
"inputs": {},
|
||||
"query": "99",
|
||||
},
|
||||
"isAnswer": true,
|
||||
"log": [
|
||||
{
|
||||
"files": [],
|
||||
"role": "user",
|
||||
"text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38",
|
||||
},
|
||||
{
|
||||
"files": [],
|
||||
"role": "assistant",
|
||||
"text": "Sure, I'll play! My number is 57. Your turn!",
|
||||
},
|
||||
{
|
||||
"files": [],
|
||||
"role": "user",
|
||||
"text": "58",
|
||||
},
|
||||
{
|
||||
"files": [],
|
||||
"role": "assistant",
|
||||
"text": "I choose 83. What's your next number?",
|
||||
},
|
||||
{
|
||||
"files": [],
|
||||
"role": "user",
|
||||
"text": "99",
|
||||
},
|
||||
{
|
||||
"files": [],
|
||||
"role": "assistant",
|
||||
"text": "I'll go with 112. Your turn!",
|
||||
},
|
||||
],
|
||||
"message_files": [],
|
||||
"more": {
|
||||
"latency": "1.49",
|
||||
"time": "09/11/2024 09:50 PM",
|
||||
"tokens": 86,
|
||||
},
|
||||
"parentMessageId": "question-4c5d0841-1206-463e-95d8-71f812877658",
|
||||
"siblingIndex": 0,
|
||||
"workflow_run_id": null,
|
||||
},
|
||||
],
|
||||
"content": "99",
|
||||
"id": "question-4c5d0841-1206-463e-95d8-71f812877658",
|
||||
"isAnswer": false,
|
||||
"message_files": [],
|
||||
"parentMessageId": "73bbad14-d915-499d-87bf-0df14d40779d",
|
||||
},
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
@ -2078,6 +2155,178 @@ exports[`build chat item tree and get thread messages should work with partial m
|
||||
]
|
||||
`;
|
||||
|
||||
exports[`build chat item tree and get thread messages should work with partial messages 2 1`] = `
|
||||
[
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [],
|
||||
"content": "237.",
|
||||
"id": "ebb73fe2-15de-46dd-aab5-75416d8448eb",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb",
|
||||
"siblingIndex": 0,
|
||||
},
|
||||
],
|
||||
"content": "123",
|
||||
"id": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418",
|
||||
},
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [],
|
||||
"content": "My number is 256.",
|
||||
"id": "3553d508-3850-462e-8594-078539f940f9",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-3553d508-3850-462e-8594-078539f940f9",
|
||||
"siblingIndex": 1,
|
||||
},
|
||||
],
|
||||
"content": "123",
|
||||
"id": "question-3553d508-3850-462e-8594-078539f940f9",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418",
|
||||
},
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [],
|
||||
"content": "My number is 3e (approximately 8.15).",
|
||||
"id": "9eac3bcc-8d3b-4e56-a12b-44c34cebc719",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719",
|
||||
"siblingIndex": 0,
|
||||
},
|
||||
],
|
||||
"content": "e",
|
||||
"id": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c",
|
||||
},
|
||||
],
|
||||
"content": "My number is 2π (approximately 6.28).",
|
||||
"id": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c",
|
||||
"siblingIndex": 0,
|
||||
},
|
||||
],
|
||||
"content": "π",
|
||||
"id": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "46a49bb9-0881-459e-8c6a-24d20ae48d2f",
|
||||
},
|
||||
],
|
||||
"content": "My number is 145.",
|
||||
"id": "46a49bb9-0881-459e-8c6a-24d20ae48d2f",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f",
|
||||
"siblingIndex": 0,
|
||||
},
|
||||
],
|
||||
"content": "78",
|
||||
"id": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "3cded945-855a-4a24-aab7-43c7dd54664c",
|
||||
},
|
||||
],
|
||||
"content": "My number is 7.89.",
|
||||
"id": "3cded945-855a-4a24-aab7-43c7dd54664c",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-3cded945-855a-4a24-aab7-43c7dd54664c",
|
||||
"siblingIndex": 0,
|
||||
},
|
||||
],
|
||||
"content": "3.11",
|
||||
"id": "question-3cded945-855a-4a24-aab7-43c7dd54664c",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7",
|
||||
},
|
||||
],
|
||||
"content": "My number is 22.",
|
||||
"id": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7",
|
||||
"siblingIndex": 0,
|
||||
},
|
||||
],
|
||||
"content": "-5",
|
||||
"id": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "93bac05d-1470-4ac9-b090-fe21cd7c3d55",
|
||||
},
|
||||
],
|
||||
"content": "My number is 4782.",
|
||||
"id": "93bac05d-1470-4ac9-b090-fe21cd7c3d55",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55",
|
||||
"siblingIndex": 0,
|
||||
},
|
||||
],
|
||||
"content": "3306",
|
||||
"id": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "9e51a13b-7780-4565-98dc-f2d8c3b1758f",
|
||||
},
|
||||
],
|
||||
"content": "My number is 2048.",
|
||||
"id": "9e51a13b-7780-4565-98dc-f2d8c3b1758f",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f",
|
||||
"siblingIndex": 0,
|
||||
},
|
||||
],
|
||||
"content": "1024",
|
||||
"id": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "507f9df9-1f06-4a57-bb38-f00228c42c22",
|
||||
},
|
||||
],
|
||||
"content": "My number is 259.",
|
||||
"id": "507f9df9-1f06-4a57-bb38-f00228c42c22",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-507f9df9-1f06-4a57-bb38-f00228c42c22",
|
||||
"siblingIndex": 2,
|
||||
},
|
||||
],
|
||||
"content": "123",
|
||||
"id": "question-507f9df9-1f06-4a57-bb38-f00228c42c22",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418",
|
||||
},
|
||||
]
|
||||
`;
|
||||
|
||||
exports[`build chat item tree and get thread messages should work with real world messages 1`] = `
|
||||
[
|
||||
{
|
||||
|
||||
122
web/app/components/base/chat/__tests__/partialMessages.json
Normal file
122
web/app/components/base/chat/__tests__/partialMessages.json
Normal file
@ -0,0 +1,122 @@
|
||||
[
|
||||
{
|
||||
"id": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb",
|
||||
"content": "123",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418"
|
||||
},
|
||||
{
|
||||
"id": "ebb73fe2-15de-46dd-aab5-75416d8448eb",
|
||||
"content": "237.",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb"
|
||||
},
|
||||
{
|
||||
"id": "question-3553d508-3850-462e-8594-078539f940f9",
|
||||
"content": "123",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418"
|
||||
},
|
||||
{
|
||||
"id": "3553d508-3850-462e-8594-078539f940f9",
|
||||
"content": "My number is 256.",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-3553d508-3850-462e-8594-078539f940f9"
|
||||
},
|
||||
{
|
||||
"id": "question-507f9df9-1f06-4a57-bb38-f00228c42c22",
|
||||
"content": "123",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418"
|
||||
},
|
||||
{
|
||||
"id": "507f9df9-1f06-4a57-bb38-f00228c42c22",
|
||||
"content": "My number is 259.",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-507f9df9-1f06-4a57-bb38-f00228c42c22"
|
||||
},
|
||||
{
|
||||
"id": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f",
|
||||
"content": "1024",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "507f9df9-1f06-4a57-bb38-f00228c42c22"
|
||||
},
|
||||
{
|
||||
"id": "9e51a13b-7780-4565-98dc-f2d8c3b1758f",
|
||||
"content": "My number is 2048.",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f"
|
||||
},
|
||||
{
|
||||
"id": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55",
|
||||
"content": "3306",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "9e51a13b-7780-4565-98dc-f2d8c3b1758f"
|
||||
},
|
||||
{
|
||||
"id": "93bac05d-1470-4ac9-b090-fe21cd7c3d55",
|
||||
"content": "My number is 4782.",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55"
|
||||
},
|
||||
{
|
||||
"id": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7",
|
||||
"content": "-5",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "93bac05d-1470-4ac9-b090-fe21cd7c3d55"
|
||||
},
|
||||
{
|
||||
"id": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7",
|
||||
"content": "My number is 22.",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7"
|
||||
},
|
||||
{
|
||||
"id": "question-3cded945-855a-4a24-aab7-43c7dd54664c",
|
||||
"content": "3.11",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7"
|
||||
},
|
||||
{
|
||||
"id": "3cded945-855a-4a24-aab7-43c7dd54664c",
|
||||
"content": "My number is 7.89.",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-3cded945-855a-4a24-aab7-43c7dd54664c"
|
||||
},
|
||||
{
|
||||
"id": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f",
|
||||
"content": "78",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "3cded945-855a-4a24-aab7-43c7dd54664c"
|
||||
},
|
||||
{
|
||||
"id": "46a49bb9-0881-459e-8c6a-24d20ae48d2f",
|
||||
"content": "My number is 145.",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f"
|
||||
},
|
||||
{
|
||||
"id": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c",
|
||||
"content": "π",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "46a49bb9-0881-459e-8c6a-24d20ae48d2f"
|
||||
},
|
||||
{
|
||||
"id": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c",
|
||||
"content": "My number is 2π (approximately 6.28).",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c"
|
||||
},
|
||||
{
|
||||
"id": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719",
|
||||
"content": "e",
|
||||
"isAnswer": false,
|
||||
"parentMessageId": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c"
|
||||
},
|
||||
{
|
||||
"id": "9eac3bcc-8d3b-4e56-a12b-44c34cebc719",
|
||||
"content": "My number is 3e (approximately 8.15).",
|
||||
"isAnswer": true,
|
||||
"parentMessageId": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719"
|
||||
}
|
||||
]
|
||||
@ -7,6 +7,7 @@ import mixedTestMessages from './mixedTestMessages.json'
|
||||
import multiRootNodesMessages from './multiRootNodesMessages.json'
|
||||
import multiRootNodesWithLegacyTestMessages from './multiRootNodesWithLegacyTestMessages.json'
|
||||
import realWorldMessages from './realWorldMessages.json'
|
||||
import partialMessages from './partialMessages.json'
|
||||
|
||||
function visitNode(tree: ChatItemInTree | ChatItemInTree[], path: string): ChatItemInTree {
|
||||
return get(tree, path)
|
||||
@ -256,9 +257,15 @@ describe('build chat item tree and get thread messages', () => {
|
||||
expect(threadMessages6_2).toMatchSnapshot()
|
||||
})
|
||||
|
||||
const partialMessages = (realWorldMessages as ChatItemInTree[]).slice(-10)
|
||||
const tree7 = buildChatItemTree(partialMessages)
|
||||
it('should work with partial messages', () => {
|
||||
const partialMessages1 = (realWorldMessages as ChatItemInTree[]).slice(-10)
|
||||
const tree7 = buildChatItemTree(partialMessages1)
|
||||
it('should work with partial messages 1', () => {
|
||||
expect(tree7).toMatchSnapshot()
|
||||
})
|
||||
|
||||
const partialMessages2 = (partialMessages as ChatItemInTree[])
|
||||
const tree8 = buildChatItemTree(partialMessages2)
|
||||
it('should work with partial messages 2', () => {
|
||||
expect(tree8).toMatchSnapshot()
|
||||
})
|
||||
})
|
||||
|
||||
@ -127,19 +127,16 @@ function buildChatItemTree(allMessages: IChatItem[]): ChatItemInTree[] {
|
||||
lastAppendedLegacyAnswer = answerNode
|
||||
}
|
||||
else {
|
||||
if (!parentMessageId)
|
||||
if (
|
||||
!parentMessageId
|
||||
|| !allMessages.some(item => item.id === parentMessageId) // parent message might not be fetched yet, in this case we will append the question to the root nodes
|
||||
)
|
||||
rootNodes.push(questionNode)
|
||||
else
|
||||
map[parentMessageId]?.children!.push(questionNode)
|
||||
}
|
||||
}
|
||||
|
||||
// If no messages have parentMessageId=null (indicating a root node),
|
||||
// then we likely have a partial chat history. In this case,
|
||||
// use the first available message as the root node.
|
||||
if (rootNodes.length === 0 && allMessages.length > 0)
|
||||
rootNodes.push(map[allMessages[0]!.id]!)
|
||||
|
||||
return rootNodes
|
||||
}
|
||||
|
||||
|
||||
@ -13,10 +13,19 @@ const FileInput = ({
|
||||
const files = useStore(s => s.files)
|
||||
const { handleLocalFileUpload } = useFile(fileConfig)
|
||||
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const file = e.target.files?.[0]
|
||||
const targetFiles = e.target.files
|
||||
|
||||
if (file)
|
||||
handleLocalFileUpload(file)
|
||||
if (targetFiles) {
|
||||
if (fileConfig.number_limits) {
|
||||
for (let i = 0; i < targetFiles.length; i++) {
|
||||
if (i + 1 + files.length <= fileConfig.number_limits)
|
||||
handleLocalFileUpload(targetFiles[i])
|
||||
}
|
||||
}
|
||||
else {
|
||||
handleLocalFileUpload(targetFiles[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const allowedFileTypes = fileConfig.allowed_file_types
|
||||
@ -32,6 +41,7 @@ const FileInput = ({
|
||||
onChange={handleChange}
|
||||
accept={accept}
|
||||
disabled={!!(fileConfig.number_limits && files.length >= fileConfig?.number_limits)}
|
||||
multiple={!!fileConfig.number_limits && fileConfig.number_limits > 1}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user