Compare commits

..

2 Commits

320 changed files with 1419 additions and 5731 deletions

View File

@ -8,7 +8,7 @@ inputs:
poetry-version:
description: Poetry version to set up
required: true
default: '2.0.1'
default: '1.8.4'
poetry-lockfile:
description: Path to the Poetry lockfile to restore cache from
required: true

View File

@ -42,23 +42,25 @@ jobs:
run: poetry install -C api --with dev
- name: Check dependencies in pyproject.toml
run: poetry run -P api bash dev/pytest/pytest_artifacts.sh
run: poetry run -C api bash dev/pytest/pytest_artifacts.sh
- name: Run Unit tests
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh
run: poetry run -C api bash dev/pytest/pytest_unit_tests.sh
- name: Run ModelRuntime
run: poetry run -P api bash dev/pytest/pytest_model_runtime.sh
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
- name: Run dify config tests
run: poetry run -P api python dev/pytest/pytest_config_tests.py
run: poetry run -C api python dev/pytest/pytest_config_tests.py
- name: Run Tool
run: poetry run -P api bash dev/pytest/pytest_tools.sh
run: poetry run -C api bash dev/pytest/pytest_tools.sh
- name: Run mypy
run: |
poetry run -C api python -m mypy --install-types --non-interactive .
pushd api
poetry run python -m mypy --install-types --non-interactive .
popd
- name: Set up dotenvs
run: |
@ -78,4 +80,4 @@ jobs:
ssrf_proxy
- name: Run Workflow
run: poetry run -P api bash dev/pytest/pytest_workflow.sh
run: poetry run -C api bash dev/pytest/pytest_workflow.sh

View File

@ -38,12 +38,12 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: |
poetry run -C api ruff --version
poetry run -C api ruff check ./
poetry run -C api ruff format --check ./
poetry run -C api ruff check ./api
poetry run -C api ruff format --check ./api
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -P api dotenv-linter ./api/.env.example ./web/.env.example
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
- name: Lint hints
if: failure()
@ -82,33 +82,6 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: yarn run lint
docker-compose-template:
name: Docker Compose Template
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- name: Generate Docker Compose
if: steps.changed-files.outputs.any_changed == 'true'
run: |
cd docker
./generate_docker_compose
- name: Check for changes
if: steps.changed-files.outputs.any_changed == 'true'
run: git diff --exit-code
superlinter:
name: SuperLinter

View File

@ -70,4 +70,4 @@ jobs:
tidb
- name: Test Vector Stores
run: poetry run -P api bash dev/pytest/pytest_vdb.sh
run: poetry run -C api bash dev/pytest/pytest_vdb.sh

View File

@ -53,12 +53,10 @@ ignore = [
"FURB152", # math-constant
"UP007", # non-pep604-annotation
"UP032", # f-string
"UP045", # non-pep604-annotation-optional
"B005", # strip-with-multi-characters
"B006", # mutable-argument-default
"B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg
"B903", # class-as-data-structure
"B904", # raise-without-from-inside-except
"B905", # zip-without-explicit-strict
"N806", # non-lowercase-variable-in-function

View File

@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
WORKDIR /app/api
# Install Poetry
ENV POETRY_VERSION=2.0.1
ENV POETRY_VERSION=1.8.4
# if you located in China, you can use aliyun mirror to speed up
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/

View File

@ -79,5 +79,5 @@
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
```bash
poetry run -P api bash dev/pytest/pytest_all_tests.sh
poetry run -C api bash dev/pytest/pytest_all_tests.sh
```

View File

@ -146,7 +146,7 @@ class EndpointConfig(BaseSettings):
)
CONSOLE_WEB_URL: str = Field(
description="Base URL for the console web interface,used for frontend references and CORS configuration",
description="Base URL for the console web interface," "used for frontend references and CORS configuration",
default="",
)

View File

@ -181,7 +181,7 @@ class HostedFetchAppTemplateConfig(BaseSettings):
"""
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
description="Mode for fetching app templates: remote, db, or builtin default to remote,",
description="Mode for fetching app templates: remote, db, or builtin" " default to remote,",
default="remote",
)

View File

@ -33,9 +33,3 @@ class MilvusConfig(BaseSettings):
description="Name of the Milvus database to connect to (default is 'default')",
default="default",
)
MILVUS_ENABLE_HYBRID_SEARCH: bool = Field(
description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with "
"older versions",
default=True,
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.15.2",
default="0.14.2",
)
COMMIT_SHA: str = Field(

View File

@ -56,7 +56,7 @@ class InsertExploreAppListApi(Resource):
app = App.query.filter(App.id == args["app_id"]).first()
if not app:
raise NotFound(f"App '{args['app_id']}' is not found")
raise NotFound(f'App \'{args["app_id"]}\' is not found')
site = app.site
if not site:

View File

@ -22,7 +22,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from models import App, AppMode
from models.model import AppMode
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@ -79,7 +79,7 @@ class ChatMessageTextApi(Resource):
@login_required
@account_initialization_required
@get_app_model
def post(self, app_model: App):
def post(self, app_model):
from werkzeug.exceptions import InternalServerError
try:
@ -98,13 +98,9 @@ class ChatMessageTextApi(Resource):
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
if text_to_speech is None:
raise ValueError("TTS is not enabled")
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
if app_model.app_model_config is None:
raise ValueError("AppModelConfig not found")
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None

View File

@ -52,12 +52,12 @@ class DatasetListApi(Resource):
# provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true"
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
page, limit, current_user.current_tenant_id, current_user, search, tag_ids
)
# check embedding setting
@ -457,7 +457,7 @@ class DatasetIndexingEstimateApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -619,7 +619,8 @@ class DatasetRetrievalSettingApi(Resource):
vector_type = dify_config.VECTOR_STORE
match vector_type:
case (
VectorType.RELYT
VectorType.MILVUS
| VectorType.RELYT
| VectorType.PGVECTOR
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
@ -639,12 +640,10 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.PGVECTOR
| VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM
| VectorType.COUCHBASE
| VectorType.MILVUS
):
return {
"retrieval_method": [
@ -684,7 +683,6 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.COUCHBASE
| VectorType.PGVECTOR
| VectorType.LINDORM

View File

@ -257,8 +257,7 @@ class DatasetDocumentListApi(Resource):
parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
@ -350,7 +349,8 @@ class DatasetInitApi(Resource):
)
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -525,7 +525,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
return response.model_dump(), 200
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)

View File

@ -168,7 +168,8 @@ class DatasetDocumentSegmentApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -216,7 +217,8 @@ class DatasetDocumentSegmentAddApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -265,7 +267,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -365,9 +368,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
result = []
for index, row in df.iterrows():
if document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
data = {"content": row[0], "answer": row[1]}
else:
data = {"content": row.iloc[0]}
data = {"content": row[0]}
result.append(data)
if len(result) == 0:
raise ValueError("The CSV file is empty.")
@ -434,7 +437,8 @@ class ChildChunkAddApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)

View File

@ -32,7 +32,7 @@ class ConversationListApi(InstalledAppResource):
pinned = None
if "pinned" in args and args["pinned"] is not None:
pinned = args["pinned"] == "true"
pinned = True if args["pinned"] == "true" else False
try:
with Session(db.engine) as session:

View File

@ -7,4 +7,4 @@ api = ExternalApi(bp)
from . import index
from .app import app, audio, completion, conversation, file, message, workflow
from .dataset import dataset, document, hit_testing, segment, upload_file
from .dataset import dataset, document, hit_testing, segment

View File

@ -31,11 +31,8 @@ class DatasetListApi(DatasetApiResource):
# provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true"
datasets, total = DatasetService.get_datasets(
page, limit, tenant_id, current_user, search, tag_ids, include_all
)
datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids)
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)

View File

@ -53,7 +53,8 @@ class SegmentApi(DatasetApiResource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -94,7 +95,8 @@ class SegmentApi(DatasetApiResource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -173,7 +175,8 @@ class DatasetSegmentApi(DatasetApiResource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)

View File

@ -1,54 +0,0 @@
from werkzeug.exceptions import NotFound
from controllers.service_api import api
from controllers.service_api.wraps import (
DatasetApiResource,
)
from core.file import helpers as file_helpers
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import UploadFile
from services.dataset_service import DocumentService
class UploadFileApi(DatasetApiResource):
def get(self, tenant_id, dataset_id, document_id):
"""Get upload file."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check upload file
if document.data_source_type != "upload_file":
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("UploadFile not found.")
else:
raise ValueError("Upload file id not found in document data source info.")
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"url": url,
"download_url": f"{url}&as_attachment=true",
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at.timestamp(),
}, 200
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")

View File

@ -1,5 +1,5 @@
from collections.abc import Callable
from datetime import UTC, datetime, timedelta
from datetime import UTC, datetime
from enum import Enum
from functools import wraps
from typing import Optional
@ -8,8 +8,6 @@ from flask import current_app, request
from flask_login import user_logged_in # type: ignore
from flask_restful import Resource # type: ignore
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized
from extensions.ext_database import db
@ -176,7 +174,7 @@ def validate_dataset_token(view=None):
return decorator
def validate_and_get_api_token(scope: str | None = None):
def validate_and_get_api_token(scope=None):
"""
Validate and get API token.
"""
@ -190,29 +188,20 @@ def validate_and_get_api_token(scope: str | None = None):
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")
current_time = datetime.now(UTC).replace(tzinfo=None)
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
update(ApiToken)
.where(
ApiToken.token == auth_token,
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
ApiToken.type == scope,
)
.values(last_used_at=current_time)
.returning(ApiToken)
api_token = (
db.session.query(ApiToken)
.filter(
ApiToken.token == auth_token,
ApiToken.type == scope,
)
result = session.execute(update_stmt)
api_token = result.scalar_one_or_none()
.first()
)
if not api_token:
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
raise Unauthorized("Access token is invalid")
else:
session.commit()
if not api_token:
raise Unauthorized("Access token is invalid")
api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return api_token
@ -240,7 +229,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == "DEFAULT-USER",
is_anonymous=True if user_id == "DEFAULT-USER" else False,
session_id=user_id,
)
db.session.add(end_user)

View File

@ -39,7 +39,7 @@ class ConversationListApi(WebApiResource):
pinned = None
if "pinned" in args and args["pinned"] is not None:
pinned = args["pinned"] == "true"
pinned = True if args["pinned"] == "true" else False
try:
with Session(db.engine) as session:

View File

@ -172,7 +172,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
tool_name=scratchpad.action.action_name if scratchpad.action else "",
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_invoke_meta={},
thought=scratchpad.thought or "",

View File

@ -167,7 +167,8 @@ class AppQueueManager:
else:
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
"Critical Error: Passing SQLAlchemy Model instances "
"that cause thread safety issues is not allowed."
)

View File

@ -89,7 +89,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
Conversation.status == "normal",
Conversation.is_deleted.is_(False),
]
if isinstance(user, Account):

View File

@ -145,7 +145,7 @@ class MessageCycleManage:
# get extension
if "." in message_file.url:
extension = f".{message_file.url.split('.')[-1]}"
extension = f'.{message_file.url.split(".")[-1]}'
if len(extension) > 10:
extension = ".bin"
else:

View File

@ -62,9 +62,8 @@ class ApiExternalDataTool(ExternalDataTool):
if not api_based_extension:
raise ValueError(
"[External data tool] API query failed, variable: {}, error: api_based_extension_id is invalid".format(
self.variable
)
"[External data tool] API query failed, variable: {}, "
"error: api_based_extension_id is invalid".format(self.variable)
)
# decrypt api_key

View File

@ -90,7 +90,7 @@ class File(BaseModel):
def markdown(self) -> str:
url = self.generate_url()
if self.type == FileType.IMAGE:
text = f"![{self.filename or ''}]({url})"
text = f'![{self.filename or ""}]({url})'
else:
text = f"[{self.filename or url}]({url})"

View File

@ -530,6 +530,7 @@ class IndexingRunner:
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
tokens = 0
chunk_size = 10
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
# create keyword index
create_keyword_thread = threading.Thread(
@ -538,22 +539,11 @@ class IndexingRunner:
)
create_keyword_thread.start()
max_workers = 10
if dataset.indexing_technique == "high_quality":
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = []
# Distribute documents into multiple groups based on the hash values of page_content
# This is done to prevent multiple threads from processing the same document,
# Thereby avoiding potential database insertion deadlocks
document_groups: list[list[Document]] = [[] for _ in range(max_workers)]
for document in documents:
hash = helper.generate_text_hash(document.page_content)
group_index = int(hash, 16) % max_workers
document_groups[group_index].append(document)
for chunk_documents in document_groups:
if len(chunk_documents) == 0:
continue
for i in range(0, len(documents), chunk_size):
chunk_documents = documents[i : i + chunk_size]
futures.append(
executor.submit(
self._process_chunk,

View File

@ -131,7 +131,7 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n"
"MAKE SURE your output is the SAME language as the Assistant's latest response. "
"MAKE SURE your output is the SAME language as the Assistant's latest response"
"The output must be an array in JSON format following the specified schema:\n"
'["question1","question2","question3"]\n'
)

View File

@ -1,11 +1,13 @@
import logging
from concurrent.futures import ProcessPoolExecutor
from os.path import abspath, dirname, join
from threading import Lock
from typing import Any
from typing import Any, cast
logger = logging.getLogger(__name__)
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
_tokenizer: Any = None
_lock = Lock()
_executor = ProcessPoolExecutor(max_workers=1)
class GPT2Tokenizer:
@ -15,37 +17,22 @@ class GPT2Tokenizer:
use gpt2 tokenizer to get num tokens
"""
_tokenizer = GPT2Tokenizer.get_encoder()
tokens = _tokenizer.encode(text)
tokens = _tokenizer.encode(text, verbose=False)
return len(tokens)
@staticmethod
def get_num_tokens(text: str) -> int:
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
#
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
# result = future.result()
# return cast(int, result)
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
result = future.result()
return cast(int, result)
@staticmethod
def get_encoder() -> Any:
global _tokenizer, _lock
with _lock:
if _tokenizer is None:
# Try to use tiktoken to get the tokenizer because it is faster
#
try:
import tiktoken
_tokenizer = tiktoken.get_encoding("gpt2")
except Exception:
from os.path import abspath, dirname, join
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken")
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
return _tokenizer

View File

@ -108,7 +108,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not ai_model_entity:
raise CredentialsValidateFailedError(f"Base Model Name {credentials['base_model_name']} is invalid")
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
try:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))

View File

@ -130,7 +130,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
raise CredentialsValidateFailedError("Base Model Name is required")
if not self._get_ai_model_entity(credentials["base_model_name"], model):
raise CredentialsValidateFailedError(f"Base Model Name {credentials['base_model_name']} is invalid")
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
try:
credentials_kwargs = self._to_credential_kwargs(credentials)

View File

@ -70,7 +70,7 @@ class BedrockRerankModel(RerankModel):
rerankingConfiguration = {
"type": "BEDROCK_RERANKING_MODEL",
"bedrockRerankingConfiguration": {
"numberOfResults": min(top_n, len(text_sources)),
"numberOfResults": top_n,
"modelConfiguration": {
"modelArn": model_package_arn,
},

View File

@ -1,3 +1,2 @@
- deepseek-chat
- deepseek-coder
- deepseek-reasoner

View File

@ -10,7 +10,7 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 64000
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -10,7 +10,7 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 64000
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -1,21 +0,0 @@
model: deepseek-reasoner
label:
zh_Hans: deepseek-reasoner
en_US: deepseek-reasoner
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 64000
parameter_rules:
- name: max_tokens
use_template: max_tokens
min: 1
max: 8192
default: 4096
pricing:
input: "4"
output: "16"
unit: "0.000001"
currency: RMB

View File

@ -24,6 +24,9 @@ class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
# {"response_format": "xx"} need convert to {"response_format": {"type": "xx"}}
if "response_format" in model_parameters:
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
def validate_credentials(self, model: str, credentials: dict) -> None:

View File

@ -1,6 +1,5 @@
- gemini-2.0-flash-exp
- gemini-2.0-flash-thinking-exp-1219
- gemini-2.0-flash-thinking-exp-01-21
- gemini-1.5-pro
- gemini-1.5-pro-latest
- gemini-1.5-pro-001

View File

@ -1,39 +0,0 @@
model: gemini-2.0-flash-thinking-exp-01-21
label:
en_US: Gemini 2.0 Flash Thinking Exp 01-21
model_type: llm
features:
- agent-thought
- vision
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -162,9 +162,9 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
@staticmethod
def _check_endpoint_url_model_repository_name(credentials: dict, model_name: str):
try:
url = f"{HUGGINGFACE_ENDPOINT_API}{credentials['huggingface_namespace']}"
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
headers = {
"Authorization": f"Bearer {credentials['huggingfacehub_api_token']}",
"Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}',
"Content-Type": "application/json",
}

View File

@ -34,7 +34,6 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
class MinimaxLargeLanguageModel(LargeLanguageModel):
model_apis = {
"minimax-text-01": MinimaxChatCompletionPro,
"abab7-chat-preview": MinimaxChatCompletionPro,
"abab6.5t-chat": MinimaxChatCompletionPro,
"abab6.5s-chat": MinimaxChatCompletionPro,

View File

@ -1,46 +0,0 @@
model: minimax-text-01
label:
en_US: Minimax-Text-01
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 1000192
parameter_rules:
- name: temperature
use_template: temperature
min: 0.01
max: 1
default: 0.1
- name: top_p
use_template: top_p
min: 0.01
max: 1
default: 0.95
- name: max_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 1000192
- name: mask_sensitive_info
type: boolean
default: true
label:
zh_Hans: 隐私保护
en_US: Moderate
help:
zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码目前包括但不限于邮箱、域名、链接、证件号、家庭住址等默认true即开启打码
en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
pricing:
input: '0.001'
output: '0.008'
unit: '0.001'
currency: RMB

View File

@ -44,6 +44,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
self._add_custom_parameters(credentials)
self._add_function_call(model, credentials)
user = user[:32] if user else None
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
if "response_format" in model_parameters:
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:

View File

@ -1,6 +1,5 @@
import json
import logging
import re
from collections.abc import Generator
from typing import Any, Optional, Union, cast
@ -622,19 +621,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
# o1 compatibility
block_as_stream = False
if model.startswith("o1"):
if "max_tokens" in model_parameters:
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
del model_parameters["max_tokens"]
if re.match(r"^o1(-\d{4}-\d{2}-\d{2})?$", model):
if stream:
block_as_stream = True
stream = False
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]
if "stop" in extra_model_kwargs:
del extra_model_kwargs["stop"]
@ -651,45 +642,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
if block_as_stream:
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
return block_result
def _handle_chat_block_as_stream_response(
self,
block_result: LLMResult,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return: llm response chunk generator
"""
text = block_result.message.content
text = cast(str, text)
if stop:
text = self.enforce_stop_tokens(text, stop)
yield LLMResultChunk(
model=block_result.model,
prompt_messages=prompt_messages,
system_fingerprint=block_result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=block_result.message,
finish_reason="stop",
usage=block_result.usage,
),
)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
def _handle_chat_generate_response(
self,

View File

@ -7,7 +7,6 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 200000

View File

@ -7,8 +7,6 @@
- Qwen/Qwen2.5-Coder-7B-Instruct
- Qwen/Qwen2-VL-72B-Instruct
- Qwen/Qwen2-1.5B-Instruct
- Qwen/Qwen2.5-72B-Instruct-128K
- Vendor-A/Qwen/Qwen2.5-72B-Instruct
- Pro/Qwen/Qwen2-VL-7B-Instruct
- OpenGVLab/InternVL2-26B
- Pro/OpenGVLab/InternVL2-8B

View File

@ -29,6 +29,9 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel):
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
if "response_format" in model_parameters:
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
def validate_credentials(self, model: str, credentials: dict) -> None:

View File

@ -1,51 +0,0 @@
model: Qwen/Qwen2.5-72B-Instruct-128K
label:
en_US: Qwen/Qwen2.5-72B-Instruct-128K
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
- name: max_tokens
use_template: max_tokens
type: int
default: 512
min: 1
max: 4096
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '4.13'
output: '4.13'
unit: '0.000001'
currency: RMB

View File

@ -1,51 +0,0 @@
model: Vendor-A/Qwen/Qwen2.5-72B-Instruct
label:
en_US: Vendor-A/Qwen/Qwen2.5-72B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
- name: max_tokens
use_template: max_tokens
type: int
default: 512
min: 1
max: 4096
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '1.00'
output: '1.00'
unit: '0.000001'
currency: RMB

View File

@ -15,7 +15,7 @@ parameter_rules:
type: int
default: 512
min: 1
max: 4096
max: 8192
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.

View File

@ -1,37 +0,0 @@
model: fishaudio/fish-speech-1.5
model_type: tts
model_properties:
default_voice: 'fishaudio/fish-speech-1.5:alex'
voices:
- mode: "fishaudio/fish-speech-1.5:alex"
name: "Alex男声"
language: [ "zh-Hans", "en-US" ]
- mode: "fishaudio/fish-speech-1.5:benjamin"
name: "Benjamin男声"
language: [ "zh-Hans", "en-US" ]
- mode: "fishaudio/fish-speech-1.5:charles"
name: "Charles男声"
language: [ "zh-Hans", "en-US" ]
- mode: "fishaudio/fish-speech-1.5:david"
name: "David男声"
language: [ "zh-Hans", "en-US" ]
- mode: "fishaudio/fish-speech-1.5:anna"
name: "Anna女声"
language: [ "zh-Hans", "en-US" ]
- mode: "fishaudio/fish-speech-1.5:bella"
name: "Bella女声"
language: [ "zh-Hans", "en-US" ]
- mode: "fishaudio/fish-speech-1.5:claire"
name: "Claire女声"
language: [ "zh-Hans", "en-US" ]
- mode: "fishaudio/fish-speech-1.5:diana"
name: "Diana女声"
language: [ "zh-Hans", "en-US" ]
audio_type: 'mp3'
max_workers: 5
# stream: false
pricing:
input: '0.015'
output: '0'
unit: '0.001'
currency: RMB

View File

@ -21,7 +21,7 @@ class SparkLLMClient:
domain = api_domain
model_api_configs = {
"spark-lite": {"version": "v1.1", "chat_domain": "lite"},
"spark-lite": {"version": "v1.1", "chat_domain": "general"},
"spark-pro": {"version": "v3.1", "chat_domain": "generalv3"},
"spark-pro-128k": {"version": "pro-128k", "chat_domain": "pro-128k"},
"spark-max": {"version": "v3.5", "chat_domain": "generalv3.5"},

View File

@ -257,7 +257,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
for index, response in enumerate(responses):
if response.status_code not in {200, HTTPStatus.OK}:
raise ServiceUnavailableError(
f"Failed to invoke model {model}, status code: {response.status_code}, message: {response.message}"
f"Failed to invoke model {model}, status code: {response.status_code}, "
f"message: {response.message}"
)
resp_finish_reason = response.output.choices[0].finish_reason

View File

@ -146,7 +146,7 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel):
elif credentials["completion_type"] == "completion":
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f"completion_type {credentials['completion_type']} is not supported")
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
entity = AIModelEntity(
model=model,

View File

@ -18,93 +18,72 @@ class ModelConfig(BaseModel):
configs: dict[str, ModelConfig] = {
"Doubao-1.5-vision-pro-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=12288, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.VISION],
),
"Doubao-1.5-pro-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=12288, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
),
"Doubao-1.5-lite-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=12288, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
),
"Doubao-1.5-pro-256k": ModelConfig(
properties=ModelProperties(context_size=262144, max_tokens=12288, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
),
"Doubao-vision-pro-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.VISION],
features=[ModelFeature.VISION],
),
"Doubao-vision-lite-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.VISION],
features=[ModelFeature.VISION],
),
"Doubao-pro-4k": ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
features=[ModelFeature.TOOL_CALL],
),
"Doubao-lite-4k": ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
features=[ModelFeature.TOOL_CALL],
),
"Doubao-pro-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
features=[ModelFeature.TOOL_CALL],
),
"Doubao-lite-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
features=[ModelFeature.TOOL_CALL],
),
"Doubao-pro-256k": ModelConfig(
properties=ModelProperties(context_size=262144, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
features=[],
),
"Doubao-pro-128k": ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
features=[ModelFeature.TOOL_CALL],
),
"Doubao-lite-128k": ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[]
),
"Skylark2-pro-4k": ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), features=[]
),
"Llama3-8B": ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[]
),
"Llama3-70B": ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[]
),
"Moonshot-v1-8k": ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
features=[ModelFeature.TOOL_CALL],
),
"Moonshot-v1-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
features=[ModelFeature.TOOL_CALL],
),
"Moonshot-v1-128k": ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
features=[ModelFeature.TOOL_CALL],
),
"GLM3-130B": ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
features=[ModelFeature.TOOL_CALL],
),
"GLM3-130B-Fin": ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
features=[ModelFeature.TOOL_CALL],
),
"Mistral-7B": ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT),
features=[ModelFeature.AGENT_THOUGHT],
properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), features=[]
),
}

View File

@ -118,30 +118,6 @@ model_credential_schema:
type: select
required: true
options:
- label:
en_US: Doubao-1.5-vision-pro-32k
value: Doubao-1.5-vision-pro-32k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-1.5-pro-32k
value: Doubao-1.5-pro-32k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-1.5-lite-32k
value: Doubao-1.5-lite-32k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-1.5-pro-256k
value: Doubao-1.5-pro-256k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-vision-pro-32k
value: Doubao-vision-pro-32k

View File

@ -41,15 +41,15 @@ class BaiduAccessToken:
resp = response.json()
if "error" in resp:
if resp["error"] == "invalid_client":
raise InvalidAPIKeyError(f"Invalid API key or secret key: {resp['error_description']}")
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
elif resp["error"] == "unknown_error":
raise InternalServerError(f"Internal server error: {resp['error_description']}")
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
elif resp["error"] == "invalid_request":
raise BadRequestError(f"Bad request: {resp['error_description']}")
raise BadRequestError(f'Bad request: {resp["error_description"]}')
elif resp["error"] == "rate_limit_exceeded":
raise RateLimitReachedError(f"Rate limit reached: {resp['error_description']}")
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
else:
raise Exception(f"Unknown error: {resp['error_description']}")
raise Exception(f'Unknown error: {resp["error_description"]}')
return resp["access_token"]

View File

@ -406,7 +406,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
elif credentials["completion_type"] == "completion":
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f"completion_type {credentials['completion_type']} is not supported")
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
else:
extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials["server_url"],
@ -472,7 +472,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
api_key = credentials.get("api_key") or "abc"
client = OpenAI(
base_url=f"{credentials['server_url']}/v1",
base_url=f'{credentials["server_url"]}/v1',
api_key=api_key,
max_retries=int(credentials.get("max_retries") or DEFAULT_MAX_RETRIES),
timeout=int(credentials.get("invoke_timeout") or DEFAULT_INVOKE_TIMEOUT),

View File

@ -87,6 +87,6 @@ class CommonValidator:
if value.lower() not in {"true", "false"}:
raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
value = value.lower() == "true"
value = True if value.lower() == "true" else False
return value

View File

@ -6,7 +6,6 @@ from pydantic import BaseModel, ValidationInfo, field_validator
class TracingProviderEnum(Enum):
LANGFUSE = "langfuse"
LANGSMITH = "langsmith"
OPIK = "opik"
class BaseTracingConfig(BaseModel):
@ -57,36 +56,5 @@ class LangSmithConfig(BaseTracingConfig):
return v
class OpikConfig(BaseTracingConfig):
"""
Model class for Opik tracing config.
"""
api_key: str | None = None
project: str | None = None
workspace: str | None = None
url: str = "https://www.comet.com/opik/api/"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "Default Project"
return v
@field_validator("url")
@classmethod
def url_validator(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://www.comet.com/opik/api/"
if not v.startswith(("https://", "http://")):
raise ValueError("url must start with https:// or http://")
if not v.endswith("/api/"):
raise ValueError("url should ends with /api/")
return v
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

View File

@ -1,469 +0,0 @@
import json
import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import Optional, cast
from opik import Opik, Trace
from opik.id_helpers import uuid4_to_uuid7
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import OpikConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__)
def wrap_dict(key_name, data):
"""Make sure that the input data is a dict"""
if not isinstance(data, dict):
return {key_name: data}
return data
def wrap_metadata(metadata, **kwargs):
"""Add common metatada to all Traces and Spans"""
metadata["created_from"] = "dify"
metadata.update(kwargs)
return metadata
def prepare_opik_uuid(user_datetime: Optional[datetime], user_uuid: Optional[str]):
"""Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most
messages and objects. The type-hints of BaseTraceInfo indicates that
objects start_time and message_id could be null which means we cannot map
it to a UUIDv7. Given that we have no way to identify that object
uniquely, generate a new random one UUIDv7 in that case.
"""
if user_datetime is None:
user_datetime = datetime.now()
if user_uuid is None:
user_uuid = str(uuid.uuid4())
return uuid4_to_uuid7(user_datetime, user_uuid)
class OpikDataTrace(BaseTraceInstance):
def __init__(
self,
opik_config: OpikConfig,
):
super().__init__(opik_config)
self.opik_client = Opik(
project_name=opik_config.project,
workspace=opik_config.workspace,
host=opik_config.url,
api_key=opik_config.api_key,
)
self.project = opik_config.project
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
dify_trace_id = trace_info.workflow_run_id
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
workflow_metadata = wrap_metadata(
trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
)
root_span_id = None
if trace_info.message_id:
dify_trace_id = trace_info.message_id
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
trace_data = {
"id": opik_trace_id,
"name": TraceTaskName.MESSAGE_TRACE.value,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"tags": ["message", "workflow"],
"project_name": self.project,
}
self.add_trace(trace_data)
root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
span_data = {
"id": root_span_id,
"parent_span_id": None,
"trace_id": opik_trace_id,
"name": TraceTaskName.WORKFLOW_TRACE.value,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
"tags": ["workflow"],
"project_name": self.project,
}
self.add_span(span_data)
else:
trace_data = {
"id": opik_trace_id,
"name": TraceTaskName.MESSAGE_TRACE.value,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"tags": ["workflow"],
"project_name": self.project,
}
self.add_trace(trace_data)
# through workflow_run_id get all_nodes_execution
workflow_nodes_execution_id_records = (
db.session.query(WorkflowNodeExecution.id)
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
.all()
)
for node_execution_id_record in workflow_nodes_execution_id_records:
node_execution = (
db.session.query(
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id
app_id = node_execution.app_id
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == "llm":
inputs = (
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
)
else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = (
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
)
metadata = execution_metadata.copy()
metadata.update(
{
"workflow_run_id": trace_info.workflow_run_id,
"node_execution_id": node_execution_id,
"tenant_id": tenant_id,
"app_id": app_id,
"app_name": node_name,
"node_type": node_type,
"status": status,
}
)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
provider = None
model = None
total_tokens = 0
completion_tokens = 0
prompt_tokens = 0
if process_data and process_data.get("model_mode") == "chat":
run_type = "llm"
provider = process_data.get("model_provider", None)
model = process_data.get("model_name", "")
metadata.update(
{
"ls_provider": provider,
"ls_model_name": model,
}
)
try:
if outputs.get("usage"):
total_tokens = outputs["usage"].get("total_tokens", 0)
prompt_tokens = outputs["usage"].get("prompt_tokens", 0)
completion_tokens = outputs["usage"].get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
else:
run_type = "tool"
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
if not total_tokens:
total_tokens = execution_metadata.get("total_tokens", 0)
span_data = {
"trace_id": opik_trace_id,
"id": prepare_opik_uuid(created_at, node_execution_id),
"parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id),
"name": node_type,
"type": run_type,
"start_time": created_at,
"end_time": finished_at,
"metadata": wrap_metadata(metadata),
"input": wrap_dict("input", inputs),
"output": wrap_dict("output", outputs),
"tags": ["node_execution"],
"project_name": self.project,
"usage": {
"total_tokens": total_tokens,
"completion_tokens": completion_tokens,
"prompt_tokens": prompt_tokens,
},
"model": model,
"provider": provider,
}
self.add_span(span_data)
def message_trace(self, trace_info: MessageTraceInfo):
# get message file data
file_list = cast(list[str], trace_info.file_list) or []
message_file_data: Optional[MessageFile] = trace_info.message_file_data
if message_file_data is not None:
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
message_data = trace_info.message_data
if message_data is None:
return
metadata = trace_info.metadata
message_id = trace_info.message_id
user_id = message_data.from_account_id
metadata["user_id"] = user_id
metadata["file_list"] = file_list
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id
metadata["end_user_id"] = end_user_id
trace_data = {
"id": prepare_opik_uuid(trace_info.start_time, message_id),
"name": TraceTaskName.MESSAGE_TRACE.value,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(metadata),
"input": trace_info.inputs,
"output": message_data.answer,
"tags": ["message", str(trace_info.conversation_mode)],
"project_name": self.project,
}
trace = self.add_trace(trace_data)
span_data = {
"trace_id": trace.id,
"name": "llm",
"type": "llm",
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(metadata),
"input": {"input": trace_info.inputs},
"output": {"output": message_data.answer},
"tags": ["llm", str(trace_info.conversation_mode)],
"usage": {
"completion_tokens": trace_info.answer_tokens,
"prompt_tokens": trace_info.message_tokens,
"total_tokens": trace_info.total_tokens,
},
"project_name": self.project,
}
self.add_span(span_data)
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
"name": TraceTaskName.MODERATION_TRACE.value,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": {
"action": trace_info.action,
"flagged": trace_info.flagged,
"preset_response": trace_info.preset_response,
"inputs": trace_info.inputs,
},
"tags": ["moderation"],
}
self.add_span(span_data)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
start_time = trace_info.start_time or message_data.created_at
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or message_data.updated_at,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": wrap_dict("output", trace_info.suggested_question),
"tags": ["suggested_question"],
}
self.add_span(span_data)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": {"documents": trace_info.documents},
"tags": ["dataset_retrieval"],
}
self.add_span(span_data)
def tool_trace(self, trace_info: ToolTraceInfo):
span_data = {
"trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
"name": trace_info.tool_name,
"type": "tool",
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.tool_inputs),
"output": wrap_dict("output", trace_info.tool_outputs),
"tags": ["tool", trace_info.tool_name],
}
self.add_span(span_data)
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
trace_data = {
"id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),
"input": trace_info.inputs,
"output": trace_info.outputs,
"tags": ["generate_name"],
"project_name": self.project,
}
trace = self.add_trace(trace_data)
span_data = {
"trace_id": trace.id,
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": wrap_dict("output", trace_info.outputs),
"tags": ["generate_name"],
}
self.add_span(span_data)
def add_trace(self, opik_trace_data: dict) -> Trace:
try:
trace = self.opik_client.trace(**opik_trace_data)
logger.debug("Opik Trace created successfully")
return trace
except Exception as e:
raise ValueError(f"Opik Failed to create trace: {str(e)}")
def add_span(self, opik_span_data: dict):
try:
self.opik_client.span(**opik_span_data)
logger.debug("Opik Span created successfully")
except Exception as e:
raise ValueError(f"Opik Failed to create span: {str(e)}")
def api_check(self):
try:
self.opik_client.auth_check()
return True
except Exception as e:
logger.info(f"Opik API check failed: {str(e)}", exc_info=True)
raise ValueError(f"Opik API check failed: {str(e)}")
def get_project_url(self):
try:
return self.opik_client.get_project_url(project_name=self.project)
except Exception as e:
logger.info(f"Opik get run url failed: {str(e)}", exc_info=True)
raise ValueError(f"Opik get run url failed: {str(e)}")

View File

@ -17,7 +17,6 @@ from core.ops.entities.config_entity import (
OPS_FILE_PATH,
LangfuseConfig,
LangSmithConfig,
OpikConfig,
TracingProviderEnum,
)
from core.ops.entities.trace_entity import (
@ -33,7 +32,6 @@ from core.ops.entities.trace_entity import (
)
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from core.ops.opik_trace.opik_trace import OpikDataTrace
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
@ -54,12 +52,6 @@ provider_config_map: dict[str, dict[str, Any]] = {
"other_keys": ["project", "endpoint"],
"trace_instance": LangSmithDataTrace,
},
TracingProviderEnum.OPIK.value: {
"config_class": OpikConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "url", "workspace"],
"trace_instance": OpikDataTrace,
},
}

View File

@ -22,12 +22,7 @@ from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.helper.position_helper import is_filtered
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
CredentialFormSchema,
FormType,
ProviderEntity,
)
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
from core.model_runtime.model_providers import model_provider_factory
from extensions import ext_hosting_provider
from extensions.ext_database import db
@ -840,18 +835,11 @@ class ProviderManager:
:return:
"""
# Get provider model credential secret variables
if ConfigurateMethod.PREDEFINED_MODEL in provider_entity.configurate_methods:
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.provider_credential_schema.credential_form_schemas
if provider_entity.provider_credential_schema
else []
)
else:
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema
else []
)
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema
else []
)
model_settings: list[ModelSettings] = []
if not provider_model_settings:

View File

@ -1,104 +0,0 @@
import json
import logging
from typing import Any, Optional
from flask import current_app
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchConfig,
ElasticSearchVector,
ElasticSearchVectorFactory,
)
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class ElasticSearchJaVector(ElasticSearchVector):
def create_collection(
self,
embeddings: list[list[float]],
metadatas: Optional[list[dict[Any, Any]]] = None,
index_params: Optional[dict] = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return
if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
settings = {
"analysis": {
"analyzer": {
"ja_analyzer": {
"type": "custom",
"char_filter": [
"icu_normalizer",
"kuromoji_iteration_mark",
],
"tokenizer": "kuromoji_tokenizer",
"filter": [
"kuromoji_baseform",
"kuromoji_part_of_speech",
"ja_stop",
"kuromoji_number",
"kuromoji_stemmer",
],
}
}
}
}
mappings = {
"properties": {
Field.CONTENT_KEY.value: {
"type": "text",
"analyzer": "ja_analyzer",
"search_analyzer": "ja_analyzer",
},
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
},
},
}
}
self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
return ElasticSearchJaVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get("ELASTICSEARCH_HOST", "localhost"),
port=config.get("ELASTICSEARCH_PORT", 9200),
username=config.get("ELASTICSEARCH_USERNAME", ""),
password=config.get("ELASTICSEARCH_PASSWORD", ""),
),
attributes=[],
)

View File

@ -6,8 +6,6 @@ class Field(Enum):
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR = "vector"
# Sparse Vector aims to support full text search
SPARSE_VECTOR = "sparse_vector"
TEXT_KEY = "text"
PRIMARY_KEY = "id"
DOC_ID = "metadata.doc_id"

View File

@ -258,7 +258,7 @@ class LindormVectorStore(BaseVector):
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
nlist = kwargs.pop("nlist", 1000)
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", nlist >= 5000)
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False)
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
@ -305,7 +305,7 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic
if method_name == "ivfpq":
ivfpq_m = kwargs["ivfpq_m"]
nlist = kwargs["nlist"]
centroids_use_hnsw = nlist > 10000
centroids_use_hnsw = True if nlist > 10000 else False
centroids_hnsw_m = 24
centroids_hnsw_ef_construct = 500
centroids_hnsw_ef_search = 100

View File

@ -2,7 +2,6 @@ import json
import logging
from typing import Any, Optional
from packaging import version
from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException # type: ignore
from pymilvus.milvus_client import IndexParams # type: ignore
@ -21,25 +20,16 @@ logger = logging.getLogger(__name__)
class MilvusConfig(BaseModel):
"""
Configuration class for Milvus connection.
"""
uri: str # Milvus server URI
token: Optional[str] = None # Optional token for authentication
user: str # Username for authentication
password: str # Password for authentication
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
uri: str
token: Optional[str] = None
user: str
password: str
batch_size: int = 100
database: str = "default"
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
"""
Validate the configuration values.
Raises ValueError if required fields are missing.
"""
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")
if not values.get("user"):
@ -49,9 +39,6 @@ class MilvusConfig(BaseModel):
return values
def to_milvus_params(self):
"""
Convert the configuration to a dictionary of Milvus connection parameters.
"""
return {
"uri": self.uri,
"token": self.token,
@ -62,57 +49,26 @@ class MilvusConfig(BaseModel):
class MilvusVector(BaseVector):
"""
Milvus vector storage implementation.
"""
def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = "Session" # Consistency level for Milvus operations
self._fields: list[str] = [] # List of fields in the collection
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
def _check_hybrid_search_support(self) -> bool:
"""
Check if the current Milvus version supports hybrid search.
Returns True if the version is >= 2.5.0, otherwise False.
"""
if not self._client_config.enable_hybrid_search:
return False
try:
milvus_version = self._client.get_server_version()
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
except Exception as e:
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
return False
self._consistency_level = "Session"
self._fields: list[str] = []
def get_type(self) -> str:
"""
Get the type of vector storage (Milvus).
"""
return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""
Create a collection and add texts with embeddings.
"""
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""
Add texts and their embeddings to the collection.
"""
insert_dict_list = []
for i in range(len(documents)):
insert_dict = {
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
# function will automatically convert the native text into a sparse vector for us.
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
@ -120,11 +76,12 @@ class MilvusVector(BaseVector):
insert_dict_list.append(insert_dict)
# Total insert count
total_count = len(insert_dict_list)
pks: list[str] = []
for i in range(0, total_count, 1000):
# Insert into the collection.
batch_insert_list = insert_dict_list[i : i + 1000]
# Insert into the collection.
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
@ -134,9 +91,6 @@ class MilvusVector(BaseVector):
return pks
def get_ids_by_metadata_field(self, key: str, value: str):
"""
Get document IDs by metadata field key and value.
"""
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
)
@ -146,18 +100,12 @@ class MilvusVector(BaseVector):
return None
def delete_by_metadata_field(self, key: str, value: str):
"""
Delete documents by metadata field key and value.
"""
if self._client.has_collection(self._collection_name):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, ids: list[str]) -> None:
"""
Delete documents by their IDs.
"""
if self._client.has_collection(self._collection_name):
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
@ -167,16 +115,10 @@ class MilvusVector(BaseVector):
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self) -> None:
"""
Delete the entire collection.
"""
if self._client.has_collection(self._collection_name):
self._client.drop_collection(self._collection_name, None)
def text_exists(self, id: str) -> bool:
"""
Check if a text with the given ID exists in the collection.
"""
if not self._client.has_collection(self._collection_name):
return False
@ -186,80 +128,32 @@ class MilvusVector(BaseVector):
return len(result) > 0
def field_exists(self, field: str) -> bool:
"""
Check if a field exists in the collection.
"""
return field in self._fields
def _process_search_results(
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
) -> list[Document]:
"""
Common method to process search results
:param results: Search results
:param output_fields: Fields to be output
:param score_threshold: Score threshold for filtering
:return: List of documents
"""
docs = []
for result in results[0]:
metadata = result["entity"].get(output_fields[1], {})
metadata["score"] = result["distance"]
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
docs.append(doc)
return docs
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search for documents by vector similarity.
"""
# Set search parameters.
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
anns_field=Field.VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
# Organize results.
docs = []
for result in results[0]:
metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""
Search for documents by full-text search (if hybrid search is enabled).
"""
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
return []
results = self._client.search(
collection_name=self._collection_name,
data=[query],
anns_field=Field.SPARSE_VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
# milvus/zilliz doesn't support bm25 search
return []
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
"""
Create a new collection in Milvus with the specified schema and index parameters.
"""
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
@ -267,7 +161,7 @@ class MilvusVector(BaseVector):
return
# Grab the existing collection if it exists
if not self._client.has_collection(self._collection_name):
from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
# Determine embedding dim
@ -276,36 +170,16 @@ class MilvusVector(BaseVector):
if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field, enable_analyzer will be set True to support milvus automatically
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
fields.append(
FieldSchema(
Field.CONTENT_KEY.value,
DataType.VARCHAR,
max_length=65_535,
enable_analyzer=self._hybrid_search_enabled,
)
)
# Create the text field
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
# Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
# Create the schema for the collection
schema = CollectionSchema(fields)
# Create custom function to support text to sparse vector by BM25
if self._hybrid_search_enabled:
bm25_function = Function(
name="text_bm25_emb",
input_field_names=[Field.CONTENT_KEY.value],
output_field_names=[Field.SPARSE_VECTOR.value],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
for x in schema.fields:
self._fields.append(x.name)
# Since primary field is auto-id, no need to track it
@ -315,15 +189,10 @@ class MilvusVector(BaseVector):
index_params_obj = IndexParams()
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
index_params_obj.add_index(
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
)
# Create the collection
collection_name = self._collection_name
self._client.create_collection(
collection_name=self._collection_name,
collection_name=collection_name,
schema=schema,
index_params=index_params_obj,
consistency_level=self._consistency_level,
@ -331,22 +200,12 @@ class MilvusVector(BaseVector):
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:
"""
Initialize and return a Milvus client.
"""
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
return client
class MilvusVectorFactory(AbstractVectorFactory):
"""
Factory class for creating MilvusVector instances.
"""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
"""
Initialize a MilvusVector instance for the given dataset.
"""
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
@ -363,6 +222,5 @@ class MilvusVectorFactory(AbstractVectorFactory):
user=dify_config.MILVUS_USER or "",
password=dify_config.MILVUS_PASSWORD or "",
database=dify_config.MILVUS_DATABASE or "",
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
),
)

View File

@ -409,27 +409,27 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
if not tidb_auth_binding:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
else:
idle_tidb_auth_binding = (
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
else:
new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID or "",
@ -451,6 +451,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
db.session.add(new_tidb_auth_binding)
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
else:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"

View File

@ -90,12 +90,6 @@ class Vector:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.ELASTICSEARCH_JA:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
ElasticSearchJaVectorFactory,
)
return ElasticSearchJaVectorFactory
case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory

View File

@ -16,7 +16,6 @@ class VectorType(StrEnum):
TENCENT = "tencent"
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"
ELASTICSEARCH_JA = "elasticsearch-ja"
LINDORM = "lindorm"
COUCHBASE = "couchbase"
BAIDU = "baidu"

View File

@ -31,7 +31,7 @@ class FirecrawlApp:
"markdown": data.get("markdown"),
}
else:
raise Exception(f"Failed to scrape URL. Error: {response_data['error']}")
raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}')
elif response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")

View File

@ -358,7 +358,8 @@ class NotionExtractor(BaseExtractor):
if not data_source_binding:
raise Exception(
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
f"No notion data source binding found for tenant {tenant_id} "
f"and notion workspace {notion_workspace_id}"
)
return cast(str, data_source_binding.access_token)

View File

@ -23,6 +23,7 @@ class PdfExtractor(BaseExtractor):
self._file_cache_key = file_cache_key
def extract(self) -> list[Document]:
plaintext_file_key = ""
plaintext_file_exists = False
if self._file_cache_key:
try:
@ -38,8 +39,8 @@ class PdfExtractor(BaseExtractor):
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and self._file_cache_key:
storage.save(self._file_cache_key, text.encode("utf-8"))
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode("utf-8"))
return documents

View File

@ -112,7 +112,7 @@ class QAIndexProcessor(BaseIndexProcessor):
df = pd.read_csv(file)
text_docs = []
for index, row in df.iterrows():
data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]})
data = Document(page_content=row[0], metadata={"answer": row[1]})
text_docs.append(data)
if len(text_docs) == 0:
raise ValueError("The CSV file is empty.")

View File

@ -127,7 +127,7 @@ class AIPPTGenerateToolAdapter:
response = response.json()
if response.get("code") != 0:
raise Exception(f"Failed to create task: {response.get('msg')}")
raise Exception(f'Failed to create task: {response.get("msg")}')
return response.get("data", {}).get("id")
@ -222,7 +222,7 @@ class AIPPTGenerateToolAdapter:
elif model == "wenxin":
response = response.json()
if response.get("code") != 0:
raise Exception(f"Failed to generate content: {response.get('msg')}")
raise Exception(f'Failed to generate content: {response.get("msg")}')
return response.get("data", "")
@ -254,7 +254,7 @@ class AIPPTGenerateToolAdapter:
response = response.json()
if response.get("code") != 0:
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
id = response.get("data", {}).get("id")
cover_url = response.get("data", {}).get("cover_url")
@ -270,7 +270,7 @@ class AIPPTGenerateToolAdapter:
response = response.json()
if response.get("code") != 0:
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
export_code = response.get("data")
if not export_code:
@ -290,7 +290,7 @@ class AIPPTGenerateToolAdapter:
response = response.json()
if response.get("code") != 0:
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
if response.get("msg") == "导出中":
current_iteration += 1
@ -343,7 +343,7 @@ class AIPPTGenerateToolAdapter:
raise Exception(f"Failed to connect to aippt: {response.text}")
response = response.json()
if response.get("code") != 0:
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
token = response.get("data", {}).get("token")
expire = response.get("data", {}).get("time_expire")
@ -379,7 +379,7 @@ class AIPPTGenerateToolAdapter:
if cls._style_cache[key]["expire"] < now:
del cls._style_cache[key]
key = f"{credentials['aippt_access_key']}#@#{user_id}"
key = f'{credentials["aippt_access_key"]}#@#{user_id}'
if key in cls._style_cache:
return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"]
@ -396,11 +396,11 @@ class AIPPTGenerateToolAdapter:
response = response.json()
if response.get("code") != 0:
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
colors = [
{
"id": f"id-{item.get('id')}",
"id": f'id-{item.get("id")}',
"name": item.get("name"),
"en_name": item.get("en_name", item.get("name")),
}
@ -408,7 +408,7 @@ class AIPPTGenerateToolAdapter:
]
styles = [
{
"id": f"id-{item.get('id')}",
"id": f'id-{item.get("id")}',
"name": item.get("title"),
}
for item in response.get("data", {}).get("suit_style") or []
@ -454,7 +454,7 @@ class AIPPTGenerateToolAdapter:
response = response.json()
if response.get("code") != 0:
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
if len(response.get("data", {}).get("list") or []) > 0:
return response.get("data", {}).get("list")[0].get("id")

View File

@ -14,38 +14,14 @@ class BedrockRetrieveTool(BuiltinTool):
topk: int = None
def _bedrock_retrieve(
self,
query_input: str,
knowledge_base_id: str,
num_results: int,
search_type: str,
rerank_model_id: str,
metadata_filter: Optional[dict] = None,
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
):
try:
retrieval_query = {"text": query_input}
if search_type not in ["HYBRID", "SEMANTIC"]:
raise RuntimeException("search_type should be HYBRID or SEMANTIC")
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
retrieval_configuration = {
"vectorSearchConfiguration": {"numberOfResults": num_results, "overrideSearchType": search_type}
}
if rerank_model_id != "default":
model_for_rerank_arn = f"arn:aws:bedrock:us-west-2::foundation-model/{rerank_model_id}"
rerankingConfiguration = {
"bedrockRerankingConfiguration": {
"numberOfRerankedResults": num_results,
"modelConfiguration": {"modelArn": model_for_rerank_arn},
},
"type": "BEDROCK_RERANKING_MODEL",
}
retrieval_configuration["vectorSearchConfiguration"]["rerankingConfiguration"] = rerankingConfiguration
retrieval_configuration["vectorSearchConfiguration"]["numberOfResults"] = num_results * 5
# 如果有元数据过滤条件,则添加到检索配置中
# Add metadata filter to retrieval configuration if present
if metadata_filter:
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
@ -101,20 +77,15 @@ class BedrockRetrieveTool(BuiltinTool):
if not query:
return self.create_text_message("Please input query")
# 获取元数据过滤条件(如果存在)
# Get metadata filter conditions (if they exist)
metadata_filter_str = tool_parameters.get("metadata_filter")
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
search_type = tool_parameters.get("search_type")
rerank_model_id = tool_parameters.get("rerank_model_id")
line = 4
retrieved_docs = self._bedrock_retrieve(
query_input=query,
knowledge_base_id=self.knowledge_base_id,
num_results=self.topk,
search_type=search_type,
rerank_model_id=rerank_model_id,
metadata_filter=metadata_filter,
)
@ -138,7 +109,7 @@ class BedrockRetrieveTool(BuiltinTool):
if not parameters.get("query"):
raise ValueError("query is required")
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
# Optional: Validate if metadata filter is a valid JSON string (if provided)
metadata_filter_str = parameters.get("metadata_filter")
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
raise ValueError("metadata_filter must be a valid JSON object")

View File

@ -59,57 +59,6 @@ parameters:
max: 10
default: 5
- name: search_type
type: select
required: false
label:
en_US: search type
zh_Hans: 搜索类型
pt_BR: search type
human_description:
en_US: search type
zh_Hans: 搜索类型
pt_BR: search type
llm_description: search type
default: SEMANTIC
options:
- value: SEMANTIC
label:
en_US: SEMANTIC
zh_Hans: 语义搜索
- value: HYBRID
label:
en_US: HYBRID
zh_Hans: 混合搜索
form: form
- name: rerank_model_id
type: select
required: false
label:
en_US: rerank model id
zh_Hans: 重拍模型ID
pt_BR: rerank model id
human_description:
en_US: rerank model id
zh_Hans: 重拍模型ID
pt_BR: rerank model id
llm_description: rerank model id
options:
- value: default
label:
en_US: default
zh_Hans: 默认
- value: cohere.rerank-v3-5:0
label:
en_US: cohere.rerank-v3-5:0
zh_Hans: cohere.rerank-v3-5:0
- value: amazon.rerank-v1:0
label:
en_US: amazon.rerank-v1:0
zh_Hans: amazon.rerank-v1:0
form: form
- name: aws_region
type: string
required: false

View File

@ -229,7 +229,8 @@ class NovaReelTool(BuiltinTool):
if async_mode:
return self.create_text_message(
f"Video generation started.\nInvocation ARN: {invocation_arn}\nVideo will be available at: {video_uri}"
f"Video generation started.\nInvocation ARN: {invocation_arn}\n"
f"Video will be available at: {video_uri}"
)
return self._wait_for_completion(bedrock, s3_client, invocation_arn)

View File

@ -65,7 +65,7 @@ class BaiduFieldTranslateTool(BuiltinTool, BaiduTranslateToolBase):
if "trans_result" in result:
result_text = result["trans_result"][0]["dst"]
else:
result_text = f"{result['error_code']}: {result['error_msg']}"
result_text = f'{result["error_code"]}: {result["error_msg"]}'
return self.create_text_message(str(result_text))
except requests.RequestException as e:

View File

@ -52,7 +52,7 @@ class BaiduLanguageTool(BuiltinTool, BaiduTranslateToolBase):
result_text = ""
if result["error_code"] != 0:
result_text = f"{result['error_code']}: {result['error_msg']}"
result_text = f'{result["error_code"]}: {result["error_msg"]}'
else:
result_text = result["data"]["src"]
result_text = self.mapping_result(description_language, result_text)

View File

@ -58,7 +58,7 @@ class BaiduTranslateTool(BuiltinTool, BaiduTranslateToolBase):
if "trans_result" in result:
result_text = result["trans_result"][0]["dst"]
else:
result_text = f"{result['error_code']}: {result['error_msg']}"
result_text = f'{result["error_code"]}: {result["error_msg"]}'
return self.create_text_message(str(result_text))
except requests.RequestException as e:

View File

@ -30,7 +30,7 @@ class BingSearchTool(BuiltinTool):
headers = {"Ocp-Apim-Subscription-Key": subscription_key, "Accept-Language": accept_language}
query = quote(query)
server_url = f"{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={','.join(filters)}"
server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
response = get(server_url, headers=headers)
if response.status_code != 200:
@ -47,23 +47,23 @@ class BingSearchTool(BuiltinTool):
results = []
if search_results:
for result in search_results:
url = f": {result['url']}" if "url" in result else ""
results.append(self.create_text_message(text=f"{result['name']}{url}"))
url = f': {result["url"]}' if "url" in result else ""
results.append(self.create_text_message(text=f'{result["name"]}{url}'))
if entities:
for entity in entities:
url = f": {entity['url']}" if "url" in entity else ""
results.append(self.create_text_message(text=f"{entity.get('name', '')}{url}"))
url = f': {entity["url"]}' if "url" in entity else ""
results.append(self.create_text_message(text=f'{entity.get("name", "")}{url}'))
if news:
for news_item in news:
url = f": {news_item['url']}" if "url" in news_item else ""
results.append(self.create_text_message(text=f"{news_item.get('name', '')}{url}"))
url = f': {news_item["url"]}' if "url" in news_item else ""
results.append(self.create_text_message(text=f'{news_item.get("name", "")}{url}'))
if related_searches:
for related in related_searches:
url = f": {related['displayText']}" if "displayText" in related else ""
results.append(self.create_text_message(text=f"{related.get('displayText', '')}{url}"))
url = f': {related["displayText"]}' if "displayText" in related else ""
results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}'))
return results
elif result_type == "json":
@ -106,29 +106,29 @@ class BingSearchTool(BuiltinTool):
text = ""
if search_results:
for i, result in enumerate(search_results):
text += f"{i + 1}: {result.get('name', '')} - {result.get('snippet', '')}\n"
text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
if computation and "expression" in computation and "value" in computation:
text += "\nComputation:\n"
text += f"{computation['expression']} = {computation['value']}\n"
text += f'{computation["expression"]} = {computation["value"]}\n'
if entities:
text += "\nEntities:\n"
for entity in entities:
url = f"- {entity['url']}" if "url" in entity else ""
text += f"{entity.get('name', '')}{url}\n"
url = f'- {entity["url"]}' if "url" in entity else ""
text += f'{entity.get("name", "")}{url}\n'
if news:
text += "\nNews:\n"
for news_item in news:
url = f"- {news_item['url']}" if "url" in news_item else ""
text += f"{news_item.get('name', '')}{url}\n"
url = f'- {news_item["url"]}' if "url" in news_item else ""
text += f'{news_item.get("name", "")}{url}\n'
if related_searches:
text += "\n\nRelated Searches:\n"
for related in related_searches:
url = f"- {related['webSearchUrl']}" if "webSearchUrl" in related else ""
text += f"{related.get('displayText', '')}{url}\n"
url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else ""
text += f'{related.get("displayText", "")}{url}\n'
return self.create_text_message(text=self.summary(user_id=user_id, content=text))

View File

@ -83,5 +83,5 @@ class DIDApp:
if status["status"] == "done":
return status
elif status["status"] == "error" or status["status"] == "rejected":
raise HTTPError(f"Talks {id} failed: {status['status']} {status.get('error', {}).get('description')}")
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error", {}).get("description")}')
time.sleep(poll_interval)

View File

@ -20,33 +20,33 @@ class SendEmailToolParameters(BaseModel):
encrypt_method: str
def send_mail(params: SendEmailToolParameters):
def send_mail(parmas: SendEmailToolParameters):
timeout = 60
msg = MIMEMultipart("alternative")
msg["From"] = params.email_account
msg["To"] = params.sender_to
msg["Subject"] = params.subject
msg.attach(MIMEText(params.email_content, "plain"))
msg.attach(MIMEText(params.email_content, "html"))
msg["From"] = parmas.email_account
msg["To"] = parmas.sender_to
msg["Subject"] = parmas.subject
msg.attach(MIMEText(parmas.email_content, "plain"))
msg.attach(MIMEText(parmas.email_content, "html"))
ctx = ssl.create_default_context()
if params.encrypt_method.upper() == "SSL":
if parmas.encrypt_method.upper() == "SSL":
try:
with smtplib.SMTP_SSL(params.smtp_server, params.smtp_port, context=ctx, timeout=timeout) as server:
server.login(params.email_account, params.email_password)
server.sendmail(params.email_account, params.sender_to, msg.as_string())
with smtplib.SMTP_SSL(parmas.smtp_server, parmas.smtp_port, context=ctx, timeout=timeout) as server:
server.login(parmas.email_account, parmas.email_password)
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string())
return True
except Exception as e:
logging.exception("send email failed")
return False
else: # NONE or TLS
try:
with smtplib.SMTP(params.smtp_server, params.smtp_port, timeout=timeout) as server:
if params.encrypt_method.upper() == "TLS":
with smtplib.SMTP(parmas.smtp_server, parmas.smtp_port, timeout=timeout) as server:
if parmas.encrypt_method.upper() == "TLS":
server.starttls(context=ctx)
server.login(params.email_account, params.email_password)
server.sendmail(params.email_account, params.sender_to, msg.as_string())
server.login(parmas.email_account, parmas.email_password)
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string())
return True
except Exception as e:
logging.exception("send email failed")

View File

@ -74,7 +74,7 @@ class FirecrawlApp:
if response is None:
raise HTTPError("Failed to initiate crawl after multiple retries")
elif response.get("success") == False:
raise HTTPError(f"Failed to crawl: {response.get('error')}")
raise HTTPError(f'Failed to crawl: {response.get("error")}')
job_id: str = response["id"]
if wait:
return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval)
@ -100,7 +100,7 @@ class FirecrawlApp:
if status["status"] == "completed":
return status
elif status["status"] == "failed":
raise HTTPError(f"Job {job_id} failed: {status['error']}")
raise HTTPError(f'Job {job_id} failed: {status["error"]}')
time.sleep(poll_interval)

View File

@ -37,9 +37,8 @@ class GaodeRepositoriesTool(BuiltinTool):
CityCode = City_data["districts"][0]["adcode"]
weatherInfo_response = s.request(
method="GET",
url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json".format(
url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")
),
url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json"
"".format(url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")),
)
weatherInfo_data = weatherInfo_response.json()
if weatherInfo_response.status_code == 200 and weatherInfo_data.get("info") == "OK":

View File

@ -11,21 +11,19 @@ class GitlabFilesTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
repository = tool_parameters.get("repository", "")
project = tool_parameters.get("project", "")
repository = tool_parameters.get("repository", "")
branch = tool_parameters.get("branch", "")
path = tool_parameters.get("path", "")
file_path = tool_parameters.get("file_path", "")
if not repository and not project:
return self.create_text_message("Either repository or project is required")
if not project and not repository:
return self.create_text_message("Either project or repository is required")
if not branch:
return self.create_text_message("Branch is required")
if not path and not file_path:
return self.create_text_message("Either path or file_path is required")
if not path:
return self.create_text_message("Path is required")
access_token = self.runtime.credentials.get("access_tokens")
headers = {"PRIVATE-TOKEN": access_token}
site_url = self.runtime.credentials.get("site_url")
if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"):
@ -33,45 +31,33 @@ class GitlabFilesTool(BuiltinTool):
if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"):
site_url = "https://gitlab.com"
if repository:
# URL encode the repository path
identifier = urllib.parse.quote(repository, safe="")
else:
identifier = self.get_project_id(site_url, access_token, project)
if not identifier:
raise Exception(f"Project '{project}' not found.)")
# Get file content
if path:
results = self.fetch_files(site_url, headers, identifier, branch, path)
return [self.create_json_message(item) for item in results]
if repository:
result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True)
else:
result = self.fetch_file(site_url, headers, identifier, branch, file_path)
return [self.create_json_message(result)]
result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False)
@staticmethod
def fetch_file(
site_url: str,
headers: dict[str, str],
identifier: str,
branch: str,
path: str,
) -> dict[str, Any]:
encoded_file_path = urllib.parse.quote(path, safe="")
file_url = f"{site_url}/api/v4/projects/{identifier}/repository/files/{encoded_file_path}/raw?ref={branch}"
file_response = requests.get(file_url, headers=headers)
file_response.raise_for_status()
file_content = file_response.text
return {"path": path, "branch": branch, "content": file_content}
return [self.create_json_message(item) for item in result]
def fetch_files(
self, site_url: str, headers: dict[str, str], identifier: str, branch: str, path: str
self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool
) -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = []
try:
tree_url = f"{site_url}/api/v4/projects/{identifier}/repository/tree?path={path}&ref={branch}"
if is_repository:
# URL encode the repository path
encoded_identifier = urllib.parse.quote(identifier, safe="")
tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}"
else:
# Get project ID from project name
project_id = self.get_project_id(site_url, access_token, identifier)
if not project_id:
return self.create_text_message(f"Project '{identifier}' not found.")
tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
response = requests.get(tree_url, headers=headers)
response.raise_for_status()
items = response.json()
@ -79,10 +65,26 @@ class GitlabFilesTool(BuiltinTool):
for item in items:
item_path = item["path"]
if item["type"] == "tree": # It's a directory
results.extend(self.fetch_files(site_url, headers, identifier, branch, item_path))
results.extend(
self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository)
)
else: # It's a file
result = self.fetch_file(site_url, headers, identifier, branch, item_path)
results.append(result)
encoded_item_path = urllib.parse.quote(item_path, safe="")
if is_repository:
file_url = (
f"{domain}/api/v4/projects/{encoded_identifier}/repository/files"
f"/{encoded_item_path}/raw?ref={branch}"
)
else:
file_url = (
f"{domain}/api/v4/projects/{project_id}/repository/files"
f"{encoded_item_path}/raw?ref={branch}"
)
file_response = requests.get(file_url, headers=headers)
file_response.raise_for_status()
file_content = file_response.text
results.append({"path": item_path, "branch": branch, "content": file_content})
except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}")

View File

@ -29,7 +29,7 @@ parameters:
zh_Hans: 项目
human_description:
en_US: project
zh_Hans: 项目(和仓库路径二选一,都填写以仓库路径优先)
zh_Hans: 项目
llm_description: Project for GitLab
form: llm
- name: branch
@ -45,21 +45,12 @@ parameters:
form: llm
- name: path
type: string
required: true
label:
en_US: path
zh_Hans: 文件夹
human_description:
en_US: path
zh_Hans: 文件夹
llm_description: Dir path for GitLab
form: llm
- name: file_path
type: string
label:
en_US: file_path
zh_Hans: 文件路径
human_description:
en_US: file_path
zh_Hans: 文件路径(和文件夹二选一,都填写以文件夹优先)
en_US: path
zh_Hans: 文件路径
llm_description: File path for GitLab
form: llm

View File

@ -110,7 +110,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
result["rows"].append(self.get_row_field_value(row, schema))
return self.create_text_message(json.dumps(result, ensure_ascii=False))
else:
result_text = f'Found {result["total"]} rows in worksheet "{worksheet_name}".'
result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"."
if result["total"] > 0:
result_text += (
f" The following are {min(limit, result['total'])}"

View File

@ -28,4 +28,4 @@ class BaseStabilityAuthorization:
"""
This method is responsible for generating the authorization headers.
"""
return {"Authorization": f"Bearer {credentials.get('api_key', '')}"}
return {"Authorization": f'Bearer {credentials.get("api_key", "")}'}

View File

@ -38,7 +38,7 @@ class VannaProvider(BuiltinToolProviderController):
tool_parameters={
"model": "chinook",
"db_type": "SQLite",
"url": f"{self._get_protocol_and_main_domain(credentials['base_url'])}/Chinook.sqlite",
"url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite',
"query": "What are the top 10 customers by sales?",
},
)

View File

@ -43,7 +43,7 @@ class SerplyApi:
def parse_results(res: dict) -> str:
"""Process response from Serply Job Search."""
jobs = res.get("jobs", [])
if not res or "jobs" not in res:
if not jobs:
raise ValueError(f"Got error from Serply: {res}")
string = []

View File

@ -43,7 +43,7 @@ class SerplyApi:
def parse_results(res: dict) -> str:
"""Process response from Serply News Search."""
news = res.get("entries", [])
if not res or "entries" not in res:
if not news:
raise ValueError(f"Got error from Serply: {res}")
string = []

View File

@ -43,7 +43,7 @@ class SerplyApi:
def parse_results(res: dict) -> str:
"""Process response from Serply News Search."""
articles = res.get("articles", [])
if not res or "articles" not in res:
if not articles:
raise ValueError(f"Got error from Serply: {res}")
string = []

View File

@ -42,7 +42,7 @@ class SerplyApi:
def parse_results(res: dict) -> str:
"""Process response from Serply Web Search."""
results = res.get("results", [])
if not res or "results" not in res:
if not results:
raise ValueError(f"Got error from Serply: {res}")
string = []

View File

@ -84,9 +84,9 @@ class ApiTool(Tool):
if "api_key_header_prefix" in credentials:
api_key_header_prefix = credentials["api_key_header_prefix"]
if api_key_header_prefix == "basic" and credentials["api_key_value"]:
credentials["api_key_value"] = f"Basic {credentials['api_key_value']}"
credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}'
elif api_key_header_prefix == "bearer" and credentials["api_key_value"]:
credentials["api_key_value"] = f"Bearer {credentials['api_key_value']}"
credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}'
elif api_key_header_prefix == "custom":
pass

Some files were not shown because too many files have changed in this diff Show More