mirror of
https://github.com/langgenius/dify.git
synced 2026-02-06 11:45:35 +08:00
Compare commits
73 Commits
0.11.2
...
feat/suppo
| Author | SHA1 | Date | |
|---|---|---|---|
| a023315f61 | |||
| 4dd32eb089 | |||
| d3051eed48 | |||
| ed55de888a | |||
| da601f0bef | |||
| 08ac36812b | |||
| 556de444e8 | |||
| 3750200c5e | |||
| c5f7d650b5 | |||
| 535c72cad7 | |||
| 8a83edc1b5 | |||
| 5b415a6227 | |||
| 5172f0bf39 | |||
| d9579f418d | |||
| 3579bbd1c4 | |||
| 817b85001f | |||
| e8868a7fb9 | |||
| 2cd9ac60f1 | |||
| 464f384cea | |||
| 8b16f07eb0 | |||
| fefda40acf | |||
| 8c2f62fb92 | |||
| 1a6b961b5f | |||
| 01014a6a84 | |||
| cb0c55daa7 | |||
| 82575a7aea | |||
| 80da0c5830 | |||
| 83b6abf4ad | |||
| ea0ebc020c | |||
| f358db9f02 | |||
| 94c9cadbd8 | |||
| 2ae6460f46 | |||
| 0067b16d1e | |||
| ec9f6220c9 | |||
| af53e2b6b0 | |||
| b42b333a72 | |||
| 99b0369f1b | |||
| d6ea1e2f12 | |||
| 4d6b45427c | |||
| 1be8365684 | |||
| c3d11c8ff6 | |||
| 8ff65abbc6 | |||
| bf4b6e5f80 | |||
| 25fda7adc5 | |||
| f3af7b5f35 | |||
| 33cfc56ad0 | |||
| 464cc26ccf | |||
| d18754afdd | |||
| beb7953d38 | |||
| fbfc811a44 | |||
| 7e66e5a713 | |||
| 07b5bbae06 | |||
| 3087913b74 | |||
| 904ea05bf6 | |||
| 6f4885d86d | |||
| 2dc29cfee3 | |||
| bd05df5cc5 | |||
| ee1f14621a | |||
| 58a9d9eb9a | |||
| bc1013dacf | |||
| 9f195df103 | |||
| 1cc7dc6360 | |||
| 328965ed7c | |||
| 133de9a087 | |||
| 7261384655 | |||
| 4718071cbb | |||
| 22be0816aa | |||
| 49e88322de | |||
| 14f3d44c37 | |||
| 0ba17ec116 | |||
| 79d59c004b | |||
| 873e9720e9 | |||
| de6d3e493c |
@ -1,6 +1,8 @@
|
||||
# CONTRIBUTING
|
||||
|
||||
So you're looking to contribute to Dify - that's awesome, we can't wait to see what you do. As a startup with limited headcount and funding, we have grand ambitions to design the most intuitive workflow for building and managing LLM applications. Any help from the community counts, truly.
|
||||
|
||||
We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part.
|
||||
We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part.
|
||||
|
||||
This guide, like Dify itself, is a constant work in progress. We highly appreciate your understanding if at times it lags behind the actual project, and welcome any feedback for us to improve.
|
||||
|
||||
@ -10,14 +12,12 @@ In terms of licensing, please take a minute to read our short [License and Contr
|
||||
|
||||
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
|
||||
|
||||
### Feature requests:
|
||||
### Feature requests
|
||||
|
||||
* If you're opening a new feature request, we'd like you to explain what the proposed feature achieves, and include as much context as possible. [@perzeusss](https://github.com/perzeuss) has made a solid [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) that helps you draft out your needs. Feel free to give it a try.
|
||||
|
||||
* If you want to pick one up from the existing issues, simply drop a comment below it saying so.
|
||||
|
||||
|
||||
|
||||
A team member working in the related direction will be looped in. If all looks good, they will give the go-ahead for you to start coding. We ask that you hold off working on the feature until then, so none of your work goes to waste should we propose changes.
|
||||
|
||||
Depending on whichever area the proposed feature falls under, you might talk to different team members. Here's rundown of the areas each our team members are working on at the moment:
|
||||
@ -40,7 +40,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
|
||||
| Non-core features and minor enhancements | Low Priority |
|
||||
| Valuable but not immediate | Future-Feature |
|
||||
|
||||
### Anything else (e.g. bug report, performance optimization, typo correction):
|
||||
### Anything else (e.g. bug report, performance optimization, typo correction)
|
||||
|
||||
* Start coding right away.
|
||||
|
||||
@ -52,7 +52,6 @@ In terms of licensing, please take a minute to read our short [License and Contr
|
||||
| Non-critical bugs, performance boosts | Medium Priority |
|
||||
| Minor fixes (typos, confusing but working UI) | Low Priority |
|
||||
|
||||
|
||||
## Installing
|
||||
|
||||
Here are the steps to set up Dify for development:
|
||||
@ -63,7 +62,7 @@ Here are the steps to set up Dify for development:
|
||||
|
||||
Clone the forked repository from your terminal:
|
||||
|
||||
```
|
||||
```shell
|
||||
git clone git@github.com:<github_username>/dify.git
|
||||
```
|
||||
|
||||
@ -71,11 +70,11 @@ git clone git@github.com:<github_username>/dify.git
|
||||
|
||||
Dify requires the following dependencies to build, make sure they're installed on your system:
|
||||
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) version 3.10.x
|
||||
* [Docker](https://www.docker.com/)
|
||||
* [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
* [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
* [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
* [Python](https://www.python.org/) version 3.11.x or 3.12.x
|
||||
|
||||
### 4. Installations
|
||||
|
||||
@ -85,7 +84,7 @@ Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) fo
|
||||
|
||||
### 5. Visit dify in your browser
|
||||
|
||||
To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running.
|
||||
To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running.
|
||||
|
||||
## Developing
|
||||
|
||||
@ -97,9 +96,9 @@ To help you quickly navigate where your contribution fits, a brief, annotated ou
|
||||
|
||||
### Backend
|
||||
|
||||
Dify’s backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login.
|
||||
Dify’s backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login.
|
||||
|
||||
```
|
||||
```text
|
||||
[api/]
|
||||
├── constants // Constant settings used throughout code base.
|
||||
├── controllers // API route definitions and request handling logic.
|
||||
@ -121,7 +120,7 @@ Dify’s backend is written in Python using [Flask](https://flask.palletsproject
|
||||
|
||||
The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typescript and uses [Tailwind CSS](https://tailwindcss.com/) for styling. [React-i18next](https://react.i18next.com/) is used for internationalization.
|
||||
|
||||
```
|
||||
```text
|
||||
[web/]
|
||||
├── app // layouts, pages, and components
|
||||
│ ├── (commonLayout) // common layout used throughout the app
|
||||
@ -149,10 +148,10 @@ The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typ
|
||||
|
||||
## Submitting your PR
|
||||
|
||||
At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests).
|
||||
At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests).
|
||||
|
||||
And that's it! Once your PR is merged, you will be featured as a contributor in our [README](https://github.com/langgenius/dify/blob/main/README.md).
|
||||
|
||||
## Getting Help
|
||||
|
||||
If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
|
||||
If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
|
||||
|
||||
@ -42,6 +42,11 @@ REDIS_SENTINEL_USERNAME=
|
||||
REDIS_SENTINEL_PASSWORD=
|
||||
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
|
||||
|
||||
# redis Cluster configuration.
|
||||
REDIS_USE_CLUSTERS=false
|
||||
REDIS_CLUSTERS=
|
||||
REDIS_CLUSTERS_PASSWORD=
|
||||
|
||||
# PostgreSQL database configuration
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
@ -234,6 +239,10 @@ ANALYTICDB_ACCOUNT=testaccount
|
||||
ANALYTICDB_PASSWORD=testpassword
|
||||
ANALYTICDB_NAMESPACE=dify
|
||||
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
||||
ANALYTICDB_HOST=gp-test.aliyuncs.com
|
||||
ANALYTICDB_PORT=5432
|
||||
ANALYTICDB_MIN_CONNECTION=1
|
||||
ANALYTICDB_MAX_CONNECTION=5
|
||||
|
||||
# OpenSearch configuration
|
||||
OPENSEARCH_HOST=127.0.0.1
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# base image
|
||||
FROM python:3.10-slim-bookworm AS base
|
||||
FROM python:3.12-slim-bookworm AS base
|
||||
|
||||
WORKDIR /app/api
|
||||
|
||||
|
||||
@ -18,12 +18,17 @@
|
||||
```
|
||||
|
||||
2. Copy `.env.example` to `.env`
|
||||
|
||||
```cli
|
||||
cp .env.example .env
|
||||
```
|
||||
3. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
bash for Linux
|
||||
```bash for Linux
|
||||
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
```
|
||||
|
||||
bash for Mac
|
||||
```bash for Mac
|
||||
secret_key=$(openssl rand -base64 42)
|
||||
sed -i '' "/^SECRET_KEY=/c\\
|
||||
@ -37,18 +42,10 @@
|
||||
5. Install dependencies
|
||||
|
||||
```bash
|
||||
poetry env use 3.10
|
||||
poetry env use 3.12
|
||||
poetry install
|
||||
```
|
||||
|
||||
In case of contributors missing to update dependencies for `pyproject.toml`, you can perform the following shell instead.
|
||||
|
||||
```bash
|
||||
poetry shell # activate current environment
|
||||
poetry add $(cat requirements.txt) # install dependencies of production and update pyproject.toml
|
||||
poetry add $(cat requirements-dev.txt) --group dev # install dependencies of development and update pyproject.toml
|
||||
```
|
||||
|
||||
6. Run migrate
|
||||
|
||||
Before the first launch, migrate the database to the latest version.
|
||||
@ -84,5 +81,3 @@
|
||||
```bash
|
||||
poetry run -C api bash dev/pytest/pytest_all_tests.sh
|
||||
```
|
||||
|
||||
|
||||
|
||||
@ -27,7 +27,6 @@ class DifyConfig(
|
||||
# read from dotenv format config file
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
frozen=True,
|
||||
# ignore extra attributes
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
15
api/configs/middleware/cache/redis_config.py
vendored
15
api/configs/middleware/cache/redis_config.py
vendored
@ -68,3 +68,18 @@ class RedisConfig(BaseSettings):
|
||||
description="Socket timeout in seconds for Redis Sentinel connections",
|
||||
default=0.1,
|
||||
)
|
||||
|
||||
REDIS_USE_CLUSTERS: bool = Field(
|
||||
description="Enable Redis Clusters mode for high availability",
|
||||
default=False,
|
||||
)
|
||||
|
||||
REDIS_CLUSTERS: Optional[str] = Field(
|
||||
description="Comma-separated list of Redis Clusters nodes (host:port)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_CLUSTERS_PASSWORD: Optional[str] = Field(
|
||||
description="Password for Redis Clusters authentication (if required)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
@ -40,3 +40,11 @@ class AnalyticdbConfig(BaseModel):
|
||||
description="The password for accessing the specified namespace within the AnalyticDB instance"
|
||||
" (if namespace feature is enabled).",
|
||||
)
|
||||
ANALYTICDB_HOST: Optional[str] = Field(
|
||||
default=None, description="The host of the AnalyticDB instance you want to connect to."
|
||||
)
|
||||
ANALYTICDB_PORT: PositiveInt = Field(
|
||||
default=5432, description="The port of the AnalyticDB instance you want to connect to."
|
||||
)
|
||||
ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.")
|
||||
ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.")
|
||||
|
||||
@ -2,6 +2,7 @@ from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
from .app.app_import import AppImportApi, AppImportConfirmApi
|
||||
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
|
||||
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
@ -17,6 +18,10 @@ api.add_resource(FileSupportTypeApi, "/files/support-type")
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
|
||||
# Import App
|
||||
api.add_resource(AppImportApi, "/apps/imports")
|
||||
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
|
||||
|
||||
# Import other controllers
|
||||
from . import admin, apikey, extension, feature, ping, setup, version
|
||||
|
||||
@ -57,6 +62,7 @@ from .datasets import (
|
||||
external,
|
||||
hit_testing,
|
||||
website,
|
||||
fta_test,
|
||||
)
|
||||
|
||||
# Import explore controllers
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, abort
|
||||
|
||||
from controllers.console import api
|
||||
@ -13,13 +16,15 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (
|
||||
app_detail_fields,
|
||||
app_detail_fields_with_site,
|
||||
app_pagination_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from services.app_dsl_service import AppDslService
|
||||
from models import Account, App
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
@ -92,61 +97,6 @@ class AppListApi(Resource):
|
||||
return app, 201
|
||||
|
||||
|
||||
class AppImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
def post(self):
|
||||
"""Import app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
||||
class AppImportFromUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
def post(self):
|
||||
"""Import app from url"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("url", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app = AppDslService.import_and_create_new_app_from_url(
|
||||
tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -224,10 +174,24 @@ class AppCopyApi(Resource):
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
data = AppDslService.export_dsl(app_model=app_model, include_secret=True)
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=yaml_content,
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
)
|
||||
session.commit()
|
||||
|
||||
stmt = select(App).where(App.id == result.app.id)
|
||||
app = session.scalar(stmt)
|
||||
|
||||
return app, 201
|
||||
|
||||
@ -368,8 +332,6 @@ class AppTraceApi(Resource):
|
||||
|
||||
|
||||
api.add_resource(AppListApi, "/apps")
|
||||
api.add_resource(AppImportApi, "/apps/import")
|
||||
api.add_resource(AppImportFromUrlApi, "/apps/import/url")
|
||||
api.add_resource(AppApi, "/apps/<uuid:app_id>")
|
||||
api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy")
|
||||
api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export")
|
||||
|
||||
90
api/controllers/console/app/app_import.py
Normal file
90
api/controllers/console/app/app_import.py
Normal file
@ -0,0 +1,90 @@
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
|
||||
|
||||
class AppImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_fields)
|
||||
def post(self):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("mode", type=str, required=True, location="json")
|
||||
parser.add_argument("yaml_content", type=str, location="json")
|
||||
parser.add_argument("yaml_url", type=str, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("app_id", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Import app
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
yaml_content=args.get("yaml_content"),
|
||||
yaml_url=args.get("yaml_url"),
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
app_id=args.get("app_id"),
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class AppImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_fields)
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
@ -20,7 +20,6 @@ from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -126,31 +125,6 @@ class DraftWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
class DraftWorkflowImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_fields)
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Import draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow = AppDslService.import_and_overwrite_workflow(
|
||||
app_model=app_model, data=args["data"], account=current_user
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
|
||||
class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -453,7 +427,6 @@ class ConvertToWorkflowApi(Resource):
|
||||
|
||||
|
||||
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
|
||||
api.add_resource(DraftWorkflowImportApi, "/apps/<uuid:app_id>/workflows/draft/import")
|
||||
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
|
||||
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
|
||||
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
|
||||
145
api/controllers/console/datasets/fta_test.py
Normal file
145
api/controllers/console/datasets/fta_test.py
Normal file
@ -0,0 +1,145 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
from flask import Response
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy import text
|
||||
|
||||
from controllers.console import api
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.fta import ComponentFailure, ComponentFailureStats
|
||||
|
||||
|
||||
class FATTestApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("log_process_data", nullable=False, required=True, type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
print(args["log_process_data"])
|
||||
# Extract the JSON string from the text field
|
||||
json_str = args["log_process_data"].strip("```json\\n").strip("```").strip().replace("\\n", "")
|
||||
log_data = json.loads(json_str)
|
||||
db.session.query(ComponentFailure).delete()
|
||||
for data in log_data:
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError("Data must be a dictionary.")
|
||||
|
||||
required_keys = {"Date", "Component", "FailureMode", "Cause", "RepairAction", "Technician"}
|
||||
if not required_keys.issubset(data.keys()):
|
||||
raise ValueError(f"Data dictionary must contain the following keys: {required_keys}")
|
||||
|
||||
try:
|
||||
# Clear existing stats
|
||||
component_failure = ComponentFailure(
|
||||
Date=data["Date"],
|
||||
Component=data["Component"],
|
||||
FailureMode=data["FailureMode"],
|
||||
Cause=data["Cause"],
|
||||
RepairAction=data["RepairAction"],
|
||||
Technician=data["Technician"],
|
||||
)
|
||||
db.session.add(component_failure)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# Clear existing stats
|
||||
db.session.query(ComponentFailureStats).delete()
|
||||
|
||||
# Insert calculated statistics
|
||||
try:
|
||||
db.session.execute(
|
||||
text("""
|
||||
INSERT INTO component_failure_stats ("Component", "FailureMode", "Cause", "PossibleAction", "Probability", "MTBF")
|
||||
SELECT
|
||||
cf."Component",
|
||||
cf."FailureMode",
|
||||
cf."Cause",
|
||||
cf."RepairAction" as "PossibleAction",
|
||||
COUNT(*) * 1.0 / (SELECT COUNT(*) FROM component_failure WHERE "Component" = cf."Component") AS "Probability",
|
||||
COALESCE(AVG(EXTRACT(EPOCH FROM (next_failure_date::timestamp - cf."Date"::timestamp)) / 86400.0),0)AS "MTBF"
|
||||
FROM (
|
||||
SELECT
|
||||
"Component",
|
||||
"FailureMode",
|
||||
"Cause",
|
||||
"RepairAction",
|
||||
"Date",
|
||||
LEAD("Date") OVER (PARTITION BY "Component", "FailureMode", "Cause" ORDER BY "Date") AS next_failure_date
|
||||
FROM
|
||||
component_failure
|
||||
) cf
|
||||
GROUP BY
|
||||
cf."Component", cf."FailureMode", cf."Cause", cf."RepairAction";
|
||||
""")
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
print(f"Error during stats calculation: {e}")
|
||||
# output format
|
||||
# [
|
||||
# (17, 'Hydraulic system', 'Leak', 'Hose rupture', 'Replaced hydraulic hose', 0.3333333333333333, None),
|
||||
# (18, 'Hydraulic system', 'Leak', 'Seal Wear', 'Replaced the faulty seal', 0.3333333333333333, None),
|
||||
# (19, 'Hydraulic system', 'Pressure drop', 'Fluid leak', 'Replaced hydraulic fluid and seals', 0.3333333333333333, None)
|
||||
# ]
|
||||
|
||||
component_failure_stats = db.session.query(ComponentFailureStats).all()
|
||||
# Convert stats to list of tuples format
|
||||
stats_list = []
|
||||
for stat in component_failure_stats:
|
||||
stats_list.append(
|
||||
(
|
||||
stat.StatID,
|
||||
stat.Component,
|
||||
stat.FailureMode,
|
||||
stat.Cause,
|
||||
stat.PossibleAction,
|
||||
stat.Probability,
|
||||
stat.MTBF,
|
||||
)
|
||||
)
|
||||
return {"data": stats_list}, 200
|
||||
|
||||
|
||||
# generate-fault-tree
|
||||
class GenerateFaultTreeApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("llm_text", nullable=False, required=True, type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
entities = args["llm_text"].replace("```", "").replace("\\n", "\n")
|
||||
print(entities)
|
||||
request_data = {"fault_tree_text": entities}
|
||||
url = "https://fta.cognitech-dev.live/generate-fault-tree"
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
response = requests.post(url, json=request_data, headers=headers)
|
||||
print(response.json())
|
||||
return {"data": response.json()}, 200
|
||||
|
||||
|
||||
class ExtractSVGApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("svg_text", nullable=False, required=True, type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
# svg_text = ''.join(args["svg_text"].splitlines())
|
||||
svg_text = args["svg_text"].replace("\n", "")
|
||||
svg_text = svg_text.replace('"', '"')
|
||||
print(svg_text)
|
||||
svg_text_json = json.loads(svg_text)
|
||||
svg_content = svg_text_json.get("data").get("svg_content")[0]
|
||||
svg_content = svg_content.replace("\n", "").replace('"', '"')
|
||||
file_key = "fta_svg/" + "fat.svg"
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, svg_content.encode("utf-8"))
|
||||
generator = storage.load(file_key, stream=True)
|
||||
|
||||
return Response(generator, mimetype="image/svg+xml")
|
||||
|
||||
|
||||
api.add_resource(FATTestApi, "/fta/db-handler")
|
||||
api.add_resource(GenerateFaultTreeApi, "/fta/generate-fault-tree")
|
||||
api.add_resource(ExtractSVGApi, "/fta/extract-svg")
|
||||
@ -45,7 +45,7 @@ class RemoteFileUploadApi(Resource):
|
||||
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3)
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from urllib import parse
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, marshal_with, reqparse
|
||||
|
||||
@ -57,11 +59,12 @@ class MemberInviteEmailApi(Resource):
|
||||
token = RegisterService.invite_new_member(
|
||||
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
|
||||
)
|
||||
encoded_invitee_email = parse.quote(invitee_email)
|
||||
invitation_results.append(
|
||||
{
|
||||
"status": "success",
|
||||
"email": invitee_email,
|
||||
"url": f"{console_web_url}/activate?email={invitee_email}&token={token}",
|
||||
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
|
||||
}
|
||||
)
|
||||
except AccountAlreadyInTenantError:
|
||||
|
||||
@ -114,16 +114,9 @@ class BaseAgentRunner(AppRunner):
|
||||
# check if model supports stream tool call
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
|
||||
self.stream_tool_call = True
|
||||
else:
|
||||
self.stream_tool_call = False
|
||||
|
||||
# check if model supports vision
|
||||
if model_schema and ModelFeature.VISION in (model_schema.features or []):
|
||||
self.files = application_generate_entity.files
|
||||
else:
|
||||
self.files = []
|
||||
features = model_schema.features if model_schema and model_schema.features else []
|
||||
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
||||
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
||||
self.query = None
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
@ -250,7 +243,7 @@ class BaseAgentRunner(AppRunner):
|
||||
update prompt message tool
|
||||
"""
|
||||
# try to get tool runtime parameters
|
||||
tool_runtime_parameters = tool.get_runtime_parameters() or []
|
||||
tool_runtime_parameters = tool.get_runtime_parameters()
|
||||
|
||||
for parameter in tool_runtime_parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
|
||||
@ -11,7 +11,7 @@ from core.provider_manager import ProviderManager
|
||||
|
||||
class ModelConfigConverter:
|
||||
@classmethod
|
||||
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
|
||||
def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
|
||||
"""
|
||||
Convert app model config dict to entity.
|
||||
:param app_config: app config
|
||||
@ -38,27 +38,23 @@ class ModelConfigConverter:
|
||||
)
|
||||
|
||||
if model_credentials is None:
|
||||
if not skip_check:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
else:
|
||||
model_credentials = {}
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
if not skip_check:
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_config.model, model_type=ModelType.LLM
|
||||
)
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_config.model, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
model_name = model_config.model
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
if provider_model is None:
|
||||
model_name = model_config.model
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = model_config.parameters
|
||||
@ -76,7 +72,7 @@ class ModelConfigConverter:
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||
|
||||
if not skip_check and not model_schema:
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
return ModelConfigWithCredentialsEntity(
|
||||
|
||||
@ -16,9 +16,7 @@ class FileUploadConfigManager:
|
||||
file_upload_dict = config.get("file_upload")
|
||||
if file_upload_dict:
|
||||
if file_upload_dict.get("enabled"):
|
||||
transform_methods = file_upload_dict.get("allowed_file_upload_methods") or file_upload_dict.get(
|
||||
"allowed_upload_methods", []
|
||||
)
|
||||
transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
|
||||
data = {
|
||||
"image_config": {
|
||||
"number_limits": file_upload_dict["number_limits"],
|
||||
|
||||
@ -33,8 +33,8 @@ class BaseAppGenerator:
|
||||
tenant_id=app_config.tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
),
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
@ -47,8 +47,8 @@ class BaseAppGenerator:
|
||||
tenant_id=app_config.tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
),
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
|
||||
@ -217,9 +217,12 @@ class WorkflowCycleManage:
|
||||
).total_seconds()
|
||||
db.session.commit()
|
||||
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
session.add(workflow_run)
|
||||
session.refresh(workflow_run)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
@ -381,7 +384,7 @@ class WorkflowCycleManage:
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
inputs=workflow_run.inputs_dict or {},
|
||||
inputs=workflow_run.inputs_dict,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
),
|
||||
)
|
||||
@ -428,7 +431,7 @@ class WorkflowCycleManage:
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -1,9 +1,16 @@
|
||||
import base64
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_repository
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
|
||||
from core.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
@ -13,6 +20,38 @@ from .models import File, FileTransferMethod, FileType
|
||||
from .tool_file_parser import ToolFileParser
|
||||
|
||||
|
||||
def download_to_target_path(f: File, temp_dir: str, /):
|
||||
if f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
suffix = Path(tool_file.file_key).suffix
|
||||
target_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
_download_file_to_target_path(tool_file.file_key, target_path)
|
||||
return target_path
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
suffix = Path(upload_file.key).suffix
|
||||
target_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
_download_file_to_target_path(upload_file.key, target_path)
|
||||
return target_path
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def _download_file_to_target_path(path: str, target_path: str, /):
|
||||
"""
|
||||
Download and return the contents of a file as bytes.
|
||||
|
||||
This function loads the file from storage and ensures it's in bytes format.
|
||||
|
||||
Args:
|
||||
path (str): The path to the file in storage.
|
||||
target_path (str): The path to the target file.
|
||||
Raises:
|
||||
ValueError: If the loaded file is not a bytes object.
|
||||
"""
|
||||
storage.download(path, target_path)
|
||||
|
||||
|
||||
def get_attr(*, file: File, attr: FileAttribute):
|
||||
match attr:
|
||||
case FileAttribute.TYPE:
|
||||
@ -29,35 +68,17 @@ def get_attr(*, file: File, attr: FileAttribute):
|
||||
return file.remote_url
|
||||
case FileAttribute.EXTENSION:
|
||||
return file.extension
|
||||
case _:
|
||||
raise ValueError(f"Invalid file attribute: {attr}")
|
||||
|
||||
|
||||
def to_prompt_message_content(
|
||||
f: File,
|
||||
/,
|
||||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
):
|
||||
"""
|
||||
Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.
|
||||
|
||||
This function takes a File object and converts it to an appropriate PromptMessageContent
|
||||
object, which can be used as a prompt for image or audio-based AI models.
|
||||
|
||||
Args:
|
||||
f (File): The File object to convert.
|
||||
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
|
||||
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.
|
||||
|
||||
Returns:
|
||||
Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level
|
||||
|
||||
Raises:
|
||||
ValueError: If the file type is not supported or if required data is missing.
|
||||
"""
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
@ -65,7 +86,7 @@ def to_prompt_message_content(
|
||||
|
||||
return ImagePromptMessageContent(data=data, detail=image_detail_config)
|
||||
case FileType.AUDIO:
|
||||
encoded_string = _file_to_encoded_string(f)
|
||||
encoded_string = _get_encoded_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
||||
@ -74,9 +95,20 @@ def to_prompt_message_content(
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||
case FileType.DOCUMENT:
|
||||
data = _get_encoded_string(f)
|
||||
if f.mime_type is None:
|
||||
raise ValueError("Missing file mime_type")
|
||||
return DocumentPromptMessageContent(
|
||||
encode_format="base64",
|
||||
mime_type=f.mime_type,
|
||||
data=data,
|
||||
)
|
||||
case _:
|
||||
raise ValueError("file type f.type is not supported")
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
@ -118,21 +150,16 @@ def _get_encoded_string(f: File, /):
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
encoded_string = base64.b64encode(content).decode("utf-8")
|
||||
return encoded_string
|
||||
data = response.content
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
data = _download_file_content(upload_file.key)
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
data = _download_file_content(tool_file.file_key)
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
case _:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
|
||||
|
||||
def _to_base64_data_string(f: File, /):
|
||||
@ -140,18 +167,6 @@ def _to_base64_data_string(f: File, /):
|
||||
return f"data:{f.mime_type};base64,{encoded_string}"
|
||||
|
||||
|
||||
def _file_to_encoded_string(f: File, /):
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
return _to_base64_data_string(f)
|
||||
case FileType.VIDEO:
|
||||
return _to_base64_data_string(f)
|
||||
case FileType.AUDIO:
|
||||
return _get_encoded_string(f)
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def _to_url(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
|
||||
@ -28,8 +28,8 @@ class FileUploadConfig(BaseModel):
|
||||
|
||||
image_config: Optional[ImageConfig] = None
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
number_limits: int = 0
|
||||
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
)
|
||||
|
||||
retries = 0
|
||||
stream = kwargs.pop("stream", False)
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
@ -52,6 +53,8 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
|
||||
if response.status_code not in STATUS_FORCELIST:
|
||||
if stream:
|
||||
return response.iter_bytes()
|
||||
return response
|
||||
else:
|
||||
logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list")
|
||||
|
||||
@ -29,6 +29,8 @@ from core.rag.splitter.fixed_text_splitter import (
|
||||
FixedRecursiveCharacterTextSplitter,
|
||||
)
|
||||
from core.rag.splitter.text_splitter import TextSplitter
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
@ -278,6 +280,19 @@ class IndexingRunner:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
|
||||
# delete image files and related db records
|
||||
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"Delete image_files failed while indexing_estimate, \
|
||||
image_upload_file_is: {}".format(upload_file_id)
|
||||
)
|
||||
db.session.delete(image_file)
|
||||
|
||||
if doc_form and doc_form == "qa_model":
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
@ -500,11 +515,7 @@ class IndexingRunner:
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
# delete Splitter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:]
|
||||
else:
|
||||
page_content = page_content
|
||||
document_node.page_content = page_content
|
||||
document_node.page_content = remove_leading_symbols(page_content)
|
||||
|
||||
if document_node.page_content:
|
||||
split_documents.append(document_node)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.file.models import FileType
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -27,7 +27,7 @@ class TokenBufferMemory:
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
) -> list[PromptMessage]:
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
:param max_token_limit: max token limit
|
||||
@ -102,12 +102,11 @@ class TokenBufferMemory:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
for file in file_objs:
|
||||
if file.type in {FileType.IMAGE, FileType.AUDIO}:
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=detail,
|
||||
)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=detail,
|
||||
)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
|
||||
|
||||
@ -100,10 +100,10 @@ class ModelInstance:
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
@ -31,7 +32,7 @@ class Callback(ABC):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
@ -60,7 +61,7 @@ class Callback(ABC):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
):
|
||||
@ -90,7 +91,7 @@ class Callback(ABC):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
@ -120,7 +121,7 @@ class Callback(ABC):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
|
||||
@ -2,6 +2,7 @@ from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsa
|
||||
from .message_entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
@ -37,4 +38,5 @@ __all__ = [
|
||||
"LLMResultChunk",
|
||||
"LLMResultChunkDelta",
|
||||
"AudioPromptMessageContent",
|
||||
"DocumentPromptMessageContent",
|
||||
]
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from abc import ABC
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@ -48,7 +49,7 @@ class PromptMessageFunction(BaseModel):
|
||||
function: PromptMessageTool
|
||||
|
||||
|
||||
class PromptMessageContentType(Enum):
|
||||
class PromptMessageContentType(str, Enum):
|
||||
"""
|
||||
Enum class for prompt message content type.
|
||||
"""
|
||||
@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
|
||||
IMAGE = "image"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
DOCUMENT = "document"
|
||||
|
||||
|
||||
class PromptMessageContent(BaseModel):
|
||||
@ -101,13 +103,20 @@ class ImagePromptMessageContent(PromptMessageContent):
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(PromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
||||
encode_format: Literal["base64"]
|
||||
mime_type: str
|
||||
data: str
|
||||
|
||||
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
"""
|
||||
Model class for prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole
|
||||
content: Optional[str | list[PromptMessageContent]] = None
|
||||
content: Optional[str | Sequence[PromptMessageContent]] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
|
||||
@ -87,6 +87,9 @@ class ModelFeature(Enum):
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = "vision"
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
class DefaultParameterName(str, Enum):
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
@ -169,7 +169,7 @@ class LargeLanguageModel(AIModel):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
@ -212,7 +212,7 @@ if you are not sure about the structure.
|
||||
)
|
||||
|
||||
model_parameters.pop("response_format")
|
||||
stop = stop or []
|
||||
stop = list(stop) if stop is not None else []
|
||||
stop.extend(["\n```", "```\n"])
|
||||
block_prompts = block_prompts.replace("{{block}}", code_block)
|
||||
|
||||
@ -408,7 +408,7 @@ if you are not sure about the structure.
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
@ -479,7 +479,7 @@ if you are not sure about the structure.
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -601,7 +601,7 @@ if you are not sure about the structure.
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
@ -647,7 +647,7 @@ if you are not sure about the structure.
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
@ -694,7 +694,7 @@ if you are not sure about the structure.
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
@ -742,7 +742,7 @@ if you are not sure about the structure.
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
||||
@ -7,6 +7,7 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
||||
@ -7,6 +7,7 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import anthropic
|
||||
@ -21,9 +21,9 @@ from httpx import Timeout
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
@ -33,6 +33,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
@ -86,10 +87,10 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -130,9 +131,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
# Add the new header for claude-3-5-sonnet-20240620 model
|
||||
extra_headers = {}
|
||||
if model == "claude-3-5-sonnet-20240620":
|
||||
if model_parameters.get("max_tokens") > 4096:
|
||||
if model_parameters.get("max_tokens", 0) > 4096:
|
||||
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
|
||||
|
||||
if any(
|
||||
isinstance(content, DocumentPromptMessageContent)
|
||||
for prompt_message in prompt_messages
|
||||
if isinstance(prompt_message.content, list)
|
||||
for content in prompt_message.content
|
||||
):
|
||||
extra_headers["anthropic-beta"] = "pdfs-2024-09-25"
|
||||
|
||||
if tools:
|
||||
extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools]
|
||||
response = client.beta.tools.messages.create(
|
||||
@ -325,14 +334,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
assistant_prompt_message.tool_calls.append(tool_call)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
# transform usage
|
||||
prompt_tokens = response.usage.input_tokens
|
||||
completion_tokens = response.usage.output_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
prompt_tokens = (response.usage and response.usage.input_tokens) or self.get_num_tokens(
|
||||
model, credentials, prompt_messages
|
||||
)
|
||||
|
||||
completion_tokens = (response.usage and response.usage.output_tokens) or self.get_num_tokens(
|
||||
model, credentials, [assistant_prompt_message]
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
@ -505,6 +513,21 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
"source": {"type": "base64", "media_type": mime_type, "data": base64_data},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif isinstance(message_content, DocumentPromptMessageContent):
|
||||
if message_content.mime_type != "application/pdf":
|
||||
raise ValueError(
|
||||
f"Unsupported document type {message_content.mime_type}, "
|
||||
"only support application/pdf"
|
||||
)
|
||||
sub_message_dict = {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": message_content.encode_format,
|
||||
"media_type": message_content.mime_type,
|
||||
"data": message_content.data,
|
||||
},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
prompt_message_dicts.append({"role": "user", "content": sub_messages})
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
|
||||
@ -2,13 +2,11 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
# 3rd import
|
||||
import boto3
|
||||
import requests
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import (
|
||||
ClientError,
|
||||
@ -439,22 +437,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
if not message_content.data.startswith("data:"):
|
||||
# fetch image data from url
|
||||
try:
|
||||
url = message_content.data
|
||||
image_content = requests.get(url).content
|
||||
if "?" in url:
|
||||
url = url.split("?")[0]
|
||||
mime_type, _ = mimetypes.guess_type(url)
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
else:
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
image_content = base64.b64decode(base64_data)
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
image_content = base64.b64decode(base64_data)
|
||||
|
||||
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||
raise ValueError(
|
||||
|
||||
@ -691,8 +691,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
base_model_schema = cast(AIModelEntity, base_model_schema)
|
||||
|
||||
base_model_schema_features = base_model_schema.features or []
|
||||
base_model_schema_model_properties = base_model_schema.model_properties or {}
|
||||
base_model_schema_parameters_rules = base_model_schema.parameter_rules or []
|
||||
base_model_schema_model_properties = base_model_schema.model_properties
|
||||
base_model_schema_parameters_rules = base_model_schema.parameter_rules
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
|
||||
@ -7,6 +7,7 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
@ -24,14 +25,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -7,6 +7,7 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
@ -24,14 +25,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -7,6 +7,7 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
@ -24,14 +25,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -7,6 +7,7 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
@ -24,14 +25,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -7,6 +7,7 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
@ -24,14 +25,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -7,6 +7,7 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
@ -24,14 +25,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -7,6 +7,7 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
@ -24,14 +25,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -7,6 +7,7 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
@ -24,14 +25,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -7,6 +7,7 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
@ -24,14 +25,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -7,6 +7,7 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
@ -24,14 +25,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -7,6 +7,7 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
@ -24,14 +25,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -7,6 +7,7 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
@ -24,14 +25,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -7,6 +7,7 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
@ -24,14 +25,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -7,9 +7,10 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
@ -24,14 +25,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -0,0 +1,38 @@
|
||||
model: gemini-exp-1121
|
||||
label:
|
||||
en_US: Gemini exp 1121
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -32,3 +32,4 @@ pricing:
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
||||
@ -36,3 +36,4 @@ pricing:
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
||||
@ -0,0 +1,38 @@
|
||||
model: learnlm-1.5-pro-experimental
|
||||
label:
|
||||
en_US: LearnLM 1.5 Pro Experimental
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -1,7 +1,6 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
@ -17,6 +16,7 @@ from PIL import Image
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
@ -36,16 +36,20 @@ from core.model_runtime.errors.invoke import (
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
""" # noqa: E501
|
||||
GOOGLE_AVAILABLE_MIMETYPE = [
|
||||
"application/pdf",
|
||||
"application/x-javascript",
|
||||
"text/javascript",
|
||||
"application/x-python",
|
||||
"text/x-python",
|
||||
"text/plain",
|
||||
"text/html",
|
||||
"text/css",
|
||||
"text/md",
|
||||
"text/csv",
|
||||
"text/xml",
|
||||
"text/rtf",
|
||||
]
|
||||
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
@ -155,7 +159,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
|
||||
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
|
||||
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
@ -184,7 +188,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
config_kwargs = model_parameters.copy()
|
||||
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
|
||||
if schema := config_kwargs.pop("json_schema", None):
|
||||
try:
|
||||
schema = json.loads(schema)
|
||||
except:
|
||||
raise exceptions.InvalidArgument("Invalid JSON Schema")
|
||||
if tools:
|
||||
raise exceptions.InvalidArgument("gemini not support use Tools and JSON Schema at same time")
|
||||
config_kwargs["response_schema"] = schema
|
||||
config_kwargs["response_mime_type"] = "application/json"
|
||||
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
@ -374,6 +386,12 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
|
||||
glm_content["parts"].append(blob)
|
||||
elif c.type == PromptMessageContentType.DOCUMENT:
|
||||
message_content = cast(DocumentPromptMessageContent, c)
|
||||
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
|
||||
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
|
||||
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
|
||||
glm_content["parts"].append(blob)
|
||||
|
||||
return glm_content
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
|
||||
@ -22,6 +22,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
@ -86,6 +87,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
@ -153,6 +155,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
@ -196,6 +199,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
if completion_type is LLMMode.CHAT:
|
||||
endpoint_url = urljoin(endpoint_url, "api/chat")
|
||||
data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
if tools:
|
||||
data["tools"] = [self._convert_prompt_message_tool_to_dict(tool) for tool in tools]
|
||||
else:
|
||||
endpoint_url = urljoin(endpoint_url, "api/generate")
|
||||
first_prompt_message = prompt_messages[0]
|
||||
@ -232,7 +237,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
|
||||
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages, tools)
|
||||
|
||||
def _handle_generate_response(
|
||||
self,
|
||||
@ -241,6 +246,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
completion_type: LLMMode,
|
||||
response: requests.Response,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm completion response
|
||||
@ -253,14 +259,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
:return: llm result
|
||||
"""
|
||||
response_json = response.json()
|
||||
|
||||
tool_calls = []
|
||||
if completion_type is LLMMode.CHAT:
|
||||
message = response_json.get("message", {})
|
||||
response_content = message.get("content", "")
|
||||
response_tool_calls = message.get("tool_calls", [])
|
||||
tool_calls = [self._extract_response_tool_call(tool_call) for tool_call in response_tool_calls]
|
||||
else:
|
||||
response_content = response_json["response"]
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=response_content)
|
||||
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=tool_calls)
|
||||
|
||||
if "prompt_eval_count" in response_json and "eval_count" in response_json:
|
||||
# transform usage
|
||||
@ -405,9 +413,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
def _convert_prompt_message_tool_to_dict(self, tool: PromptMessageTool) -> dict:
|
||||
"""
|
||||
Convert PromptMessageTool to dict for Ollama API
|
||||
|
||||
:param tool: tool
|
||||
:return: tool dict
|
||||
"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for Ollama API
|
||||
|
||||
:param message: prompt message
|
||||
:return: message dict
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
@ -432,6 +459,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"role": "tool", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
@ -452,6 +482,29 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _extract_response_tool_call(self, response_tool_call: dict) -> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Extract response tool call
|
||||
"""
|
||||
tool_call = None
|
||||
if response_tool_call and "function" in response_tool_call:
|
||||
# Convert arguments to JSON string if it's a dict
|
||||
arguments = response_tool_call.get("function").get("arguments")
|
||||
if isinstance(arguments, dict):
|
||||
arguments = json.dumps(arguments)
|
||||
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.get("function").get("name"),
|
||||
arguments=arguments,
|
||||
)
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.get("function").get("name"),
|
||||
type="function",
|
||||
function=function,
|
||||
)
|
||||
|
||||
return tool_call
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
Get customizable model schema.
|
||||
@ -461,10 +514,15 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
:return: model schema
|
||||
"""
|
||||
extras = {}
|
||||
extras = {
|
||||
"features": [],
|
||||
}
|
||||
|
||||
if "vision_support" in credentials and credentials["vision_support"] == "true":
|
||||
extras["features"] = [ModelFeature.VISION]
|
||||
extras["features"].append(ModelFeature.VISION)
|
||||
if "function_call_support" in credentials and credentials["function_call_support"] == "true":
|
||||
extras["features"].append(ModelFeature.TOOL_CALL)
|
||||
extras["features"].append(ModelFeature.MULTI_TOOL_CALL)
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
|
||||
@ -96,3 +96,22 @@ model_credential_schema:
|
||||
label:
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
- variable: function_call_support
|
||||
label:
|
||||
zh_Hans: 是否支持函数调用
|
||||
en_US: Function call support
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
default: 'false'
|
||||
type: radio
|
||||
required: false
|
||||
options:
|
||||
- value: 'true'
|
||||
label:
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
- value: 'false'
|
||||
label:
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
- gpt-4o
|
||||
- gpt-4o-2024-05-13
|
||||
- gpt-4o-2024-08-06
|
||||
- gpt-4o-2024-11-20
|
||||
- chatgpt-4o-latest
|
||||
- gpt-4o-mini
|
||||
- gpt-4o-mini-2024-07-18
|
||||
|
||||
@ -0,0 +1,47 @@
|
||||
model: gpt-4o-2024-11-20
|
||||
label:
|
||||
zh_Hans: gpt-4o-2024-11-20
|
||||
en_US: gpt-4o-2024-11-20
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- json_schema
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '2.50'
|
||||
output: '10.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -7,7 +7,7 @@ features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
|
||||
@ -615,19 +615,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||
|
||||
# o1 compatibility
|
||||
block_as_stream = False
|
||||
if model.startswith("o1"):
|
||||
if "max_tokens" in model_parameters:
|
||||
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
|
||||
del model_parameters["max_tokens"]
|
||||
|
||||
if stream:
|
||||
block_as_stream = True
|
||||
stream = False
|
||||
|
||||
if "stream_options" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stream_options"]
|
||||
|
||||
if "stop" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stop"]
|
||||
|
||||
@ -644,47 +636,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
if block_as_stream:
|
||||
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
|
||||
|
||||
return block_result
|
||||
|
||||
def _handle_chat_block_as_stream_response(
|
||||
self,
|
||||
block_result: LLMResult,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
text = block_result.message.content
|
||||
text = cast(str, text)
|
||||
|
||||
if stop:
|
||||
text = self.enforce_stop_tokens(text, stop)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=block_result.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=block_result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
finish_reason="stop",
|
||||
usage=block_result.usage,
|
||||
),
|
||||
)
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
@ -1178,8 +1130,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
base_model_schema = model_map[base_model]
|
||||
|
||||
base_model_schema_features = base_model_schema.features or []
|
||||
base_model_schema_model_properties = base_model_schema.model_properties or {}
|
||||
base_model_schema_parameters_rules = base_model_schema.parameter_rules or []
|
||||
base_model_schema_model_properties = base_model_schema.model_properties
|
||||
base_model_schema_parameters_rules = base_model_schema.parameter_rules
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
|
||||
@ -64,7 +64,7 @@ class OAICompatRerankModel(RerankModel):
|
||||
|
||||
# TODO: Do we need truncate docs to avoid llama.cpp return error?
|
||||
|
||||
data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n}
|
||||
data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n, "return_documents": True}
|
||||
|
||||
try:
|
||||
response = post(str(URL(url) / "rerank"), headers=headers, data=dumps(data), timeout=60)
|
||||
@ -83,7 +83,13 @@ class OAICompatRerankModel(RerankModel):
|
||||
index = result["index"]
|
||||
|
||||
# Retrieve document text (fallback if llama.cpp rerank doesn't return it)
|
||||
text = result.get("document", {}).get("text", docs[index])
|
||||
text = docs[index]
|
||||
document = result.get("document", {})
|
||||
if document:
|
||||
if isinstance(document, dict):
|
||||
text = document.get("text", docs[index])
|
||||
elif isinstance(document, str):
|
||||
text = document
|
||||
|
||||
# Normalize the score
|
||||
normalized_score = (result["relevance_score"] - min_score) / score_range
|
||||
|
||||
@ -37,13 +37,14 @@ class OpenLLMGenerateMessage:
|
||||
class OpenLLMGenerate:
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
server_url: str,
|
||||
model_name: str,
|
||||
stream: bool,
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str],
|
||||
stop: list[str] | None = None,
|
||||
prompt_messages: list[OpenLLMGenerateMessage],
|
||||
user: str,
|
||||
user: str | None = None,
|
||||
) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]:
|
||||
if not server_url:
|
||||
raise InvalidAuthenticationError("Invalid server URL")
|
||||
|
||||
@ -45,19 +45,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._update_credential(model, credentials)
|
||||
|
||||
block_as_stream = False
|
||||
if model.startswith("openai/o1"):
|
||||
block_as_stream = True
|
||||
stop = None
|
||||
|
||||
# invoke block as stream
|
||||
if stream and block_as_stream:
|
||||
return self._generate_block_as_stream(
|
||||
model, credentials, prompt_messages, model_parameters, tools, stop, user
|
||||
)
|
||||
else:
|
||||
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _generate_block_as_stream(
|
||||
self,
|
||||
@ -69,9 +57,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
stop: Optional[list[str]] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Generator:
|
||||
resp: LLMResult = super()._generate(
|
||||
model, credentials, prompt_messages, model_parameters, tools, stop, False, user
|
||||
)
|
||||
resp = super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, False, user)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
|
||||
@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- vision
|
||||
- agent-thought
|
||||
- video
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
|
||||
@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- vision
|
||||
- agent-thought
|
||||
- video
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
|
||||
@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- vision
|
||||
- agent-thought
|
||||
- video
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
|
||||
@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- vision
|
||||
- agent-thought
|
||||
- video
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8000
|
||||
|
||||
@ -65,6 +65,8 @@ class GTERerankModel(RerankModel):
|
||||
)
|
||||
|
||||
rerank_documents = []
|
||||
if not response.output:
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
for _, result in enumerate(response.output.results):
|
||||
# format document
|
||||
rerank_document = RerankDocument(
|
||||
|
||||
@ -22,7 +22,7 @@ def get_model_config(credentials: dict) -> ModelConfig:
|
||||
return ModelConfig(
|
||||
properties=ModelProperties(
|
||||
context_size=int(credentials.get("context_size", 0)),
|
||||
max_chunks=int(credentials.get("max_chunks", 0)),
|
||||
max_chunks=int(credentials.get("max_chunks", 1)),
|
||||
)
|
||||
)
|
||||
return model_configs
|
||||
|
||||
@ -6,6 +6,7 @@ model_properties:
|
||||
mode: chat
|
||||
features:
|
||||
- vision
|
||||
- video
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
|
||||
|
||||
@ -62,5 +65,5 @@ class KeywordsModeration(Moderation):
|
||||
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
||||
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
|
||||
|
||||
def _check_keywords_in_value(self, keywords_list, value) -> bool:
|
||||
return any(keyword.lower() in value.lower() for keyword in keywords_list)
|
||||
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:
|
||||
return any(keyword.lower() in str(value).lower() for keyword in keywords_list)
|
||||
|
||||
@ -49,6 +49,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run")
|
||||
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
|
||||
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
|
||||
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
|
||||
@ -25,7 +25,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
LangSmithRunType,
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
@ -62,6 +62,16 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = trace_info.message_id or trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
||||
message_dotted_order = (
|
||||
generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
|
||||
)
|
||||
workflow_dotted_order = generate_dotted_order(
|
||||
trace_info.workflow_app_log_id or trace_info.workflow_run_id,
|
||||
trace_info.workflow_data.created_at,
|
||||
message_dotted_order,
|
||||
)
|
||||
|
||||
if trace_info.message_id:
|
||||
message_run = LangSmithRunModel(
|
||||
id=trace_info.message_id,
|
||||
@ -76,6 +86,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
},
|
||||
tags=["message", "workflow"],
|
||||
error=trace_info.error,
|
||||
trace_id=trace_id,
|
||||
dotted_order=message_dotted_order,
|
||||
)
|
||||
self.add_run(message_run)
|
||||
|
||||
@ -95,6 +107,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
error=trace_info.error,
|
||||
tags=["workflow"],
|
||||
parent_run_id=trace_info.message_id or None,
|
||||
trace_id=trace_id,
|
||||
dotted_order=workflow_dotted_order,
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
@ -177,6 +191,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
else:
|
||||
run_type = LangSmithRunType.tool
|
||||
|
||||
node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
|
||||
langsmith_run = LangSmithRunModel(
|
||||
total_tokens=node_total_tokens,
|
||||
name=node_type,
|
||||
@ -191,6 +206,9 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
},
|
||||
parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
|
||||
tags=["node_execution"],
|
||||
id=node_execution_id,
|
||||
trace_id=trace_id,
|
||||
dotted_order=node_dotted_order,
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
@ -43,3 +44,19 @@ def replace_text_with_content(data):
|
||||
return [replace_text_with_content(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def generate_dotted_order(
|
||||
run_id: str, start_time: Union[str, datetime], parent_dotted_order: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
generate dotted_order for langsmith
|
||||
"""
|
||||
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
|
||||
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
|
||||
current_segment = f"{timestamp}{run_id}"
|
||||
|
||||
if parent_dotted_order is None:
|
||||
return current_segment
|
||||
|
||||
return f"{parent_dotted_order}.{current_segment}"
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from core.model_runtime.entities import (
|
||||
@ -14,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode
|
||||
|
||||
class PromptMessageUtil:
|
||||
@staticmethod
|
||||
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]:
|
||||
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Prompt messages to prompt for saving.
|
||||
:param model_mode: model mode
|
||||
|
||||
@ -12,7 +12,7 @@ class CleanProcessor:
|
||||
# Unicode U+FFFE
|
||||
text = re.sub("\ufffe", "", text)
|
||||
|
||||
rules = process_rule["rules"] if process_rule else None
|
||||
rules = process_rule["rules"] if process_rule else {}
|
||||
if "pre_processing_rules" in rules:
|
||||
pre_processing_rules = rules["pre_processing_rules"]
|
||||
for pre_processing_rule in pre_processing_rules:
|
||||
|
||||
@ -1,310 +1,62 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
_import_err_msg = (
|
||||
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||
)
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
|
||||
AnalyticdbVectorOpenAPI,
|
||||
AnalyticdbVectorOpenAPIConfig,
|
||||
)
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region_id: str
|
||||
instance_id: str
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = ("dify",)
|
||||
namespace_password: str = (None,)
|
||||
metrics: str = ("cosine",)
|
||||
read_timeout: int = 60000
|
||||
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
"access_key_id": self.access_key_id,
|
||||
"access_key_secret": self.access_key_secret,
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: AnalyticdbConfig):
|
||||
self._collection_name = collection_name.lower()
|
||||
try:
|
||||
from alibabacloud_gpdb20160503.client import Client
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self.config = config
|
||||
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
||||
self._client = Client(self._client_config)
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
cache_key = f"vector_indexing_{self.config.instance_id}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self.config.instance_id}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
self._initialize_vector_database()
|
||||
self._create_namespace_if_not_exists()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.init_vector_database(request)
|
||||
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.describe_namespace(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
request = gpdb_20160503_models.CreateNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
)
|
||||
self._client.create_namespace(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
)
|
||||
self._client.describe_collection(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
|
||||
full_text_retrieval_fields = "page_content"
|
||||
request = gpdb_20160503_models.CreateCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
collection=self._collection_name,
|
||||
dimension=embedding_dimension,
|
||||
metrics=self.config.metrics,
|
||||
metadata=metadata,
|
||||
full_text_retrieval_fields=full_text_retrieval_fields,
|
||||
)
|
||||
self._client.create_collection(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
def __init__(
|
||||
self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
|
||||
):
|
||||
super().__init__(collection_name)
|
||||
if api_config is not None:
|
||||
self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
|
||||
else:
|
||||
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ANALYTICDB
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection_if_not_exists(dimension)
|
||||
self.add_texts(texts, embeddings)
|
||||
self.analyticdb_vector._create_collection_if_not_exists(dimension)
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
rows=rows,
|
||||
)
|
||||
self._client.upsert_collection_data(request)
|
||||
def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
return self.analyticdb_vector.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
ids_str = f"({ids_str})"
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
self.analyticdb_vector.delete_by_ids(ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
self.analyticdb_vector.delete_by_metadata_field(key, value)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents
|
||||
return self.analyticdb_vector.search_by_vector(query_vector)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.metadata.get("vector"),
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents
|
||||
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
collection=self._collection_name,
|
||||
dbinstance_id=self.config.instance_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
region_id=self.config.region_id,
|
||||
)
|
||||
self._client.delete_collection(request)
|
||||
except Exception as e:
|
||||
raise e
|
||||
self.analyticdb_vector.delete()
|
||||
|
||||
|
||||
class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
@ -313,26 +65,9 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
|
||||
|
||||
# handle optional params
|
||||
if dify_config.ANALYTICDB_KEY_ID is None:
|
||||
raise ValueError("ANALYTICDB_KEY_ID should not be None")
|
||||
if dify_config.ANALYTICDB_KEY_SECRET is None:
|
||||
raise ValueError("ANALYTICDB_KEY_SECRET should not be None")
|
||||
if dify_config.ANALYTICDB_REGION_ID is None:
|
||||
raise ValueError("ANALYTICDB_REGION_ID should not be None")
|
||||
if dify_config.ANALYTICDB_INSTANCE_ID is None:
|
||||
raise ValueError("ANALYTICDB_INSTANCE_ID should not be None")
|
||||
if dify_config.ANALYTICDB_ACCOUNT is None:
|
||||
raise ValueError("ANALYTICDB_ACCOUNT should not be None")
|
||||
if dify_config.ANALYTICDB_PASSWORD is None:
|
||||
raise ValueError("ANALYTICDB_PASSWORD should not be None")
|
||||
if dify_config.ANALYTICDB_NAMESPACE is None:
|
||||
raise ValueError("ANALYTICDB_NAMESPACE should not be None")
|
||||
if dify_config.ANALYTICDB_NAMESPACE_PASSWORD is None:
|
||||
raise ValueError("ANALYTICDB_NAMESPACE_PASSWORD should not be None")
|
||||
return AnalyticdbVector(
|
||||
collection_name,
|
||||
AnalyticdbConfig(
|
||||
if dify_config.ANALYTICDB_HOST is None:
|
||||
# implemented through OpenAPI
|
||||
apiConfig = AnalyticdbVectorOpenAPIConfig(
|
||||
access_key_id=dify_config.ANALYTICDB_KEY_ID,
|
||||
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
|
||||
region_id=dify_config.ANALYTICDB_REGION_ID,
|
||||
@ -341,5 +76,22 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
|
||||
),
|
||||
)
|
||||
sqlConfig = None
|
||||
else:
|
||||
# implemented through sql
|
||||
sqlConfig = AnalyticdbVectorBySqlConfig(
|
||||
host=dify_config.ANALYTICDB_HOST,
|
||||
port=dify_config.ANALYTICDB_PORT,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT,
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
|
||||
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||
)
|
||||
apiConfig = None
|
||||
return AnalyticdbVector(
|
||||
collection_name,
|
||||
apiConfig,
|
||||
sqlConfig,
|
||||
)
|
||||
|
||||
@ -0,0 +1,309 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
_import_err_msg = (
|
||||
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||
)
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region_id: str
|
||||
instance_id: str
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = "dify"
|
||||
namespace_password: str = (None,)
|
||||
metrics: str = "cosine"
|
||||
read_timeout: int = 60000
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["access_key_id"]:
|
||||
raise ValueError("config ANALYTICDB_KEY_ID is required")
|
||||
if not values["access_key_secret"]:
|
||||
raise ValueError("config ANALYTICDB_KEY_SECRET is required")
|
||||
if not values["region_id"]:
|
||||
raise ValueError("config ANALYTICDB_REGION_ID is required")
|
||||
if not values["instance_id"]:
|
||||
raise ValueError("config ANALYTICDB_INSTANCE_ID is required")
|
||||
if not values["account"]:
|
||||
raise ValueError("config ANALYTICDB_ACCOUNT is required")
|
||||
if not values["account_password"]:
|
||||
raise ValueError("config ANALYTICDB_PASSWORD is required")
|
||||
if not values["namespace_password"]:
|
||||
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
"access_key_id": self.access_key_id,
|
||||
"access_key_secret": self.access_key_secret,
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPI:
|
||||
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
|
||||
try:
|
||||
from alibabacloud_gpdb20160503.client import Client
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self._collection_name = collection_name.lower()
|
||||
self.config = config
|
||||
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
||||
self._client = Client(self._client_config)
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
cache_key = f"vector_initialize_{self.config.instance_id}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
database_exist_cache_key = f"vector_initialize_{self.config.instance_id}"
|
||||
if redis_client.get(database_exist_cache_key):
|
||||
return
|
||||
self._initialize_vector_database()
|
||||
self._create_namespace_if_not_exists()
|
||||
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.init_vector_database(request)
|
||||
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.describe_namespace(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
request = gpdb_20160503_models.CreateNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
)
|
||||
self._client.create_namespace(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
)
|
||||
self._client.describe_collection(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
|
||||
full_text_retrieval_fields = "page_content"
|
||||
request = gpdb_20160503_models.CreateCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
collection=self._collection_name,
|
||||
dimension=embedding_dimension,
|
||||
metrics=self.config.metrics,
|
||||
metadata=metadata,
|
||||
full_text_retrieval_fields=full_text_retrieval_fields,
|
||||
)
|
||||
self._client.create_collection(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
rows=rows,
|
||||
)
|
||||
self._client.upsert_collection_data(request)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
ids_str = f"({ids_str})"
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.values.value,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.values.value,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
collection=self._collection_name,
|
||||
dbinstance_id=self.config.instance_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
region_id=self.config.region_id,
|
||||
)
|
||||
self._client.delete_collection(request)
|
||||
except Exception as e:
|
||||
raise e
|
||||
245
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
Normal file
245
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
Normal file
@ -0,0 +1,245 @@
|
||||
import json
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class AnalyticdbVectorBySqlConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
account: str
|
||||
account_password: str
|
||||
min_connection: int
|
||||
max_connection: int
|
||||
namespace: str = "dify"
|
||||
metrics: str = "cosine"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config ANALYTICDB_HOST is required")
|
||||
if not values["port"]:
|
||||
raise ValueError("config ANALYTICDB_PORT is required")
|
||||
if not values["account"]:
|
||||
raise ValueError("config ANALYTICDB_ACCOUNT is required")
|
||||
if not values["account_password"]:
|
||||
raise ValueError("config ANALYTICDB_PASSWORD is required")
|
||||
if not values["min_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MIN_CONNECTION is required")
|
||||
if not values["max_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MAX_CONNECTION is required")
|
||||
if values["min_connection"] > values["max_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION")
|
||||
return values
|
||||
|
||||
|
||||
class AnalyticdbVectorBySql:
|
||||
def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig):
|
||||
self._collection_name = collection_name.lower()
|
||||
self.databaseName = "knowledgebase"
|
||||
self.config = config
|
||||
self.table_name = f"{self.config.namespace}.{self._collection_name}"
|
||||
self.pool = None
|
||||
self._initialize()
|
||||
if not self.pool:
|
||||
self.pool = self._create_connection_pool()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
cache_key = f"vector_initialize_{self.config.host}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
database_exist_cache_key = f"vector_initialize_{self.config.host}"
|
||||
if redis_client.get(database_exist_cache_key):
|
||||
return
|
||||
self._initialize_vector_database()
|
||||
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _create_connection_pool(self):
|
||||
return psycopg2.pool.SimpleConnectionPool(
|
||||
self.config.min_connection,
|
||||
self.config.max_connection,
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
user=self.config.account,
|
||||
password=self.config.account_password,
|
||||
database=self.databaseName,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
conn.commit()
|
||||
self.pool.putconn(conn)
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
user=self.config.account,
|
||||
password=self.config.account_password,
|
||||
database="postgres",
|
||||
)
|
||||
conn.autocommit = True
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(f"CREATE DATABASE {self.databaseName}")
|
||||
except Exception as e:
|
||||
if "already exists" in str(e):
|
||||
return
|
||||
raise e
|
||||
finally:
|
||||
cur.close()
|
||||
conn.close()
|
||||
self.pool = self._create_connection_pool()
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
|
||||
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
|
||||
except Exception as e:
|
||||
if "already exists" not in str(e):
|
||||
raise e
|
||||
cur.execute(
|
||||
"CREATE OR REPLACE FUNCTION "
|
||||
"public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) "
|
||||
"RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ "
|
||||
"SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) "
|
||||
"FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) "
|
||||
"AS words_only;$function$"
|
||||
)
|
||||
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"CREATE TABLE IF NOT EXISTS {self.table_name}("
|
||||
f"id text PRIMARY KEY,"
|
||||
f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
|
||||
f"to_tsvector TSVECTOR"
|
||||
f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
|
||||
)
|
||||
if embedding_dimension is not None:
|
||||
index_name = f"{self._collection_name}_embedding_idx"
|
||||
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
|
||||
cur.execute(
|
||||
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
|
||||
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
|
||||
f"pq_enable=0, external_storage=0)"
|
||||
)
|
||||
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
values = []
|
||||
id_prefix = str(uuid.uuid4()) + "_"
|
||||
sql = f"""
|
||||
INSERT INTO {self.table_name}
|
||||
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
|
||||
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
|
||||
"""
|
||||
for i, doc in enumerate(documents):
|
||||
values.append(
|
||||
(
|
||||
id_prefix + str(i),
|
||||
doc.metadata.get("doc_id", str(uuid.uuid4())),
|
||||
embeddings[i],
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
doc.page_content,
|
||||
)
|
||||
)
|
||||
with self._get_cursor() as cur:
|
||||
psycopg2.extras.execute_batch(cur, sql, values)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,))
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),))
|
||||
except Exception as e:
|
||||
if "does not exist" not in str(e):
|
||||
raise e
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value))
|
||||
except Exception as e:
|
||||
if "does not exist" not in str(e):
|
||||
raise e
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
with self._get_cursor() as cur:
|
||||
query_vector_str = json.dumps(query_vector)
|
||||
query_vector_str = "{" + query_vector_str[1:-1] + "}"
|
||||
cur.execute(
|
||||
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
|
||||
f"t.page_content as page_content, t.metadata_ AS metadata_ "
|
||||
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
|
||||
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
|
||||
(query_vector_str,),
|
||||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, score, page_content, metadata = record
|
||||
if score > score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT id, vector, page_content, metadata_,
|
||||
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, page_content, metadata, score = record
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
@ -11,6 +11,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
|
||||
@ -43,11 +44,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
document_node.metadata["doc_id"] = doc_id
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
# delete Splitter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:].strip()
|
||||
else:
|
||||
page_content = page_content
|
||||
page_content = remove_leading_symbols(document_node.page_content).strip()
|
||||
if len(page_content) > 0:
|
||||
document_node.page_content = page_content
|
||||
split_documents.append(document_node)
|
||||
|
||||
@ -18,6 +18,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
|
||||
@ -53,11 +54,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
# delete Splitter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:]
|
||||
else:
|
||||
page_content = page_content
|
||||
document_node.page_content = page_content
|
||||
document_node.page_content = remove_leading_symbols(page_content)
|
||||
split_documents.append(document_node)
|
||||
all_documents.extend(split_documents)
|
||||
for i in range(0, len(all_documents), 10):
|
||||
|
||||
@ -27,11 +27,11 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
:return:
|
||||
"""
|
||||
docs = []
|
||||
doc_id = set()
|
||||
doc_ids = set()
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
if document.provider == "dify" and document.metadata["doc_id"] not in doc_id:
|
||||
doc_id.add(document.metadata["doc_id"])
|
||||
if document.provider == "dify" and document.metadata["doc_id"] not in doc_ids:
|
||||
doc_ids.add(document.metadata["doc_id"])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
elif document.provider == "external":
|
||||
|
||||
@ -36,23 +36,20 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
|
||||
:return:
|
||||
"""
|
||||
docs = []
|
||||
doc_id = []
|
||||
unique_documents = []
|
||||
doc_ids = set()
|
||||
for document in documents:
|
||||
if document.metadata["doc_id"] not in doc_id:
|
||||
doc_id.append(document.metadata["doc_id"])
|
||||
docs.append(document.page_content)
|
||||
if document.metadata["doc_id"] not in doc_ids:
|
||||
doc_ids.add(document.metadata["doc_id"])
|
||||
unique_documents.append(document)
|
||||
|
||||
documents = unique_documents
|
||||
|
||||
rerank_documents = []
|
||||
query_scores = self._calculate_keyword_score(query, documents)
|
||||
|
||||
query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
|
||||
|
||||
rerank_documents = []
|
||||
for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
|
||||
# format document
|
||||
score = (
|
||||
self.weights.vector_setting.vector_weight * query_vector_score
|
||||
+ self.weights.keyword_setting.keyword_weight * query_score
|
||||
@ -61,7 +58,8 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
continue
|
||||
document.metadata["score"] = score
|
||||
rerank_documents.append(document)
|
||||
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
|
||||
rerank_documents.sort(key=lambda x: x.metadata["score"], reverse=True)
|
||||
return rerank_documents[:top_n] if top_n else rerank_documents
|
||||
|
||||
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
|
||||
|
||||
@ -66,6 +66,41 @@ class BingSearchTool(BuiltinTool):
|
||||
results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}'))
|
||||
|
||||
return results
|
||||
elif result_type == "json":
|
||||
result = {}
|
||||
if search_results:
|
||||
result["organic"] = [
|
||||
{
|
||||
"title": item.get("name", ""),
|
||||
"snippet": item.get("snippet", ""),
|
||||
"url": item.get("url", ""),
|
||||
"siteName": item.get("siteName", ""),
|
||||
}
|
||||
for item in search_results
|
||||
]
|
||||
|
||||
if computation and "expression" in computation and "value" in computation:
|
||||
result["computation"] = {"expression": computation["expression"], "value": computation["value"]}
|
||||
|
||||
if entities:
|
||||
result["entities"] = [
|
||||
{
|
||||
"name": item.get("name", ""),
|
||||
"url": item.get("url", ""),
|
||||
"description": item.get("description", ""),
|
||||
}
|
||||
for item in entities
|
||||
]
|
||||
|
||||
if news:
|
||||
result["news"] = [{"name": item.get("name", ""), "url": item.get("url", "")} for item in news]
|
||||
|
||||
if related_searches:
|
||||
result["related searches"] = [
|
||||
{"displayText": item.get("displayText", ""), "url": item.get("webSearchUrl", "")} for item in news
|
||||
]
|
||||
|
||||
return self.create_json_message(result)
|
||||
else:
|
||||
# construct text
|
||||
text = ""
|
||||
|
||||
@ -113,9 +113,9 @@ parameters:
|
||||
zh_Hans: 结果类型
|
||||
pt_BR: result type
|
||||
human_description:
|
||||
en_US: return a list of links or texts
|
||||
zh_Hans: 返回一个连接列表还是纯文本内容
|
||||
pt_BR: return a list of links or texts
|
||||
en_US: return a list of links, json or texts
|
||||
zh_Hans: 返回一个列表,内容是链接、json还是纯文本
|
||||
pt_BR: return a list of links, json or texts
|
||||
default: text
|
||||
options:
|
||||
- value: link
|
||||
@ -123,6 +123,11 @@ parameters:
|
||||
en_US: Link
|
||||
zh_Hans: 链接
|
||||
pt_BR: Link
|
||||
- value: json
|
||||
label:
|
||||
en_US: JSON
|
||||
zh_Hans: JSON
|
||||
pt_BR: JSON
|
||||
- value: text
|
||||
label:
|
||||
en_US: Text
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
@ -11,6 +11,17 @@ class DuckDuckGoVideoSearchTool(BuiltinTool):
|
||||
Tool for performing a video search using DuckDuckGo search engine.
|
||||
"""
|
||||
|
||||
IFRAME_TEMPLATE: ClassVar[str] = """
|
||||
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden; \
|
||||
max-width: 100%; border-radius: 8px;">
|
||||
<iframe
|
||||
style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;"
|
||||
src="{src}"
|
||||
frameborder="0"
|
||||
allowfullscreen>
|
||||
</iframe>
|
||||
</div>"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
query_dict = {
|
||||
"keywords": tool_parameters.get("query"),
|
||||
@ -26,6 +37,9 @@ class DuckDuckGoVideoSearchTool(BuiltinTool):
|
||||
# Remove None values to use API defaults
|
||||
query_dict = {k: v for k, v in query_dict.items() if v is not None}
|
||||
|
||||
# Get proxy URL from parameters
|
||||
proxy_url = tool_parameters.get("proxy_url", "").strip()
|
||||
|
||||
response = DDGS().videos(**query_dict)
|
||||
|
||||
# Create HTML result with embedded iframes
|
||||
@ -36,20 +50,21 @@ class DuckDuckGoVideoSearchTool(BuiltinTool):
|
||||
title = res.get("title", "")
|
||||
embed_html = res.get("embed_html", "")
|
||||
description = res.get("description", "")
|
||||
content_url = res.get("content", "")
|
||||
|
||||
# Modify iframe to be responsive
|
||||
if embed_html:
|
||||
# Replace fixed dimensions with responsive wrapper and iframe
|
||||
embed_html = """
|
||||
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden; \
|
||||
max-width: 100%; border-radius: 8px;">
|
||||
<iframe
|
||||
style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;"
|
||||
src="{src}"
|
||||
frameborder="0"
|
||||
allowfullscreen>
|
||||
</iframe>
|
||||
</div>""".format(src=res.get("embed_url", ""))
|
||||
# Handle TED.com videos
|
||||
if not embed_html and "ted.com/talks" in content_url:
|
||||
embed_url = content_url.replace("www.ted.com", "embed.ted.com")
|
||||
if proxy_url:
|
||||
embed_url = f"{proxy_url}{embed_url}"
|
||||
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
|
||||
|
||||
# Original YouTube/other platform handling
|
||||
elif embed_html:
|
||||
embed_url = res.get("embed_url", "")
|
||||
if proxy_url and embed_url:
|
||||
embed_url = f"{proxy_url}{embed_url}"
|
||||
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
|
||||
|
||||
markdown_result += f"{title}\n\n"
|
||||
markdown_result += f"{embed_html}\n\n"
|
||||
|
||||
@ -1,40 +1,43 @@
|
||||
identity:
|
||||
name: ddgo_video
|
||||
author: Assistant
|
||||
author: Tao Wang
|
||||
label:
|
||||
en_US: DuckDuckGo Video Search
|
||||
zh_Hans: DuckDuckGo 视频搜索
|
||||
description:
|
||||
human:
|
||||
en_US: Perform video searches on DuckDuckGo and get results with embedded videos.
|
||||
zh_Hans: 在 DuckDuckGo 上进行视频搜索并获取可嵌入的视频结果。
|
||||
llm: Perform video searches on DuckDuckGo and get results with embedded videos.
|
||||
en_US: Search and embedded videos.
|
||||
zh_Hans: 搜索并嵌入视频
|
||||
llm: Search videos on duckduckgo and embed videos in iframe
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query String
|
||||
zh_Hans: 查询语句
|
||||
type: string
|
||||
required: true
|
||||
human_description:
|
||||
en_US: Search Query
|
||||
zh_Hans: 搜索查询语句。
|
||||
zh_Hans: 搜索查询语句
|
||||
llm_description: Key words for searching
|
||||
form: llm
|
||||
- name: max_results
|
||||
label:
|
||||
en_US: Max Results
|
||||
zh_Hans: 最大结果数量
|
||||
type: number
|
||||
required: true
|
||||
default: 3
|
||||
minimum: 1
|
||||
maximum: 10
|
||||
label:
|
||||
en_US: Max Results
|
||||
zh_Hans: 最大结果数量
|
||||
human_description:
|
||||
en_US: The max results (1-10).
|
||||
zh_Hans: 最大结果数量(1-10)。
|
||||
en_US: The max results (1-10)
|
||||
zh_Hans: 最大结果数量(1-10)
|
||||
form: form
|
||||
- name: timelimit
|
||||
label:
|
||||
en_US: Result Time Limit
|
||||
zh_Hans: 结果时间限制
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
@ -54,14 +57,14 @@ parameters:
|
||||
label:
|
||||
en_US: Current Year
|
||||
zh_Hans: 今年
|
||||
label:
|
||||
en_US: Result Time Limit
|
||||
zh_Hans: 结果时间限制
|
||||
human_description:
|
||||
en_US: Use when querying results within a specific time range only.
|
||||
en_US: Query results within a specific time range only
|
||||
zh_Hans: 只查询一定时间范围内的结果时使用
|
||||
form: form
|
||||
- name: duration
|
||||
label:
|
||||
en_US: Video Duration
|
||||
zh_Hans: 视频时长
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
@ -77,10 +80,18 @@ parameters:
|
||||
label:
|
||||
en_US: Long (>20 minutes)
|
||||
zh_Hans: 长视频(>20分钟)
|
||||
label:
|
||||
en_US: Video Duration
|
||||
zh_Hans: 视频时长
|
||||
human_description:
|
||||
en_US: Filter videos by duration
|
||||
zh_Hans: 按时长筛选视频
|
||||
form: form
|
||||
- name: proxy_url
|
||||
label:
|
||||
en_US: Proxy URL
|
||||
zh_Hans: 视频代理地址
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
human_description:
|
||||
en_US: Proxy URL
|
||||
zh_Hans: 视频代理地址
|
||||
form: form
|
||||
|
||||
@ -17,7 +17,7 @@ class SendMailTool(BuiltinTool):
|
||||
invoke tools
|
||||
"""
|
||||
sender = self.runtime.credentials.get("email_account", "")
|
||||
email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
|
||||
email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
|
||||
password = self.runtime.credentials.get("email_password", "")
|
||||
smtp_server = self.runtime.credentials.get("smtp_server", "")
|
||||
if not smtp_server:
|
||||
|
||||
@ -18,7 +18,7 @@ class SendMailTool(BuiltinTool):
|
||||
invoke tools
|
||||
"""
|
||||
sender = self.runtime.credentials.get("email_account", "")
|
||||
email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
|
||||
email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
|
||||
password = self.runtime.credentials.get("email_password", "")
|
||||
smtp_server = self.runtime.credentials.get("smtp_server", "")
|
||||
if not smtp_server:
|
||||
|
||||
BIN
api/core/tools/provider/builtin/file_extractor/_assets/icon.png
Normal file
BIN
api/core/tools/provider/builtin/file_extractor/_assets/icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.3 KiB |
@ -0,0 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class FileExtractorProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
@ -0,0 +1,15 @@
|
||||
identity:
|
||||
author: Jyong
|
||||
name: file_extractor
|
||||
label:
|
||||
en_US: File Extractor
|
||||
zh_Hans: 文件提取
|
||||
pt_BR: File Extractor
|
||||
description:
|
||||
en_US: Extract text from file
|
||||
zh_Hans: 从文件中提取文本
|
||||
pt_BR: Extract text from file
|
||||
icon: icon.png
|
||||
tags:
|
||||
- utilities
|
||||
- productivity
|
||||
@ -0,0 +1,45 @@
|
||||
import tempfile
|
||||
from typing import Any, Union
|
||||
|
||||
from core.file.enums import FileType
|
||||
from core.file.file_manager import download_to_target_path
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
from core.rag.splitter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolParameterValidationError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class FileExtractorTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
# image file for workflow mode
|
||||
file = tool_parameters.get("text_file")
|
||||
if file and file.type != FileType.DOCUMENT:
|
||||
raise ToolParameterValidationError("Not a valid document")
|
||||
|
||||
if file:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
file_path = download_to_target_path(file, temp_dir)
|
||||
extractor = TextExtractor(file_path, autodetect_encoding=True)
|
||||
documents = extractor.extract()
|
||||
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=tool_parameters.get("max_token", 500),
|
||||
chunk_overlap=0,
|
||||
fixed_separator=tool_parameters.get("separator", "\n\n"),
|
||||
separators=["\n\n", "。", ". ", " ", ""],
|
||||
embedding_model_instance=None,
|
||||
)
|
||||
chunks = character_splitter.split_documents(documents)
|
||||
|
||||
content = "\n".join([chunk.page_content for chunk in chunks])
|
||||
return self.create_text_message(content)
|
||||
|
||||
else:
|
||||
raise ToolParameterValidationError("Please provide either file")
|
||||
@ -0,0 +1,49 @@
|
||||
identity:
|
||||
name: text extractor
|
||||
author: Jyong
|
||||
label:
|
||||
en_US: Text extractor
|
||||
zh_Hans: Text 文本解析
|
||||
description:
|
||||
en_US: Extract content from text file and support split to chunks by split characters and token length
|
||||
zh_Hans: 支持从文本文件中提取内容并支持通过分割字符和令牌长度分割成块
|
||||
pt_BR: Extract content from text file and support split to chunks by split characters and token length
|
||||
description:
|
||||
human:
|
||||
en_US: Text extractor is a text extract tool
|
||||
zh_Hans: Text extractor 是一个文本提取工具
|
||||
pt_BR: Text extractor is a text extract tool
|
||||
llm: Text extractor is a tool used to extract text file
|
||||
parameters:
|
||||
- name: text_file
|
||||
type: file
|
||||
label:
|
||||
en_US: Text file
|
||||
human_description:
|
||||
en_US: The text file to be extracted.
|
||||
zh_Hans: 要提取的 text 文档。
|
||||
llm_description: you should not input this parameter. just input the image_id.
|
||||
form: llm
|
||||
- name: separator
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: split character
|
||||
zh_Hans: 分隔符号
|
||||
human_description:
|
||||
en_US: Text content split character
|
||||
zh_Hans: 用于文档分隔的符号
|
||||
llm_description: it is used for split content to chunks
|
||||
form: form
|
||||
- name: max_token
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Maximum chunk length
|
||||
zh_Hans: 最大分段长度
|
||||
human_description:
|
||||
en_US: Maximum chunk length
|
||||
zh_Hans: 最大分段长度
|
||||
llm_description: it is used for limit chunk's max length
|
||||
form: form
|
||||
|
||||
25
api/core/tools/provider/builtin/gitee_ai/tools/embedding.py
Normal file
25
api/core/tools/provider/builtin/gitee_ai/tools/embedding.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GiteeAIToolEmbedding(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {self.runtime.credentials['api_key']}",
|
||||
}
|
||||
|
||||
payload = {"inputs": tool_parameters.get("inputs")}
|
||||
model = tool_parameters.get("model", "bge-m3")
|
||||
url = f"https://ai.gitee.com/api/serverless/{model}/embeddings"
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message(f"Got Error Response:{response.text}")
|
||||
|
||||
return [self.create_text_message(response.content.decode("utf-8"))]
|
||||
@ -0,0 +1,37 @@
|
||||
identity:
|
||||
name: embedding
|
||||
author: gitee_ai
|
||||
label:
|
||||
en_US: embedding
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Generate word embeddings using Serverless-supported models (compatible with OpenAI)
|
||||
llm: This tool is used to generate word embeddings from text input.
|
||||
parameters:
|
||||
- name: model
|
||||
type: string
|
||||
required: true
|
||||
in: path
|
||||
description:
|
||||
en_US: Supported Embedding (compatible with OpenAI) interface models
|
||||
enum:
|
||||
- bge-m3
|
||||
- bge-large-zh-v1.5
|
||||
- bge-small-zh-v1.5
|
||||
label:
|
||||
en_US: Service Model
|
||||
zh_Hans: 服务模型
|
||||
default: bge-m3
|
||||
form: form
|
||||
- name: inputs
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Input Text
|
||||
zh_Hans: 输入文本
|
||||
human_description:
|
||||
en_US: The text input used to generate embeddings.
|
||||
zh_Hans: 用于生成词向量的输入文本。
|
||||
llm_description: This text input will be used to generate embeddings.
|
||||
form: llm
|
||||
@ -6,7 +6,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GiteeAITool(BuiltinTool):
|
||||
class GiteeAIToolText2Image(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
@ -40,6 +40,9 @@ class JSONParseTool(BuiltinTool):
|
||||
expr = parse(json_filter)
|
||||
result = [match.value for match in expr.find(input_data)]
|
||||
|
||||
if not result:
|
||||
return ""
|
||||
|
||||
if len(result) == 1:
|
||||
result = result[0]
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ class NovitaAiToolBase:
|
||||
if not loras_str:
|
||||
return []
|
||||
|
||||
loras_ori_list = lora_str.strip().split(";")
|
||||
loras_ori_list = loras_str.strip().split(";")
|
||||
result_list = []
|
||||
for lora_str in loras_ori_list:
|
||||
lora_info = lora_str.strip().split(",")
|
||||
|
||||
BIN
api/core/tools/provider/builtin/rapidapi/_assets/rapidapi.png
Normal file
BIN
api/core/tools/provider/builtin/rapidapi/_assets/rapidapi.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 62 KiB |
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user