mirror of
https://github.com/langgenius/dify.git
synced 2026-02-10 13:25:42 +08:00
Compare commits
2 Commits
0.15.2
...
provider-g
| Author | SHA1 | Date | |
|---|---|---|---|
| df82bccfdf | |||
| 5ba875f96b |
2
.github/actions/setup-poetry/action.yml
vendored
2
.github/actions/setup-poetry/action.yml
vendored
@ -8,7 +8,7 @@ inputs:
|
||||
poetry-version:
|
||||
description: Poetry version to set up
|
||||
required: true
|
||||
default: '2.0.1'
|
||||
default: '1.8.4'
|
||||
poetry-lockfile:
|
||||
description: Path to the Poetry lockfile to restore cache from
|
||||
required: true
|
||||
|
||||
16
.github/workflows/api-tests.yml
vendored
16
.github/workflows/api-tests.yml
vendored
@ -42,23 +42,25 @@ jobs:
|
||||
run: poetry install -C api --with dev
|
||||
|
||||
- name: Check dependencies in pyproject.toml
|
||||
run: poetry run -P api bash dev/pytest/pytest_artifacts.sh
|
||||
run: poetry run -C api bash dev/pytest/pytest_artifacts.sh
|
||||
|
||||
- name: Run Unit tests
|
||||
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh
|
||||
run: poetry run -C api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run ModelRuntime
|
||||
run: poetry run -P api bash dev/pytest/pytest_model_runtime.sh
|
||||
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
- name: Run dify config tests
|
||||
run: poetry run -P api python dev/pytest/pytest_config_tests.py
|
||||
run: poetry run -C api python dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: Run Tool
|
||||
run: poetry run -P api bash dev/pytest/pytest_tools.sh
|
||||
run: poetry run -C api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
- name: Run mypy
|
||||
run: |
|
||||
poetry run -C api python -m mypy --install-types --non-interactive .
|
||||
pushd api
|
||||
poetry run python -m mypy --install-types --non-interactive .
|
||||
popd
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
@ -78,4 +80,4 @@ jobs:
|
||||
ssrf_proxy
|
||||
|
||||
- name: Run Workflow
|
||||
run: poetry run -P api bash dev/pytest/pytest_workflow.sh
|
||||
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
|
||||
|
||||
33
.github/workflows/style.yml
vendored
33
.github/workflows/style.yml
vendored
@ -38,12 +38,12 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
poetry run -C api ruff --version
|
||||
poetry run -C api ruff check ./
|
||||
poetry run -C api ruff format --check ./
|
||||
poetry run -C api ruff check ./api
|
||||
poetry run -C api ruff format --check ./api
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry run -P api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
@ -82,33 +82,6 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: yarn run lint
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: |
|
||||
docker/generate_docker_compose
|
||||
docker/.env.example
|
||||
docker/docker-compose-template.yaml
|
||||
docker/docker-compose.yaml
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd docker
|
||||
./generate_docker_compose
|
||||
|
||||
- name: Check for changes
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: git diff --exit-code
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
|
||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@ -70,4 +70,4 @@ jobs:
|
||||
tidb
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -P api bash dev/pytest/pytest_vdb.sh
|
||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
@ -53,12 +53,10 @@ ignore = [
|
||||
"FURB152", # math-constant
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"UP045", # non-pep604-annotation-optional
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
"B903", # class-as-data-structure
|
||||
"B904", # raise-without-from-inside-except
|
||||
"B905", # zip-without-explicit-strict
|
||||
"N806", # non-lowercase-variable-in-function
|
||||
|
||||
@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
|
||||
WORKDIR /app/api
|
||||
|
||||
# Install Poetry
|
||||
ENV POETRY_VERSION=2.0.1
|
||||
ENV POETRY_VERSION=1.8.4
|
||||
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/
|
||||
|
||||
@ -79,5 +79,5 @@
|
||||
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
|
||||
|
||||
```bash
|
||||
poetry run -P api bash dev/pytest/pytest_all_tests.sh
|
||||
poetry run -C api bash dev/pytest/pytest_all_tests.sh
|
||||
```
|
||||
|
||||
@ -146,7 +146,7 @@ class EndpointConfig(BaseSettings):
|
||||
)
|
||||
|
||||
CONSOLE_WEB_URL: str = Field(
|
||||
description="Base URL for the console web interface,used for frontend references and CORS configuration",
|
||||
description="Base URL for the console web interface," "used for frontend references and CORS configuration",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
@ -181,7 +181,7 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
|
||||
description="Mode for fetching app templates: remote, db, or builtin default to remote,",
|
||||
description="Mode for fetching app templates: remote, db, or builtin" " default to remote,",
|
||||
default="remote",
|
||||
)
|
||||
|
||||
|
||||
@ -33,9 +33,3 @@ class MilvusConfig(BaseSettings):
|
||||
description="Name of the Milvus database to connect to (default is 'default')",
|
||||
default="default",
|
||||
)
|
||||
|
||||
MILVUS_ENABLE_HYBRID_SEARCH: bool = Field(
|
||||
description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with "
|
||||
"older versions",
|
||||
default=True,
|
||||
)
|
||||
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.15.2",
|
||||
default="0.14.2",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -56,7 +56,7 @@ class InsertExploreAppListApi(Resource):
|
||||
|
||||
app = App.query.filter(App.id == args["app_id"]).first()
|
||||
if not app:
|
||||
raise NotFound(f"App '{args['app_id']}' is not found")
|
||||
raise NotFound(f'App \'{args["app_id"]}\' is not found')
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
|
||||
@ -22,7 +22,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from models import App, AppMode
|
||||
from models.model import AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -79,7 +79,7 @@ class ChatMessageTextApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
try:
|
||||
@ -98,13 +98,9 @@ class ChatMessageTextApi(Resource):
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
if text_to_speech is None:
|
||||
raise ValueError("TTS is not enabled")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
if app_model.app_model_config is None:
|
||||
raise ValueError("AppModelConfig not found")
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
|
||||
@ -52,12 +52,12 @@ class DatasetListApi(Resource):
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
|
||||
page, limit, current_user.current_tenant_id, current_user, search, tag_ids
|
||||
)
|
||||
|
||||
# check embedding setting
|
||||
@ -457,7 +457,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -619,7 +619,8 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.RELYT
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
@ -639,12 +640,10 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
| VectorType.LINDORM
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.MILVUS
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
@ -684,7 +683,6 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.LINDORM
|
||||
|
||||
@ -257,8 +257,7 @@ class DatasetDocumentListApi(Resource):
|
||||
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
@ -350,7 +349,8 @@ class DatasetInitApi(Resource):
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -525,7 +525,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
return response.model_dump(), 200
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@ -168,7 +168,8 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -216,7 +217,8 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -265,7 +267,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -365,9 +368,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
result = []
|
||||
for index, row in df.iterrows():
|
||||
if document.doc_form == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
data = {"content": row[0], "answer": row[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
data = {"content": row[0]}
|
||||
result.append(data)
|
||||
if len(result) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
@ -434,7 +437,8 @@ class ChildChunkAddApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@ -32,7 +32,7 @@ class ConversationListApi(InstalledAppResource):
|
||||
|
||||
pinned = None
|
||||
if "pinned" in args and args["pinned"] is not None:
|
||||
pinned = args["pinned"] == "true"
|
||||
pinned = True if args["pinned"] == "true" else False
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
|
||||
@ -7,4 +7,4 @@ api = ExternalApi(bp)
|
||||
|
||||
from . import index
|
||||
from .app import app, audio, completion, conversation, file, message, workflow
|
||||
from .dataset import dataset, document, hit_testing, segment, upload_file
|
||||
from .dataset import dataset, document, hit_testing, segment
|
||||
|
||||
@ -31,11 +31,8 @@ class DatasetListApi(DatasetApiResource):
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, tenant_id, current_user, search, tag_ids, include_all
|
||||
)
|
||||
datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
|
||||
@ -53,7 +53,8 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -94,7 +95,8 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -173,7 +175,8 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@ -1,54 +0,0 @@
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
class UploadFileApi(DatasetApiResource):
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
"""Get upload file."""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check upload file
|
||||
if document.data_source_type != "upload_file":
|
||||
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
else:
|
||||
raise ValueError("Upload file id not found in document data source info.")
|
||||
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": url,
|
||||
"download_url": f"{url}&as_attachment=true",
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at.timestamp(),
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
@ -8,8 +8,6 @@ from flask import current_app, request
|
||||
from flask_login import user_logged_in # type: ignore
|
||||
from flask_restful import Resource # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -176,7 +174,7 @@ def validate_dataset_token(view=None):
|
||||
return decorator
|
||||
|
||||
|
||||
def validate_and_get_api_token(scope: str | None = None):
|
||||
def validate_and_get_api_token(scope=None):
|
||||
"""
|
||||
Validate and get API token.
|
||||
"""
|
||||
@ -190,29 +188,20 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
current_time = datetime.now(UTC).replace(tzinfo=None)
|
||||
cutoff_time = current_time - timedelta(minutes=1)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
.where(
|
||||
ApiToken.token == auth_token,
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
|
||||
ApiToken.type == scope,
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
.returning(ApiToken)
|
||||
api_token = (
|
||||
db.session.query(ApiToken)
|
||||
.filter(
|
||||
ApiToken.token == auth_token,
|
||||
ApiToken.type == scope,
|
||||
)
|
||||
result = session.execute(update_stmt)
|
||||
api_token = result.scalar_one_or_none()
|
||||
.first()
|
||||
)
|
||||
|
||||
if not api_token:
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
api_token = session.scalar(stmt)
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
else:
|
||||
session.commit()
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return api_token
|
||||
|
||||
@ -240,7 +229,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == "DEFAULT-USER",
|
||||
is_anonymous=True if user_id == "DEFAULT-USER" else False,
|
||||
session_id=user_id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
|
||||
@ -39,7 +39,7 @@ class ConversationListApi(WebApiResource):
|
||||
|
||||
pinned = None
|
||||
if "pinned" in args and args["pinned"] is not None:
|
||||
pinned = args["pinned"] == "true"
|
||||
pinned = True if args["pinned"] == "true" else False
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
|
||||
@ -172,7 +172,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else "",
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought or "",
|
||||
|
||||
@ -167,7 +167,8 @@ class AppQueueManager:
|
||||
else:
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
|
||||
raise TypeError(
|
||||
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
|
||||
"Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -89,7 +89,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.status == "normal",
|
||||
Conversation.is_deleted.is_(False),
|
||||
]
|
||||
|
||||
if isinstance(user, Account):
|
||||
|
||||
@ -145,7 +145,7 @@ class MessageCycleManage:
|
||||
|
||||
# get extension
|
||||
if "." in message_file.url:
|
||||
extension = f".{message_file.url.split('.')[-1]}"
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
if len(extension) > 10:
|
||||
extension = ".bin"
|
||||
else:
|
||||
|
||||
@ -62,9 +62,8 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError(
|
||||
"[External data tool] API query failed, variable: {}, error: api_based_extension_id is invalid".format(
|
||||
self.variable
|
||||
)
|
||||
"[External data tool] API query failed, variable: {}, "
|
||||
"error: api_based_extension_id is invalid".format(self.variable)
|
||||
)
|
||||
|
||||
# decrypt api_key
|
||||
|
||||
@ -90,7 +90,7 @@ class File(BaseModel):
|
||||
def markdown(self) -> str:
|
||||
url = self.generate_url()
|
||||
if self.type == FileType.IMAGE:
|
||||
text = f""
|
||||
text = f''
|
||||
else:
|
||||
text = f"[{self.filename or url}]({url})"
|
||||
|
||||
|
||||
@ -530,6 +530,7 @@ class IndexingRunner:
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
tokens = 0
|
||||
chunk_size = 10
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
|
||||
# create keyword index
|
||||
create_keyword_thread = threading.Thread(
|
||||
@ -538,22 +539,11 @@ class IndexingRunner:
|
||||
)
|
||||
create_keyword_thread.start()
|
||||
|
||||
max_workers = 10
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = []
|
||||
|
||||
# Distribute documents into multiple groups based on the hash values of page_content
|
||||
# This is done to prevent multiple threads from processing the same document,
|
||||
# Thereby avoiding potential database insertion deadlocks
|
||||
document_groups: list[list[Document]] = [[] for _ in range(max_workers)]
|
||||
for document in documents:
|
||||
hash = helper.generate_text_hash(document.page_content)
|
||||
group_index = int(hash, 16) % max_workers
|
||||
document_groups[group_index].append(document)
|
||||
for chunk_documents in document_groups:
|
||||
if len(chunk_documents) == 0:
|
||||
continue
|
||||
for i in range(0, len(documents), chunk_size):
|
||||
chunk_documents = documents[i : i + chunk_size]
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self._process_chunk,
|
||||
|
||||
@ -131,7 +131,7 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keeping each question under 20 characters.\n"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response. "
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response"
|
||||
"The output must be an array in JSON format following the specified schema:\n"
|
||||
'["question1","question2","question3"]\n'
|
||||
)
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
import logging
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from os.path import abspath, dirname, join
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
||||
|
||||
_tokenizer: Any = None
|
||||
_lock = Lock()
|
||||
_executor = ProcessPoolExecutor(max_workers=1)
|
||||
|
||||
|
||||
class GPT2Tokenizer:
|
||||
@ -15,37 +17,22 @@ class GPT2Tokenizer:
|
||||
use gpt2 tokenizer to get num tokens
|
||||
"""
|
||||
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||
tokens = _tokenizer.encode(text)
|
||||
tokens = _tokenizer.encode(text, verbose=False)
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
def get_num_tokens(text: str) -> int:
|
||||
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
|
||||
#
|
||||
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
||||
# result = future.result()
|
||||
# return cast(int, result)
|
||||
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
||||
future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
||||
result = future.result()
|
||||
return cast(int, result)
|
||||
|
||||
@staticmethod
|
||||
def get_encoder() -> Any:
|
||||
global _tokenizer, _lock
|
||||
with _lock:
|
||||
if _tokenizer is None:
|
||||
# Try to use tiktoken to get the tokenizer because it is faster
|
||||
#
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
_tokenizer = tiktoken.get_encoding("gpt2")
|
||||
except Exception:
|
||||
from os.path import abspath, dirname, join
|
||||
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
||||
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken")
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
|
||||
return _tokenizer
|
||||
|
||||
@ -108,7 +108,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
|
||||
if not ai_model_entity:
|
||||
raise CredentialsValidateFailedError(f"Base Model Name {credentials['base_model_name']} is invalid")
|
||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||
|
||||
try:
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
||||
@ -130,7 +130,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
raise CredentialsValidateFailedError("Base Model Name is required")
|
||||
|
||||
if not self._get_ai_model_entity(credentials["base_model_name"], model):
|
||||
raise CredentialsValidateFailedError(f"Base Model Name {credentials['base_model_name']} is invalid")
|
||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||
|
||||
try:
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
@ -70,7 +70,7 @@ class BedrockRerankModel(RerankModel):
|
||||
rerankingConfiguration = {
|
||||
"type": "BEDROCK_RERANKING_MODEL",
|
||||
"bedrockRerankingConfiguration": {
|
||||
"numberOfResults": min(top_n, len(text_sources)),
|
||||
"numberOfResults": top_n,
|
||||
"modelConfiguration": {
|
||||
"modelArn": model_package_arn,
|
||||
},
|
||||
|
||||
@ -1,3 +1,2 @@
|
||||
- deepseek-chat
|
||||
- deepseek-coder
|
||||
- deepseek-reasoner
|
||||
|
||||
@ -10,7 +10,7 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 64000
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -10,7 +10,7 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 64000
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -1,21 +0,0 @@
|
||||
model: deepseek-reasoner
|
||||
label:
|
||||
zh_Hans: deepseek-reasoner
|
||||
en_US: deepseek-reasoner
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 64000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 8192
|
||||
default: 4096
|
||||
pricing:
|
||||
input: "4"
|
||||
output: "16"
|
||||
unit: "0.000001"
|
||||
currency: RMB
|
||||
@ -24,6 +24,9 @@ class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
# {"response_format": "xx"} need convert to {"response_format": {"type": "xx"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
- gemini-2.0-flash-exp
|
||||
- gemini-2.0-flash-thinking-exp-1219
|
||||
- gemini-2.0-flash-thinking-exp-01-21
|
||||
- gemini-1.5-pro
|
||||
- gemini-1.5-pro-latest
|
||||
- gemini-1.5-pro-001
|
||||
|
||||
@ -1,39 +0,0 @@
|
||||
model: gemini-2.0-flash-thinking-exp-01-21
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Thinking Exp 01-21
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -162,9 +162,9 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
|
||||
@staticmethod
|
||||
def _check_endpoint_url_model_repository_name(credentials: dict, model_name: str):
|
||||
try:
|
||||
url = f"{HUGGINGFACE_ENDPOINT_API}{credentials['huggingface_namespace']}"
|
||||
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credentials['huggingfacehub_api_token']}",
|
||||
"Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}',
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
@ -34,7 +34,6 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
|
||||
|
||||
class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
model_apis = {
|
||||
"minimax-text-01": MinimaxChatCompletionPro,
|
||||
"abab7-chat-preview": MinimaxChatCompletionPro,
|
||||
"abab6.5t-chat": MinimaxChatCompletionPro,
|
||||
"abab6.5s-chat": MinimaxChatCompletionPro,
|
||||
|
||||
@ -1,46 +0,0 @@
|
||||
model: minimax-text-01
|
||||
label:
|
||||
en_US: Minimax-Text-01
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1000192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.95
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 1000192
|
||||
- name: mask_sensitive_info
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
zh_Hans: 隐私保护
|
||||
en_US: Moderate
|
||||
help:
|
||||
zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码
|
||||
en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.008'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@ -44,6 +44,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
self._add_custom_parameters(credentials)
|
||||
self._add_function_call(model, credentials)
|
||||
user = user[:32] if user else None
|
||||
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
@ -622,19 +621,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||
|
||||
# o1 compatibility
|
||||
block_as_stream = False
|
||||
if model.startswith("o1"):
|
||||
if "max_tokens" in model_parameters:
|
||||
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
|
||||
del model_parameters["max_tokens"]
|
||||
|
||||
if re.match(r"^o1(-\d{4}-\d{2}-\d{2})?$", model):
|
||||
if stream:
|
||||
block_as_stream = True
|
||||
stream = False
|
||||
if "stream_options" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stream_options"]
|
||||
|
||||
if "stop" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stop"]
|
||||
|
||||
@ -651,45 +642,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
if block_as_stream:
|
||||
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
|
||||
|
||||
return block_result
|
||||
|
||||
def _handle_chat_block_as_stream_response(
|
||||
self,
|
||||
block_result: LLMResult,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Handle llm chat response
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
text = block_result.message.content
|
||||
text = cast(str, text)
|
||||
|
||||
if stop:
|
||||
text = self.enforce_stop_tokens(text, stop)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=block_result.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=block_result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=block_result.message,
|
||||
finish_reason="stop",
|
||||
usage=block_result.usage,
|
||||
),
|
||||
)
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
|
||||
@ -7,7 +7,6 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
||||
@ -7,8 +7,6 @@
|
||||
- Qwen/Qwen2.5-Coder-7B-Instruct
|
||||
- Qwen/Qwen2-VL-72B-Instruct
|
||||
- Qwen/Qwen2-1.5B-Instruct
|
||||
- Qwen/Qwen2.5-72B-Instruct-128K
|
||||
- Vendor-A/Qwen/Qwen2.5-72B-Instruct
|
||||
- Pro/Qwen/Qwen2-VL-7B-Instruct
|
||||
- OpenGVLab/InternVL2-26B
|
||||
- Pro/OpenGVLab/InternVL2-8B
|
||||
|
||||
@ -29,6 +29,9 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
||||
@ -1,51 +0,0 @@
|
||||
model: Qwen/Qwen2.5-72B-Instruct-128K
|
||||
label:
|
||||
en_US: Qwen/Qwen2.5-72B-Instruct-128K
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '4.13'
|
||||
output: '4.13'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@ -1,51 +0,0 @@
|
||||
model: Vendor-A/Qwen/Qwen2.5-72B-Instruct
|
||||
label:
|
||||
en_US: Vendor-A/Qwen/Qwen2.5-72B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '1.00'
|
||||
output: '1.00'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@ -15,7 +15,7 @@ parameter_rules:
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
|
||||
@ -1,37 +0,0 @@
|
||||
model: fishaudio/fish-speech-1.5
|
||||
model_type: tts
|
||||
model_properties:
|
||||
default_voice: 'fishaudio/fish-speech-1.5:alex'
|
||||
voices:
|
||||
- mode: "fishaudio/fish-speech-1.5:alex"
|
||||
name: "Alex(男声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
- mode: "fishaudio/fish-speech-1.5:benjamin"
|
||||
name: "Benjamin(男声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
- mode: "fishaudio/fish-speech-1.5:charles"
|
||||
name: "Charles(男声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
- mode: "fishaudio/fish-speech-1.5:david"
|
||||
name: "David(男声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
- mode: "fishaudio/fish-speech-1.5:anna"
|
||||
name: "Anna(女声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
- mode: "fishaudio/fish-speech-1.5:bella"
|
||||
name: "Bella(女声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
- mode: "fishaudio/fish-speech-1.5:claire"
|
||||
name: "Claire(女声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
- mode: "fishaudio/fish-speech-1.5:diana"
|
||||
name: "Diana(女声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
# stream: false
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@ -21,7 +21,7 @@ class SparkLLMClient:
|
||||
domain = api_domain
|
||||
|
||||
model_api_configs = {
|
||||
"spark-lite": {"version": "v1.1", "chat_domain": "lite"},
|
||||
"spark-lite": {"version": "v1.1", "chat_domain": "general"},
|
||||
"spark-pro": {"version": "v3.1", "chat_domain": "generalv3"},
|
||||
"spark-pro-128k": {"version": "pro-128k", "chat_domain": "pro-128k"},
|
||||
"spark-max": {"version": "v3.5", "chat_domain": "generalv3.5"},
|
||||
|
||||
@ -257,7 +257,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
for index, response in enumerate(responses):
|
||||
if response.status_code not in {200, HTTPStatus.OK}:
|
||||
raise ServiceUnavailableError(
|
||||
f"Failed to invoke model {model}, status code: {response.status_code}, message: {response.message}"
|
||||
f"Failed to invoke model {model}, status code: {response.status_code}, "
|
||||
f"message: {response.message}"
|
||||
)
|
||||
|
||||
resp_finish_reason = response.output.choices[0].finish_reason
|
||||
|
||||
@ -146,7 +146,7 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
elif credentials["completion_type"] == "completion":
|
||||
completion_type = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f"completion_type {credentials['completion_type']} is not supported")
|
||||
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
|
||||
@ -18,93 +18,72 @@ class ModelConfig(BaseModel):
|
||||
|
||||
|
||||
configs: dict[str, ModelConfig] = {
|
||||
"Doubao-1.5-vision-pro-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=12288, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.VISION],
|
||||
),
|
||||
"Doubao-1.5-pro-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=12288, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
),
|
||||
"Doubao-1.5-lite-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=12288, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
),
|
||||
"Doubao-1.5-pro-256k": ModelConfig(
|
||||
properties=ModelProperties(context_size=262144, max_tokens=12288, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
),
|
||||
"Doubao-vision-pro-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.VISION],
|
||||
features=[ModelFeature.VISION],
|
||||
),
|
||||
"Doubao-vision-lite-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.VISION],
|
||||
features=[ModelFeature.VISION],
|
||||
),
|
||||
"Doubao-pro-4k": ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Doubao-lite-4k": ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Doubao-pro-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Doubao-lite-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Doubao-pro-256k": ModelConfig(
|
||||
properties=ModelProperties(context_size=262144, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
features=[],
|
||||
),
|
||||
"Doubao-pro-128k": ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Doubao-lite-128k": ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[]
|
||||
),
|
||||
"Skylark2-pro-4k": ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), features=[]
|
||||
),
|
||||
"Llama3-8B": ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[]
|
||||
),
|
||||
"Llama3-70B": ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[]
|
||||
),
|
||||
"Moonshot-v1-8k": ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Moonshot-v1-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Moonshot-v1-128k": ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"GLM3-130B": ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"GLM3-130B-Fin": ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Mistral-7B": ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), features=[]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@ -118,30 +118,6 @@ model_credential_schema:
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- label:
|
||||
en_US: Doubao-1.5-vision-pro-32k
|
||||
value: Doubao-1.5-vision-pro-32k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-1.5-pro-32k
|
||||
value: Doubao-1.5-pro-32k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-1.5-lite-32k
|
||||
value: Doubao-1.5-lite-32k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-1.5-pro-256k
|
||||
value: Doubao-1.5-pro-256k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-vision-pro-32k
|
||||
value: Doubao-vision-pro-32k
|
||||
|
||||
@ -41,15 +41,15 @@ class BaiduAccessToken:
|
||||
resp = response.json()
|
||||
if "error" in resp:
|
||||
if resp["error"] == "invalid_client":
|
||||
raise InvalidAPIKeyError(f"Invalid API key or secret key: {resp['error_description']}")
|
||||
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
|
||||
elif resp["error"] == "unknown_error":
|
||||
raise InternalServerError(f"Internal server error: {resp['error_description']}")
|
||||
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
|
||||
elif resp["error"] == "invalid_request":
|
||||
raise BadRequestError(f"Bad request: {resp['error_description']}")
|
||||
raise BadRequestError(f'Bad request: {resp["error_description"]}')
|
||||
elif resp["error"] == "rate_limit_exceeded":
|
||||
raise RateLimitReachedError(f"Rate limit reached: {resp['error_description']}")
|
||||
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
|
||||
else:
|
||||
raise Exception(f"Unknown error: {resp['error_description']}")
|
||||
raise Exception(f'Unknown error: {resp["error_description"]}')
|
||||
|
||||
return resp["access_token"]
|
||||
|
||||
|
||||
@ -406,7 +406,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
elif credentials["completion_type"] == "completion":
|
||||
completion_type = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f"completion_type {credentials['completion_type']} is not supported")
|
||||
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
|
||||
else:
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials["server_url"],
|
||||
@ -472,7 +472,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
api_key = credentials.get("api_key") or "abc"
|
||||
|
||||
client = OpenAI(
|
||||
base_url=f"{credentials['server_url']}/v1",
|
||||
base_url=f'{credentials["server_url"]}/v1',
|
||||
api_key=api_key,
|
||||
max_retries=int(credentials.get("max_retries") or DEFAULT_MAX_RETRIES),
|
||||
timeout=int(credentials.get("invoke_timeout") or DEFAULT_INVOKE_TIMEOUT),
|
||||
|
||||
@ -87,6 +87,6 @@ class CommonValidator:
|
||||
if value.lower() not in {"true", "false"}:
|
||||
raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
|
||||
|
||||
value = value.lower() == "true"
|
||||
value = True if value.lower() == "true" else False
|
||||
|
||||
return value
|
||||
|
||||
@ -6,7 +6,6 @@ from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
class TracingProviderEnum(Enum):
|
||||
LANGFUSE = "langfuse"
|
||||
LANGSMITH = "langsmith"
|
||||
OPIK = "opik"
|
||||
|
||||
|
||||
class BaseTracingConfig(BaseModel):
|
||||
@ -57,36 +56,5 @@ class LangSmithConfig(BaseTracingConfig):
|
||||
return v
|
||||
|
||||
|
||||
class OpikConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Opik tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
workspace: str | None = None
|
||||
url: str = "https://www.comet.com/opik/api/"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "Default Project"
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_validator(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://www.comet.com/opik/api/"
|
||||
if not v.startswith(("https://", "http://")):
|
||||
raise ValueError("url must start with https:// or http://")
|
||||
if not v.endswith("/api/"):
|
||||
raise ValueError("url should ends with /api/")
|
||||
|
||||
return v
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||
|
||||
@ -1,469 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, cast
|
||||
|
||||
from opik import Opik, Trace
|
||||
from opik.id_helpers import uuid4_to_uuid7
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def wrap_dict(key_name, data):
|
||||
"""Make sure that the input data is a dict"""
|
||||
if not isinstance(data, dict):
|
||||
return {key_name: data}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def wrap_metadata(metadata, **kwargs):
|
||||
"""Add common metatada to all Traces and Spans"""
|
||||
metadata["created_from"] = "dify"
|
||||
|
||||
metadata.update(kwargs)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def prepare_opik_uuid(user_datetime: Optional[datetime], user_uuid: Optional[str]):
|
||||
"""Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most
|
||||
messages and objects. The type-hints of BaseTraceInfo indicates that
|
||||
objects start_time and message_id could be null which means we cannot map
|
||||
it to a UUIDv7. Given that we have no way to identify that object
|
||||
uniquely, generate a new random one UUIDv7 in that case.
|
||||
"""
|
||||
|
||||
if user_datetime is None:
|
||||
user_datetime = datetime.now()
|
||||
|
||||
if user_uuid is None:
|
||||
user_uuid = str(uuid.uuid4())
|
||||
|
||||
return uuid4_to_uuid7(user_datetime, user_uuid)
|
||||
|
||||
|
||||
class OpikDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
opik_config: OpikConfig,
|
||||
):
|
||||
super().__init__(opik_config)
|
||||
self.opik_client = Opik(
|
||||
project_name=opik_config.project,
|
||||
workspace=opik_config.workspace,
|
||||
host=opik_config.url,
|
||||
api_key=opik_config.api_key,
|
||||
)
|
||||
self.project = opik_config.project
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
dify_trace_id = trace_info.workflow_run_id
|
||||
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
||||
workflow_metadata = wrap_metadata(
|
||||
trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
|
||||
)
|
||||
root_span_id = None
|
||||
|
||||
if trace_info.message_id:
|
||||
dify_trace_id = trace_info.message_id
|
||||
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
||||
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"tags": ["message", "workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_trace(trace_data)
|
||||
|
||||
root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
|
||||
span_data = {
|
||||
"id": root_span_id,
|
||||
"parent_span_id": None,
|
||||
"trace_id": opik_trace_id,
|
||||
"name": TraceTaskName.WORKFLOW_TRACE.value,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"tags": ["workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_span(span_data)
|
||||
else:
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"tags": ["workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_trace(trace_data)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution
|
||||
workflow_nodes_execution_id_records = (
|
||||
db.session.query(WorkflowNodeExecution.id)
|
||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for node_execution_id_record in workflow_nodes_execution_id_records:
|
||||
node_execution = (
|
||||
db.session.query(
|
||||
WorkflowNodeExecution.id,
|
||||
WorkflowNodeExecution.tenant_id,
|
||||
WorkflowNodeExecution.app_id,
|
||||
WorkflowNodeExecution.title,
|
||||
WorkflowNodeExecution.node_type,
|
||||
WorkflowNodeExecution.status,
|
||||
WorkflowNodeExecution.inputs,
|
||||
WorkflowNodeExecution.outputs,
|
||||
WorkflowNodeExecution.created_at,
|
||||
WorkflowNodeExecution.elapsed_time,
|
||||
WorkflowNodeExecution.process_data,
|
||||
WorkflowNodeExecution.execution_metadata,
|
||||
)
|
||||
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not node_execution:
|
||||
continue
|
||||
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = node_execution.tenant_id
|
||||
app_id = node_execution.app_id
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == "llm":
|
||||
inputs = (
|
||||
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
||||
)
|
||||
else:
|
||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = (
|
||||
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
||||
)
|
||||
metadata = execution_metadata.copy()
|
||||
metadata.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
"node_execution_id": node_execution_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_name,
|
||||
"node_type": node_type,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
|
||||
provider = None
|
||||
model = None
|
||||
total_tokens = 0
|
||||
completion_tokens = 0
|
||||
prompt_tokens = 0
|
||||
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
run_type = "llm"
|
||||
provider = process_data.get("model_provider", None)
|
||||
model = process_data.get("model_name", "")
|
||||
metadata.update(
|
||||
{
|
||||
"ls_provider": provider,
|
||||
"ls_model_name": model,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
if outputs.get("usage"):
|
||||
total_tokens = outputs["usage"].get("total_tokens", 0)
|
||||
prompt_tokens = outputs["usage"].get("prompt_tokens", 0)
|
||||
completion_tokens = outputs["usage"].get("completion_tokens", 0)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
else:
|
||||
run_type = "tool"
|
||||
|
||||
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
||||
|
||||
if not total_tokens:
|
||||
total_tokens = execution_metadata.get("total_tokens", 0)
|
||||
|
||||
span_data = {
|
||||
"trace_id": opik_trace_id,
|
||||
"id": prepare_opik_uuid(created_at, node_execution_id),
|
||||
"parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id),
|
||||
"name": node_type,
|
||||
"type": run_type,
|
||||
"start_time": created_at,
|
||||
"end_time": finished_at,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
"input": wrap_dict("input", inputs),
|
||||
"output": wrap_dict("output", outputs),
|
||||
"tags": ["node_execution"],
|
||||
"project_name": self.project,
|
||||
"usage": {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
},
|
||||
"model": model,
|
||||
"provider": provider,
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
# get message file data
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
||||
|
||||
if message_file_data is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
|
||||
metadata = trace_info.metadata
|
||||
message_id = trace_info.message_id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
metadata["user_id"] = user_id
|
||||
metadata["file_list"] = file_list
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
metadata["end_user_id"] = end_user_id
|
||||
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, message_id),
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
"input": trace_info.inputs,
|
||||
"output": message_data.answer,
|
||||
"tags": ["message", str(trace_info.conversation_mode)],
|
||||
"project_name": self.project,
|
||||
}
|
||||
trace = self.add_trace(trace_data)
|
||||
|
||||
span_data = {
|
||||
"trace_id": trace.id,
|
||||
"name": "llm",
|
||||
"type": "llm",
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
"input": {"input": trace_info.inputs},
|
||||
"output": {"output": message_data.answer},
|
||||
"tags": ["llm", str(trace_info.conversation_mode)],
|
||||
"usage": {
|
||||
"completion_tokens": trace_info.answer_tokens,
|
||||
"prompt_tokens": trace_info.message_tokens,
|
||||
"total_tokens": trace_info.total_tokens,
|
||||
},
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_span(span_data)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.MODERATION_TRACE.value,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": {
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
"tags": ["moderation"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or message_data.updated_at,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": wrap_dict("output", trace_info.suggested_question),
|
||||
"tags": ["suggested_question"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": {"documents": trace_info.documents},
|
||||
"tags": ["dataset_retrieval"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
|
||||
"name": trace_info.tool_name,
|
||||
"type": "tool",
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.tool_inputs),
|
||||
"output": wrap_dict("output", trace_info.tool_outputs),
|
||||
"tags": ["tool", trace_info.tool_name],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": trace_info.inputs,
|
||||
"output": trace_info.outputs,
|
||||
"tags": ["generate_name"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
|
||||
trace = self.add_trace(trace_data)
|
||||
|
||||
span_data = {
|
||||
"trace_id": trace.id,
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": wrap_dict("output", trace_info.outputs),
|
||||
"tags": ["generate_name"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def add_trace(self, opik_trace_data: dict) -> Trace:
|
||||
try:
|
||||
trace = self.opik_client.trace(**opik_trace_data)
|
||||
logger.debug("Opik Trace created successfully")
|
||||
return trace
|
||||
except Exception as e:
|
||||
raise ValueError(f"Opik Failed to create trace: {str(e)}")
|
||||
|
||||
def add_span(self, opik_span_data: dict):
|
||||
try:
|
||||
self.opik_client.span(**opik_span_data)
|
||||
logger.debug("Opik Span created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Opik Failed to create span: {str(e)}")
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
self.opik_client.auth_check()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.info(f"Opik API check failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Opik API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
try:
|
||||
return self.opik_client.get_project_url(project_name=self.project)
|
||||
except Exception as e:
|
||||
logger.info(f"Opik get run url failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Opik get run url failed: {str(e)}")
|
||||
@ -17,7 +17,6 @@ from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
LangfuseConfig,
|
||||
LangSmithConfig,
|
||||
OpikConfig,
|
||||
TracingProviderEnum,
|
||||
)
|
||||
from core.ops.entities.trace_entity import (
|
||||
@ -33,7 +32,6 @@ from core.ops.entities.trace_entity import (
|
||||
)
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
||||
from core.ops.opik_trace.opik_trace import OpikDataTrace
|
||||
from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
@ -54,12 +52,6 @@ provider_config_map: dict[str, dict[str, Any]] = {
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": LangSmithDataTrace,
|
||||
},
|
||||
TracingProviderEnum.OPIK.value: {
|
||||
"config_class": OpikConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "url", "workspace"],
|
||||
"trace_instance": OpikDataTrace,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -22,12 +22,7 @@ from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from extensions import ext_hosting_provider
|
||||
from extensions.ext_database import db
|
||||
@ -840,18 +835,11 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
# Get provider model credential secret variables
|
||||
if ConfigurateMethod.PREDEFINED_MODEL in provider_entity.configurate_methods:
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.provider_credential_schema.credential_form_schemas
|
||||
if provider_entity.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
else:
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema
|
||||
else []
|
||||
)
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
model_settings: list[ModelSettings] = []
|
||||
if not provider_model_settings:
|
||||
|
||||
@ -1,104 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchConfig,
|
||||
ElasticSearchVector,
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticSearchJaVector(ElasticSearchVector):
|
||||
def create_collection(
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict[Any, Any]]] = None,
|
||||
index_params: Optional[dict] = None,
|
||||
):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info(f"Collection {self._collection_name} already exists.")
|
||||
return
|
||||
|
||||
if not self._client.indices.exists(index=self._collection_name):
|
||||
dim = len(embeddings[0])
|
||||
settings = {
|
||||
"analysis": {
|
||||
"analyzer": {
|
||||
"ja_analyzer": {
|
||||
"type": "custom",
|
||||
"char_filter": [
|
||||
"icu_normalizer",
|
||||
"kuromoji_iteration_mark",
|
||||
],
|
||||
"tokenizer": "kuromoji_tokenizer",
|
||||
"filter": [
|
||||
"kuromoji_baseform",
|
||||
"kuromoji_part_of_speech",
|
||||
"ja_stop",
|
||||
"kuromoji_number",
|
||||
"kuromoji_stemmer",
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
mappings = {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {
|
||||
"type": "text",
|
||||
"analyzer": "ja_analyzer",
|
||||
"search_analyzer": "ja_analyzer",
|
||||
},
|
||||
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||
"type": "dense_vector",
|
||||
"dims": dim,
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings)
|
||||
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return ElasticSearchJaVector(
|
||||
index_name=collection_name,
|
||||
config=ElasticSearchConfig(
|
||||
host=config.get("ELASTICSEARCH_HOST", "localhost"),
|
||||
port=config.get("ELASTICSEARCH_PORT", 9200),
|
||||
username=config.get("ELASTICSEARCH_USERNAME", ""),
|
||||
password=config.get("ELASTICSEARCH_PASSWORD", ""),
|
||||
),
|
||||
attributes=[],
|
||||
)
|
||||
@ -6,8 +6,6 @@ class Field(Enum):
|
||||
METADATA_KEY = "metadata"
|
||||
GROUP_KEY = "group_id"
|
||||
VECTOR = "vector"
|
||||
# Sparse Vector aims to support full text search
|
||||
SPARSE_VECTOR = "sparse_vector"
|
||||
TEXT_KEY = "text"
|
||||
PRIMARY_KEY = "id"
|
||||
DOC_ID = "metadata.doc_id"
|
||||
|
||||
@ -258,7 +258,7 @@ class LindormVectorStore(BaseVector):
|
||||
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
|
||||
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
|
||||
nlist = kwargs.pop("nlist", 1000)
|
||||
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", nlist >= 5000)
|
||||
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False)
|
||||
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
|
||||
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
|
||||
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
|
||||
@ -305,7 +305,7 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic
|
||||
if method_name == "ivfpq":
|
||||
ivfpq_m = kwargs["ivfpq_m"]
|
||||
nlist = kwargs["nlist"]
|
||||
centroids_use_hnsw = nlist > 10000
|
||||
centroids_use_hnsw = True if nlist > 10000 else False
|
||||
centroids_hnsw_m = 24
|
||||
centroids_hnsw_ef_construct = 500
|
||||
centroids_hnsw_ef_search = 100
|
||||
|
||||
@ -2,7 +2,6 @@ import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymilvus import MilvusClient, MilvusException # type: ignore
|
||||
from pymilvus.milvus_client import IndexParams # type: ignore
|
||||
@ -21,25 +20,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
"""
|
||||
Configuration class for Milvus connection.
|
||||
"""
|
||||
|
||||
uri: str # Milvus server URI
|
||||
token: Optional[str] = None # Optional token for authentication
|
||||
user: str # Username for authentication
|
||||
password: str # Password for authentication
|
||||
batch_size: int = 100 # Batch size for operations
|
||||
database: str = "default" # Database name
|
||||
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
||||
uri: str
|
||||
token: Optional[str] = None
|
||||
user: str
|
||||
password: str
|
||||
batch_size: int = 100
|
||||
database: str = "default"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
"""
|
||||
Validate the configuration values.
|
||||
Raises ValueError if required fields are missing.
|
||||
"""
|
||||
if not values.get("uri"):
|
||||
raise ValueError("config MILVUS_URI is required")
|
||||
if not values.get("user"):
|
||||
@ -49,9 +39,6 @@ class MilvusConfig(BaseModel):
|
||||
return values
|
||||
|
||||
def to_milvus_params(self):
|
||||
"""
|
||||
Convert the configuration to a dictionary of Milvus connection parameters.
|
||||
"""
|
||||
return {
|
||||
"uri": self.uri,
|
||||
"token": self.token,
|
||||
@ -62,57 +49,26 @@ class MilvusConfig(BaseModel):
|
||||
|
||||
|
||||
class MilvusVector(BaseVector):
|
||||
"""
|
||||
Milvus vector storage implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, collection_name: str, config: MilvusConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = self._init_client(config)
|
||||
self._consistency_level = "Session" # Consistency level for Milvus operations
|
||||
self._fields: list[str] = [] # List of fields in the collection
|
||||
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
|
||||
|
||||
def _check_hybrid_search_support(self) -> bool:
|
||||
"""
|
||||
Check if the current Milvus version supports hybrid search.
|
||||
Returns True if the version is >= 2.5.0, otherwise False.
|
||||
"""
|
||||
if not self._client_config.enable_hybrid_search:
|
||||
return False
|
||||
|
||||
try:
|
||||
milvus_version = self._client.get_server_version()
|
||||
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
|
||||
return False
|
||||
self._consistency_level = "Session"
|
||||
self._fields: list[str] = []
|
||||
|
||||
def get_type(self) -> str:
|
||||
"""
|
||||
Get the type of vector storage (Milvus).
|
||||
"""
|
||||
return VectorType.MILVUS
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""
|
||||
Create a collection and add texts with embeddings.
|
||||
"""
|
||||
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas, index_params)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""
|
||||
Add texts and their embeddings to the collection.
|
||||
"""
|
||||
insert_dict_list = []
|
||||
for i in range(len(documents)):
|
||||
insert_dict = {
|
||||
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
|
||||
# function will automatically convert the native text into a sparse vector for us.
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
@ -120,11 +76,12 @@ class MilvusVector(BaseVector):
|
||||
insert_dict_list.append(insert_dict)
|
||||
# Total insert count
|
||||
total_count = len(insert_dict_list)
|
||||
|
||||
pks: list[str] = []
|
||||
|
||||
for i in range(0, total_count, 1000):
|
||||
# Insert into the collection.
|
||||
batch_insert_list = insert_dict_list[i : i + 1000]
|
||||
# Insert into the collection.
|
||||
try:
|
||||
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
|
||||
pks.extend(ids)
|
||||
@ -134,9 +91,6 @@ class MilvusVector(BaseVector):
|
||||
return pks
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
"""
|
||||
Get document IDs by metadata field key and value.
|
||||
"""
|
||||
result = self._client.query(
|
||||
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
|
||||
)
|
||||
@ -146,18 +100,12 @@ class MilvusVector(BaseVector):
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
"""
|
||||
Delete documents by metadata field key and value.
|
||||
"""
|
||||
if self._client.has_collection(self._collection_name):
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
"""
|
||||
Delete documents by their IDs.
|
||||
"""
|
||||
if self._client.has_collection(self._collection_name):
|
||||
result = self._client.query(
|
||||
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
|
||||
@ -167,16 +115,10 @@ class MilvusVector(BaseVector):
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete the entire collection.
|
||||
"""
|
||||
if self._client.has_collection(self._collection_name):
|
||||
self._client.drop_collection(self._collection_name, None)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
"""
|
||||
Check if a text with the given ID exists in the collection.
|
||||
"""
|
||||
if not self._client.has_collection(self._collection_name):
|
||||
return False
|
||||
|
||||
@ -186,80 +128,32 @@ class MilvusVector(BaseVector):
|
||||
|
||||
return len(result) > 0
|
||||
|
||||
def field_exists(self, field: str) -> bool:
|
||||
"""
|
||||
Check if a field exists in the collection.
|
||||
"""
|
||||
return field in self._fields
|
||||
|
||||
def _process_search_results(
|
||||
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Common method to process search results
|
||||
|
||||
:param results: Search results
|
||||
:param output_fields: Fields to be output
|
||||
:param score_threshold: Score threshold for filtering
|
||||
:return: List of documents
|
||||
"""
|
||||
docs = []
|
||||
for result in results[0]:
|
||||
metadata = result["entity"].get(output_fields[1], {})
|
||||
metadata["score"] = result["distance"]
|
||||
|
||||
if result["distance"] > score_threshold:
|
||||
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Search for documents by vector similarity.
|
||||
"""
|
||||
# Set search parameters.
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
anns_field=Field.VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
results,
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results[0]:
|
||||
metadata = result["entity"].get(Field.METADATA_KEY.value)
|
||||
metadata["score"] = result["distance"]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if result["distance"] > score_threshold:
|
||||
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Search for documents by full-text search (if hybrid search is enabled).
|
||||
"""
|
||||
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
|
||||
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
|
||||
return []
|
||||
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query],
|
||||
anns_field=Field.SPARSE_VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
results,
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
)
|
||||
# milvus/zilliz doesn't support bm25 search
|
||||
return []
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Create a new collection in Milvus with the specified schema and index parameters.
|
||||
"""
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
@ -267,7 +161,7 @@ class MilvusVector(BaseVector):
|
||||
return
|
||||
# Grab the existing collection if it exists
|
||||
if not self._client.has_collection(self._collection_name):
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
|
||||
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
|
||||
|
||||
# Determine embedding dim
|
||||
@ -276,36 +170,16 @@ class MilvusVector(BaseVector):
|
||||
if metadatas:
|
||||
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
||||
|
||||
# Create the text field, enable_analyzer will be set True to support milvus automatically
|
||||
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
|
||||
fields.append(
|
||||
FieldSchema(
|
||||
Field.CONTENT_KEY.value,
|
||||
DataType.VARCHAR,
|
||||
max_length=65_535,
|
||||
enable_analyzer=self._hybrid_search_enabled,
|
||||
)
|
||||
)
|
||||
# Create the text field
|
||||
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
|
||||
# Create the primary key field
|
||||
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
||||
# Create the vector field, supports binary or float vectors
|
||||
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
|
||||
# Create Sparse Vector Index for the collection
|
||||
if self._hybrid_search_enabled:
|
||||
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
|
||||
|
||||
# Create the schema for the collection
|
||||
schema = CollectionSchema(fields)
|
||||
|
||||
# Create custom function to support text to sparse vector by BM25
|
||||
if self._hybrid_search_enabled:
|
||||
bm25_function = Function(
|
||||
name="text_bm25_emb",
|
||||
input_field_names=[Field.CONTENT_KEY.value],
|
||||
output_field_names=[Field.SPARSE_VECTOR.value],
|
||||
function_type=FunctionType.BM25,
|
||||
)
|
||||
schema.add_function(bm25_function)
|
||||
|
||||
for x in schema.fields:
|
||||
self._fields.append(x.name)
|
||||
# Since primary field is auto-id, no need to track it
|
||||
@ -315,15 +189,10 @@ class MilvusVector(BaseVector):
|
||||
index_params_obj = IndexParams()
|
||||
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
|
||||
|
||||
# Create Sparse Vector Index for the collection
|
||||
if self._hybrid_search_enabled:
|
||||
index_params_obj.add_index(
|
||||
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
|
||||
)
|
||||
|
||||
# Create the collection
|
||||
collection_name = self._collection_name
|
||||
self._client.create_collection(
|
||||
collection_name=self._collection_name,
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
index_params=index_params_obj,
|
||||
consistency_level=self._consistency_level,
|
||||
@ -331,22 +200,12 @@ class MilvusVector(BaseVector):
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _init_client(self, config) -> MilvusClient:
|
||||
"""
|
||||
Initialize and return a Milvus client.
|
||||
"""
|
||||
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
|
||||
return client
|
||||
|
||||
|
||||
class MilvusVectorFactory(AbstractVectorFactory):
|
||||
"""
|
||||
Factory class for creating MilvusVector instances.
|
||||
"""
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
|
||||
"""
|
||||
Initialize a MilvusVector instance for the given dataset.
|
||||
"""
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
@ -363,6 +222,5 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
||||
user=dify_config.MILVUS_USER or "",
|
||||
password=dify_config.MILVUS_PASSWORD or "",
|
||||
database=dify_config.MILVUS_DATABASE or "",
|
||||
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
|
||||
),
|
||||
)
|
||||
|
||||
@ -409,27 +409,27 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||
)
|
||||
if not tidb_auth_binding:
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if tidb_auth_binding:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
idle_tidb_auth_binding = (
|
||||
idle_tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if idle_tidb_auth_binding:
|
||||
idle_tidb_auth_binding.active = True
|
||||
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||
else:
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.limit(1)
|
||||
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if idle_tidb_auth_binding:
|
||||
idle_tidb_auth_binding.active = True
|
||||
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||
if tidb_auth_binding:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
new_cluster = TidbService.create_tidb_serverless_cluster(
|
||||
dify_config.TIDB_PROJECT_ID or "",
|
||||
@ -451,6 +451,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
db.session.add(new_tidb_auth_binding)
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
|
||||
@ -90,12 +90,6 @@ class Vector:
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
|
||||
return ElasticSearchVectorFactory
|
||||
case VectorType.ELASTICSEARCH_JA:
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
|
||||
ElasticSearchJaVectorFactory,
|
||||
)
|
||||
|
||||
return ElasticSearchJaVectorFactory
|
||||
case VectorType.TIDB_VECTOR:
|
||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
||||
|
||||
|
||||
@ -16,7 +16,6 @@ class VectorType(StrEnum):
|
||||
TENCENT = "tencent"
|
||||
ORACLE = "oracle"
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
ELASTICSEARCH_JA = "elasticsearch-ja"
|
||||
LINDORM = "lindorm"
|
||||
COUCHBASE = "couchbase"
|
||||
BAIDU = "baidu"
|
||||
|
||||
@ -31,7 +31,7 @@ class FirecrawlApp:
|
||||
"markdown": data.get("markdown"),
|
||||
}
|
||||
else:
|
||||
raise Exception(f"Failed to scrape URL. Error: {response_data['error']}")
|
||||
raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}')
|
||||
|
||||
elif response.status_code in {402, 409, 500}:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
|
||||
@ -358,7 +358,8 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
if not data_source_binding:
|
||||
raise Exception(
|
||||
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
|
||||
f"No notion data source binding found for tenant {tenant_id} "
|
||||
f"and notion workspace {notion_workspace_id}"
|
||||
)
|
||||
|
||||
return cast(str, data_source_binding.access_token)
|
||||
|
||||
@ -23,6 +23,7 @@ class PdfExtractor(BaseExtractor):
|
||||
self._file_cache_key = file_cache_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
plaintext_file_key = ""
|
||||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
try:
|
||||
@ -38,8 +39,8 @@ class PdfExtractor(BaseExtractor):
|
||||
text = "\n\n".join(text_list)
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and self._file_cache_key:
|
||||
storage.save(self._file_cache_key, text.encode("utf-8"))
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode("utf-8"))
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
@ -112,7 +112,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
df = pd.read_csv(file)
|
||||
text_docs = []
|
||||
for index, row in df.iterrows():
|
||||
data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]})
|
||||
data = Document(page_content=row[0], metadata={"answer": row[1]})
|
||||
text_docs.append(data)
|
||||
if len(text_docs) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
|
||||
@ -127,7 +127,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to create task: {response.get('msg')}")
|
||||
raise Exception(f'Failed to create task: {response.get("msg")}')
|
||||
|
||||
return response.get("data", {}).get("id")
|
||||
|
||||
@ -222,7 +222,7 @@ class AIPPTGenerateToolAdapter:
|
||||
elif model == "wenxin":
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to generate content: {response.get('msg')}")
|
||||
raise Exception(f'Failed to generate content: {response.get("msg")}')
|
||||
|
||||
return response.get("data", "")
|
||||
|
||||
@ -254,7 +254,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
|
||||
id = response.get("data", {}).get("id")
|
||||
cover_url = response.get("data", {}).get("cover_url")
|
||||
@ -270,7 +270,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
|
||||
export_code = response.get("data")
|
||||
if not export_code:
|
||||
@ -290,7 +290,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
|
||||
if response.get("msg") == "导出中":
|
||||
current_iteration += 1
|
||||
@ -343,7 +343,7 @@ class AIPPTGenerateToolAdapter:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
|
||||
token = response.get("data", {}).get("token")
|
||||
expire = response.get("data", {}).get("time_expire")
|
||||
@ -379,7 +379,7 @@ class AIPPTGenerateToolAdapter:
|
||||
if cls._style_cache[key]["expire"] < now:
|
||||
del cls._style_cache[key]
|
||||
|
||||
key = f"{credentials['aippt_access_key']}#@#{user_id}"
|
||||
key = f'{credentials["aippt_access_key"]}#@#{user_id}'
|
||||
if key in cls._style_cache:
|
||||
return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"]
|
||||
|
||||
@ -396,11 +396,11 @@ class AIPPTGenerateToolAdapter:
|
||||
response = response.json()
|
||||
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
|
||||
colors = [
|
||||
{
|
||||
"id": f"id-{item.get('id')}",
|
||||
"id": f'id-{item.get("id")}',
|
||||
"name": item.get("name"),
|
||||
"en_name": item.get("en_name", item.get("name")),
|
||||
}
|
||||
@ -408,7 +408,7 @@ class AIPPTGenerateToolAdapter:
|
||||
]
|
||||
styles = [
|
||||
{
|
||||
"id": f"id-{item.get('id')}",
|
||||
"id": f'id-{item.get("id")}',
|
||||
"name": item.get("title"),
|
||||
}
|
||||
for item in response.get("data", {}).get("suit_style") or []
|
||||
@ -454,7 +454,7 @@ class AIPPTGenerateToolAdapter:
|
||||
response = response.json()
|
||||
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
|
||||
if len(response.get("data", {}).get("list") or []) > 0:
|
||||
return response.get("data", {}).get("list")[0].get("id")
|
||||
|
||||
@ -14,38 +14,14 @@ class BedrockRetrieveTool(BuiltinTool):
|
||||
topk: int = None
|
||||
|
||||
def _bedrock_retrieve(
|
||||
self,
|
||||
query_input: str,
|
||||
knowledge_base_id: str,
|
||||
num_results: int,
|
||||
search_type: str,
|
||||
rerank_model_id: str,
|
||||
metadata_filter: Optional[dict] = None,
|
||||
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
|
||||
):
|
||||
try:
|
||||
retrieval_query = {"text": query_input}
|
||||
|
||||
if search_type not in ["HYBRID", "SEMANTIC"]:
|
||||
raise RuntimeException("search_type should be HYBRID or SEMANTIC")
|
||||
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
|
||||
|
||||
retrieval_configuration = {
|
||||
"vectorSearchConfiguration": {"numberOfResults": num_results, "overrideSearchType": search_type}
|
||||
}
|
||||
|
||||
if rerank_model_id != "default":
|
||||
model_for_rerank_arn = f"arn:aws:bedrock:us-west-2::foundation-model/{rerank_model_id}"
|
||||
rerankingConfiguration = {
|
||||
"bedrockRerankingConfiguration": {
|
||||
"numberOfRerankedResults": num_results,
|
||||
"modelConfiguration": {"modelArn": model_for_rerank_arn},
|
||||
},
|
||||
"type": "BEDROCK_RERANKING_MODEL",
|
||||
}
|
||||
|
||||
retrieval_configuration["vectorSearchConfiguration"]["rerankingConfiguration"] = rerankingConfiguration
|
||||
retrieval_configuration["vectorSearchConfiguration"]["numberOfResults"] = num_results * 5
|
||||
|
||||
# 如果有元数据过滤条件,则添加到检索配置中
|
||||
# Add metadata filter to retrieval configuration if present
|
||||
if metadata_filter:
|
||||
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
|
||||
|
||||
@ -101,20 +77,15 @@ class BedrockRetrieveTool(BuiltinTool):
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
# 获取元数据过滤条件(如果存在)
|
||||
# Get metadata filter conditions (if they exist)
|
||||
metadata_filter_str = tool_parameters.get("metadata_filter")
|
||||
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
|
||||
|
||||
search_type = tool_parameters.get("search_type")
|
||||
rerank_model_id = tool_parameters.get("rerank_model_id")
|
||||
|
||||
line = 4
|
||||
retrieved_docs = self._bedrock_retrieve(
|
||||
query_input=query,
|
||||
knowledge_base_id=self.knowledge_base_id,
|
||||
num_results=self.topk,
|
||||
search_type=search_type,
|
||||
rerank_model_id=rerank_model_id,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
|
||||
@ -138,7 +109,7 @@ class BedrockRetrieveTool(BuiltinTool):
|
||||
if not parameters.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
|
||||
# Optional: Validate if metadata filter is a valid JSON string (if provided)
|
||||
metadata_filter_str = parameters.get("metadata_filter")
|
||||
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
|
||||
raise ValueError("metadata_filter must be a valid JSON object")
|
||||
|
||||
@ -59,57 +59,6 @@ parameters:
|
||||
max: 10
|
||||
default: 5
|
||||
|
||||
- name: search_type
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: search type
|
||||
zh_Hans: 搜索类型
|
||||
pt_BR: search type
|
||||
human_description:
|
||||
en_US: search type
|
||||
zh_Hans: 搜索类型
|
||||
pt_BR: search type
|
||||
llm_description: search type
|
||||
default: SEMANTIC
|
||||
options:
|
||||
- value: SEMANTIC
|
||||
label:
|
||||
en_US: SEMANTIC
|
||||
zh_Hans: 语义搜索
|
||||
- value: HYBRID
|
||||
label:
|
||||
en_US: HYBRID
|
||||
zh_Hans: 混合搜索
|
||||
form: form
|
||||
|
||||
- name: rerank_model_id
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: rerank model id
|
||||
zh_Hans: 重拍模型ID
|
||||
pt_BR: rerank model id
|
||||
human_description:
|
||||
en_US: rerank model id
|
||||
zh_Hans: 重拍模型ID
|
||||
pt_BR: rerank model id
|
||||
llm_description: rerank model id
|
||||
options:
|
||||
- value: default
|
||||
label:
|
||||
en_US: default
|
||||
zh_Hans: 默认
|
||||
- value: cohere.rerank-v3-5:0
|
||||
label:
|
||||
en_US: cohere.rerank-v3-5:0
|
||||
zh_Hans: cohere.rerank-v3-5:0
|
||||
- value: amazon.rerank-v1:0
|
||||
label:
|
||||
en_US: amazon.rerank-v1:0
|
||||
zh_Hans: amazon.rerank-v1:0
|
||||
form: form
|
||||
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
|
||||
@ -229,7 +229,8 @@ class NovaReelTool(BuiltinTool):
|
||||
|
||||
if async_mode:
|
||||
return self.create_text_message(
|
||||
f"Video generation started.\nInvocation ARN: {invocation_arn}\nVideo will be available at: {video_uri}"
|
||||
f"Video generation started.\nInvocation ARN: {invocation_arn}\n"
|
||||
f"Video will be available at: {video_uri}"
|
||||
)
|
||||
|
||||
return self._wait_for_completion(bedrock, s3_client, invocation_arn)
|
||||
|
||||
@ -65,7 +65,7 @@ class BaiduFieldTranslateTool(BuiltinTool, BaiduTranslateToolBase):
|
||||
if "trans_result" in result:
|
||||
result_text = result["trans_result"][0]["dst"]
|
||||
else:
|
||||
result_text = f"{result['error_code']}: {result['error_msg']}"
|
||||
result_text = f'{result["error_code"]}: {result["error_msg"]}'
|
||||
|
||||
return self.create_text_message(str(result_text))
|
||||
except requests.RequestException as e:
|
||||
|
||||
@ -52,7 +52,7 @@ class BaiduLanguageTool(BuiltinTool, BaiduTranslateToolBase):
|
||||
|
||||
result_text = ""
|
||||
if result["error_code"] != 0:
|
||||
result_text = f"{result['error_code']}: {result['error_msg']}"
|
||||
result_text = f'{result["error_code"]}: {result["error_msg"]}'
|
||||
else:
|
||||
result_text = result["data"]["src"]
|
||||
result_text = self.mapping_result(description_language, result_text)
|
||||
|
||||
@ -58,7 +58,7 @@ class BaiduTranslateTool(BuiltinTool, BaiduTranslateToolBase):
|
||||
if "trans_result" in result:
|
||||
result_text = result["trans_result"][0]["dst"]
|
||||
else:
|
||||
result_text = f"{result['error_code']}: {result['error_msg']}"
|
||||
result_text = f'{result["error_code"]}: {result["error_msg"]}'
|
||||
|
||||
return self.create_text_message(str(result_text))
|
||||
except requests.RequestException as e:
|
||||
|
||||
@ -30,7 +30,7 @@ class BingSearchTool(BuiltinTool):
|
||||
headers = {"Ocp-Apim-Subscription-Key": subscription_key, "Accept-Language": accept_language}
|
||||
|
||||
query = quote(query)
|
||||
server_url = f"{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={','.join(filters)}"
|
||||
server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
|
||||
response = get(server_url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
@ -47,23 +47,23 @@ class BingSearchTool(BuiltinTool):
|
||||
results = []
|
||||
if search_results:
|
||||
for result in search_results:
|
||||
url = f": {result['url']}" if "url" in result else ""
|
||||
results.append(self.create_text_message(text=f"{result['name']}{url}"))
|
||||
url = f': {result["url"]}' if "url" in result else ""
|
||||
results.append(self.create_text_message(text=f'{result["name"]}{url}'))
|
||||
|
||||
if entities:
|
||||
for entity in entities:
|
||||
url = f": {entity['url']}" if "url" in entity else ""
|
||||
results.append(self.create_text_message(text=f"{entity.get('name', '')}{url}"))
|
||||
url = f': {entity["url"]}' if "url" in entity else ""
|
||||
results.append(self.create_text_message(text=f'{entity.get("name", "")}{url}'))
|
||||
|
||||
if news:
|
||||
for news_item in news:
|
||||
url = f": {news_item['url']}" if "url" in news_item else ""
|
||||
results.append(self.create_text_message(text=f"{news_item.get('name', '')}{url}"))
|
||||
url = f': {news_item["url"]}' if "url" in news_item else ""
|
||||
results.append(self.create_text_message(text=f'{news_item.get("name", "")}{url}'))
|
||||
|
||||
if related_searches:
|
||||
for related in related_searches:
|
||||
url = f": {related['displayText']}" if "displayText" in related else ""
|
||||
results.append(self.create_text_message(text=f"{related.get('displayText', '')}{url}"))
|
||||
url = f': {related["displayText"]}' if "displayText" in related else ""
|
||||
results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}'))
|
||||
|
||||
return results
|
||||
elif result_type == "json":
|
||||
@ -106,29 +106,29 @@ class BingSearchTool(BuiltinTool):
|
||||
text = ""
|
||||
if search_results:
|
||||
for i, result in enumerate(search_results):
|
||||
text += f"{i + 1}: {result.get('name', '')} - {result.get('snippet', '')}\n"
|
||||
text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
|
||||
|
||||
if computation and "expression" in computation and "value" in computation:
|
||||
text += "\nComputation:\n"
|
||||
text += f"{computation['expression']} = {computation['value']}\n"
|
||||
text += f'{computation["expression"]} = {computation["value"]}\n'
|
||||
|
||||
if entities:
|
||||
text += "\nEntities:\n"
|
||||
for entity in entities:
|
||||
url = f"- {entity['url']}" if "url" in entity else ""
|
||||
text += f"{entity.get('name', '')}{url}\n"
|
||||
url = f'- {entity["url"]}' if "url" in entity else ""
|
||||
text += f'{entity.get("name", "")}{url}\n'
|
||||
|
||||
if news:
|
||||
text += "\nNews:\n"
|
||||
for news_item in news:
|
||||
url = f"- {news_item['url']}" if "url" in news_item else ""
|
||||
text += f"{news_item.get('name', '')}{url}\n"
|
||||
url = f'- {news_item["url"]}' if "url" in news_item else ""
|
||||
text += f'{news_item.get("name", "")}{url}\n'
|
||||
|
||||
if related_searches:
|
||||
text += "\n\nRelated Searches:\n"
|
||||
for related in related_searches:
|
||||
url = f"- {related['webSearchUrl']}" if "webSearchUrl" in related else ""
|
||||
text += f"{related.get('displayText', '')}{url}\n"
|
||||
url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else ""
|
||||
text += f'{related.get("displayText", "")}{url}\n'
|
||||
|
||||
return self.create_text_message(text=self.summary(user_id=user_id, content=text))
|
||||
|
||||
|
||||
@ -83,5 +83,5 @@ class DIDApp:
|
||||
if status["status"] == "done":
|
||||
return status
|
||||
elif status["status"] == "error" or status["status"] == "rejected":
|
||||
raise HTTPError(f"Talks {id} failed: {status['status']} {status.get('error', {}).get('description')}")
|
||||
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error", {}).get("description")}')
|
||||
time.sleep(poll_interval)
|
||||
|
||||
@ -20,33 +20,33 @@ class SendEmailToolParameters(BaseModel):
|
||||
encrypt_method: str
|
||||
|
||||
|
||||
def send_mail(params: SendEmailToolParameters):
|
||||
def send_mail(parmas: SendEmailToolParameters):
|
||||
timeout = 60
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["From"] = params.email_account
|
||||
msg["To"] = params.sender_to
|
||||
msg["Subject"] = params.subject
|
||||
msg.attach(MIMEText(params.email_content, "plain"))
|
||||
msg.attach(MIMEText(params.email_content, "html"))
|
||||
msg["From"] = parmas.email_account
|
||||
msg["To"] = parmas.sender_to
|
||||
msg["Subject"] = parmas.subject
|
||||
msg.attach(MIMEText(parmas.email_content, "plain"))
|
||||
msg.attach(MIMEText(parmas.email_content, "html"))
|
||||
|
||||
ctx = ssl.create_default_context()
|
||||
|
||||
if params.encrypt_method.upper() == "SSL":
|
||||
if parmas.encrypt_method.upper() == "SSL":
|
||||
try:
|
||||
with smtplib.SMTP_SSL(params.smtp_server, params.smtp_port, context=ctx, timeout=timeout) as server:
|
||||
server.login(params.email_account, params.email_password)
|
||||
server.sendmail(params.email_account, params.sender_to, msg.as_string())
|
||||
with smtplib.SMTP_SSL(parmas.smtp_server, parmas.smtp_port, context=ctx, timeout=timeout) as server:
|
||||
server.login(parmas.email_account, parmas.email_password)
|
||||
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string())
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.exception("send email failed")
|
||||
return False
|
||||
else: # NONE or TLS
|
||||
try:
|
||||
with smtplib.SMTP(params.smtp_server, params.smtp_port, timeout=timeout) as server:
|
||||
if params.encrypt_method.upper() == "TLS":
|
||||
with smtplib.SMTP(parmas.smtp_server, parmas.smtp_port, timeout=timeout) as server:
|
||||
if parmas.encrypt_method.upper() == "TLS":
|
||||
server.starttls(context=ctx)
|
||||
server.login(params.email_account, params.email_password)
|
||||
server.sendmail(params.email_account, params.sender_to, msg.as_string())
|
||||
server.login(parmas.email_account, parmas.email_password)
|
||||
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string())
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.exception("send email failed")
|
||||
|
||||
@ -74,7 +74,7 @@ class FirecrawlApp:
|
||||
if response is None:
|
||||
raise HTTPError("Failed to initiate crawl after multiple retries")
|
||||
elif response.get("success") == False:
|
||||
raise HTTPError(f"Failed to crawl: {response.get('error')}")
|
||||
raise HTTPError(f'Failed to crawl: {response.get("error")}')
|
||||
job_id: str = response["id"]
|
||||
if wait:
|
||||
return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval)
|
||||
@ -100,7 +100,7 @@ class FirecrawlApp:
|
||||
if status["status"] == "completed":
|
||||
return status
|
||||
elif status["status"] == "failed":
|
||||
raise HTTPError(f"Job {job_id} failed: {status['error']}")
|
||||
raise HTTPError(f'Job {job_id} failed: {status["error"]}')
|
||||
time.sleep(poll_interval)
|
||||
|
||||
|
||||
|
||||
@ -37,9 +37,8 @@ class GaodeRepositoriesTool(BuiltinTool):
|
||||
CityCode = City_data["districts"][0]["adcode"]
|
||||
weatherInfo_response = s.request(
|
||||
method="GET",
|
||||
url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json".format(
|
||||
url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")
|
||||
),
|
||||
url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json"
|
||||
"".format(url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")),
|
||||
)
|
||||
weatherInfo_data = weatherInfo_response.json()
|
||||
if weatherInfo_response.status_code == 200 and weatherInfo_data.get("info") == "OK":
|
||||
|
||||
@ -11,21 +11,19 @@ class GitlabFilesTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
repository = tool_parameters.get("repository", "")
|
||||
project = tool_parameters.get("project", "")
|
||||
repository = tool_parameters.get("repository", "")
|
||||
branch = tool_parameters.get("branch", "")
|
||||
path = tool_parameters.get("path", "")
|
||||
file_path = tool_parameters.get("file_path", "")
|
||||
|
||||
if not repository and not project:
|
||||
return self.create_text_message("Either repository or project is required")
|
||||
if not project and not repository:
|
||||
return self.create_text_message("Either project or repository is required")
|
||||
if not branch:
|
||||
return self.create_text_message("Branch is required")
|
||||
if not path and not file_path:
|
||||
return self.create_text_message("Either path or file_path is required")
|
||||
if not path:
|
||||
return self.create_text_message("Path is required")
|
||||
|
||||
access_token = self.runtime.credentials.get("access_tokens")
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
site_url = self.runtime.credentials.get("site_url")
|
||||
|
||||
if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"):
|
||||
@ -33,45 +31,33 @@ class GitlabFilesTool(BuiltinTool):
|
||||
if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"):
|
||||
site_url = "https://gitlab.com"
|
||||
|
||||
if repository:
|
||||
# URL encode the repository path
|
||||
identifier = urllib.parse.quote(repository, safe="")
|
||||
else:
|
||||
identifier = self.get_project_id(site_url, access_token, project)
|
||||
if not identifier:
|
||||
raise Exception(f"Project '{project}' not found.)")
|
||||
|
||||
# Get file content
|
||||
if path:
|
||||
results = self.fetch_files(site_url, headers, identifier, branch, path)
|
||||
return [self.create_json_message(item) for item in results]
|
||||
if repository:
|
||||
result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True)
|
||||
else:
|
||||
result = self.fetch_file(site_url, headers, identifier, branch, file_path)
|
||||
return [self.create_json_message(result)]
|
||||
result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False)
|
||||
|
||||
@staticmethod
|
||||
def fetch_file(
|
||||
site_url: str,
|
||||
headers: dict[str, str],
|
||||
identifier: str,
|
||||
branch: str,
|
||||
path: str,
|
||||
) -> dict[str, Any]:
|
||||
encoded_file_path = urllib.parse.quote(path, safe="")
|
||||
file_url = f"{site_url}/api/v4/projects/{identifier}/repository/files/{encoded_file_path}/raw?ref={branch}"
|
||||
|
||||
file_response = requests.get(file_url, headers=headers)
|
||||
file_response.raise_for_status()
|
||||
file_content = file_response.text
|
||||
return {"path": path, "branch": branch, "content": file_content}
|
||||
return [self.create_json_message(item) for item in result]
|
||||
|
||||
def fetch_files(
|
||||
self, site_url: str, headers: dict[str, str], identifier: str, branch: str, path: str
|
||||
self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool
|
||||
) -> list[dict[str, Any]]:
|
||||
domain = site_url
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
results = []
|
||||
|
||||
try:
|
||||
tree_url = f"{site_url}/api/v4/projects/{identifier}/repository/tree?path={path}&ref={branch}"
|
||||
if is_repository:
|
||||
# URL encode the repository path
|
||||
encoded_identifier = urllib.parse.quote(identifier, safe="")
|
||||
tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}"
|
||||
else:
|
||||
# Get project ID from project name
|
||||
project_id = self.get_project_id(site_url, access_token, identifier)
|
||||
if not project_id:
|
||||
return self.create_text_message(f"Project '{identifier}' not found.")
|
||||
tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
|
||||
|
||||
response = requests.get(tree_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
items = response.json()
|
||||
@ -79,10 +65,26 @@ class GitlabFilesTool(BuiltinTool):
|
||||
for item in items:
|
||||
item_path = item["path"]
|
||||
if item["type"] == "tree": # It's a directory
|
||||
results.extend(self.fetch_files(site_url, headers, identifier, branch, item_path))
|
||||
results.extend(
|
||||
self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository)
|
||||
)
|
||||
else: # It's a file
|
||||
result = self.fetch_file(site_url, headers, identifier, branch, item_path)
|
||||
results.append(result)
|
||||
encoded_item_path = urllib.parse.quote(item_path, safe="")
|
||||
if is_repository:
|
||||
file_url = (
|
||||
f"{domain}/api/v4/projects/{encoded_identifier}/repository/files"
|
||||
f"/{encoded_item_path}/raw?ref={branch}"
|
||||
)
|
||||
else:
|
||||
file_url = (
|
||||
f"{domain}/api/v4/projects/{project_id}/repository/files"
|
||||
f"{encoded_item_path}/raw?ref={branch}"
|
||||
)
|
||||
|
||||
file_response = requests.get(file_url, headers=headers)
|
||||
file_response.raise_for_status()
|
||||
file_content = file_response.text
|
||||
results.append({"path": item_path, "branch": branch, "content": file_content})
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching data from GitLab: {e}")
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ parameters:
|
||||
zh_Hans: 项目
|
||||
human_description:
|
||||
en_US: project
|
||||
zh_Hans: 项目(和仓库路径二选一,都填写以仓库路径优先)
|
||||
zh_Hans: 项目
|
||||
llm_description: Project for GitLab
|
||||
form: llm
|
||||
- name: branch
|
||||
@ -45,21 +45,12 @@ parameters:
|
||||
form: llm
|
||||
- name: path
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: path
|
||||
zh_Hans: 文件夹
|
||||
human_description:
|
||||
en_US: path
|
||||
zh_Hans: 文件夹
|
||||
llm_description: Dir path for GitLab
|
||||
form: llm
|
||||
- name: file_path
|
||||
type: string
|
||||
label:
|
||||
en_US: file_path
|
||||
zh_Hans: 文件路径
|
||||
human_description:
|
||||
en_US: file_path
|
||||
zh_Hans: 文件路径(和文件夹二选一,都填写以文件夹优先)
|
||||
en_US: path
|
||||
zh_Hans: 文件路径
|
||||
llm_description: File path for GitLab
|
||||
form: llm
|
||||
|
||||
@ -110,7 +110,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
|
||||
result["rows"].append(self.get_row_field_value(row, schema))
|
||||
return self.create_text_message(json.dumps(result, ensure_ascii=False))
|
||||
else:
|
||||
result_text = f'Found {result["total"]} rows in worksheet "{worksheet_name}".'
|
||||
result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"."
|
||||
if result["total"] > 0:
|
||||
result_text += (
|
||||
f" The following are {min(limit, result['total'])}"
|
||||
|
||||
@ -28,4 +28,4 @@ class BaseStabilityAuthorization:
|
||||
"""
|
||||
This method is responsible for generating the authorization headers.
|
||||
"""
|
||||
return {"Authorization": f"Bearer {credentials.get('api_key', '')}"}
|
||||
return {"Authorization": f'Bearer {credentials.get("api_key", "")}'}
|
||||
|
||||
@ -38,7 +38,7 @@ class VannaProvider(BuiltinToolProviderController):
|
||||
tool_parameters={
|
||||
"model": "chinook",
|
||||
"db_type": "SQLite",
|
||||
"url": f"{self._get_protocol_and_main_domain(credentials['base_url'])}/Chinook.sqlite",
|
||||
"url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite',
|
||||
"query": "What are the top 10 customers by sales?",
|
||||
},
|
||||
)
|
||||
|
||||
@ -43,7 +43,7 @@ class SerplyApi:
|
||||
def parse_results(res: dict) -> str:
|
||||
"""Process response from Serply Job Search."""
|
||||
jobs = res.get("jobs", [])
|
||||
if not res or "jobs" not in res:
|
||||
if not jobs:
|
||||
raise ValueError(f"Got error from Serply: {res}")
|
||||
|
||||
string = []
|
||||
|
||||
@ -43,7 +43,7 @@ class SerplyApi:
|
||||
def parse_results(res: dict) -> str:
|
||||
"""Process response from Serply News Search."""
|
||||
news = res.get("entries", [])
|
||||
if not res or "entries" not in res:
|
||||
if not news:
|
||||
raise ValueError(f"Got error from Serply: {res}")
|
||||
|
||||
string = []
|
||||
|
||||
@ -43,7 +43,7 @@ class SerplyApi:
|
||||
def parse_results(res: dict) -> str:
|
||||
"""Process response from Serply News Search."""
|
||||
articles = res.get("articles", [])
|
||||
if not res or "articles" not in res:
|
||||
if not articles:
|
||||
raise ValueError(f"Got error from Serply: {res}")
|
||||
|
||||
string = []
|
||||
|
||||
@ -42,7 +42,7 @@ class SerplyApi:
|
||||
def parse_results(res: dict) -> str:
|
||||
"""Process response from Serply Web Search."""
|
||||
results = res.get("results", [])
|
||||
if not res or "results" not in res:
|
||||
if not results:
|
||||
raise ValueError(f"Got error from Serply: {res}")
|
||||
|
||||
string = []
|
||||
|
||||
@ -84,9 +84,9 @@ class ApiTool(Tool):
|
||||
if "api_key_header_prefix" in credentials:
|
||||
api_key_header_prefix = credentials["api_key_header_prefix"]
|
||||
if api_key_header_prefix == "basic" and credentials["api_key_value"]:
|
||||
credentials["api_key_value"] = f"Basic {credentials['api_key_value']}"
|
||||
credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}'
|
||||
elif api_key_header_prefix == "bearer" and credentials["api_key_value"]:
|
||||
credentials["api_key_value"] = f"Bearer {credentials['api_key_value']}"
|
||||
credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}'
|
||||
elif api_key_header_prefix == "custom":
|
||||
pass
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user