Compare commits

...

73 Commits

Author SHA1 Message Date
a023315f61 add download file method 2024-11-23 16:14:15 +09:00
4dd32eb089 fork for fta 2024-11-23 10:23:57 +09:00
d3051eed48 chore (dep): bump gevent from v23 to v24 for better support for Python 3.11 and 3.12 (#10387) 2024-11-23 00:07:07 +08:00
ed55de888a fix: rules should not be None for in (#10977)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-11-22 23:04:20 +08:00
da601f0bef chore: update base image to Python 3.12 in Dockerfile (#10358) 2024-11-22 19:43:19 +08:00
08ac36812b feat: support LLM process document file (#10966)
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-11-22 19:32:44 +08:00
556de444e8 chore(app_dsl_service): Downgrade DSL Version (#10979) 2024-11-22 16:36:16 +08:00
3750200c5e feat: add a meta(mac) ctrl(windows) key (#10978) 2024-11-22 16:30:34 +08:00
c5f7d650b5 feat: Allow using file variables directly in the LLM node and support more file types. (#10679)
Co-authored-by: Joel <iamjoel007@gmail.com>
2024-11-22 16:30:22 +08:00
535c72cad7 fix(model): make sure AppModelConfig.model_dict returns a dict. (#10972) 2024-11-22 15:48:50 +08:00
8a83edc1b5 Feat: update icon and Divider components (#10975) 2024-11-22 15:44:42 +08:00
5b415a6227 chore: translate i18n files (#10970)
Co-authored-by: laipz8200 <16485841+laipz8200@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-11-22 15:24:11 +08:00
5172f0bf39 feat: Check and compare the DSL version before import an app (#10969)
Co-authored-by: Yi <yxiaoisme@gmail.com>
2024-11-22 15:05:04 +08:00
d9579f418d chore: Added the new gemini exp-1121 and learnlm-1.5 models (#10963) 2024-11-22 13:14:20 +08:00
3579bbd1c4 refactor: Split linear-gradient and color (#10961) 2024-11-22 10:55:42 +08:00
817b85001f feat: slidespeak slides generation (#10955) 2024-11-22 10:30:21 +08:00
e8868a7fb9 feat: add gpt-4o-2024-11-20 (#10951)
Co-authored-by: akubesti <agung.besti@insignia.co.id>
2024-11-22 10:29:20 +08:00
2cd9ac60f1 fix: unstructured io credential environment variables missing (#10953) 2024-11-22 10:15:17 +08:00
464f384cea fix: tiny lora bug found by mypy (#10959)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-11-22 10:01:44 +08:00
8b16f07eb0 feat: add cURL import for http request node (#8656) 2024-11-21 22:25:18 +08:00
fefda40acf fix: fix bugs of frontend-workflow panel operator (#10945)
Co-authored-by: marvin <sea-son@foxmail.com>
2024-11-21 19:07:02 +08:00
8c2f62fb92 Feat: support json output for bing-search (#10904) 2024-11-21 18:32:54 +08:00
1a6b961b5f Resolve 8475 support rerank model from infinity (#10939)
Co-authored-by: linyanxu <linyanxu2@qq.com>
2024-11-21 18:03:49 +08:00
01014a6a84 fix: external dataset missing score_threshold_enabled (#10943) 2024-11-21 18:01:47 +08:00
cb0c55daa7 fix weight rerank of knowledge retrieval (#10931) 2024-11-21 17:53:20 +08:00
82575a7aea fix(gpt-4o-audio-preview): Remove the vision feature (#10932) 2024-11-21 16:42:48 +08:00
80da0c5830 fix: default max_chunks set to 1 as other providers (#10937)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-11-21 16:36:05 +08:00
83b6abf4ad Update parse.py to handle empty list result (#10915)
Co-authored-by: crazywoola <427733928@qq.com>
2024-11-21 14:14:07 +08:00
ea0ebc020c fix: chat history might be empty in log detail view (#10905) 2024-11-21 14:12:01 +08:00
f358db9f02 feat : Add Japanese translations for API documentation: chat, advanced-chat, completion, and workflow (#10927) 2024-11-21 14:02:46 +08:00
94c9cadbd8 fix image files not deleted on indexing_estimate #9541 (#10798)
Co-authored-by: root <root@localhost.localdomain>
2024-11-21 13:03:16 +08:00
2ae6460f46 Add googlenews tools from rapidapi (#10877)
Co-authored-by: steven <sunzwj@digitalchina.com>
2024-11-21 10:39:49 +08:00
0067b16d1e fix: refactor all 'or []' and 'or {}' logic to make code more clear (#10883)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-11-21 10:34:43 +08:00
ec9f6220c9 doc: fix better doc for api develop, droping dead hint (#10906)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-11-21 10:34:23 +08:00
af53e2b6b0 Fix : Add a process to fetch the mime type from the file name for signed url in remote_url #10872 version2 (#10908) 2024-11-20 22:57:49 +08:00
b42b333a72 fix: handle redis authentication for healthcheck command (#10907) 2024-11-20 20:10:51 +08:00
99b0369f1b Gitee AI embedding tool (#10903) 2024-11-20 17:40:34 +08:00
d6ea1e2f12 fix: explicitly use new token when retrying ssePost after refresh (#10864)
Co-authored-by: liusurong.lsr <liusurong.lsr@alibaba-inc.com>
2024-11-20 16:11:33 +08:00
4d6b45427c Support streaming output for OpenAI o1-preview and o1-mini (#10890) 2024-11-20 15:10:41 +08:00
1be8365684 Fix/input-value-type-in-moderation (#10893) 2024-11-20 15:10:12 +08:00
c3d11c8ff6 fix: aws presign url is not workable remote url (#10884)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
2024-11-20 14:24:41 +08:00
8ff65abbc6 ext_redis.py support redis clusters --- Fixes #9538 (#9789)
Signed-off-by: root <root@localhost.localdomain>
Co-authored-by: root <root@localhost.localdomain>
Co-authored-by: Bowen Liang <bowenliang@apache.org>
2024-11-20 13:44:35 +08:00
bf4b6e5f80 feat: support custom tool upload file (#10796) 2024-11-20 13:26:42 +08:00
25fda7adc5 fix(http_request): allow content type application/x-javascript (#10862) 2024-11-20 12:55:06 +08:00
f3af7b5f35 fix: tool's file input display string (#10887) 2024-11-20 12:54:24 +08:00
33cfc56ad0 fix: update email validation regex to allow periods in local part (#10868) 2024-11-20 12:33:02 +08:00
464cc26ccf Fix : Add a process to fetch the mime type from the file name for signed url in remote_url (#10872) 2024-11-20 12:30:25 +08:00
d18754afdd feat: admin can also change member role (#10651) 2024-11-20 11:29:49 +08:00
beb7953d38 feat: enhance the custom note (#8885) 2024-11-20 11:24:45 +08:00
fbfc811a44 feat: support function call for ollama block chat api (#10784) 2024-11-20 11:15:19 +08:00
7e66e5a713 feat: make toc panel can collapse (#10875) 2024-11-20 10:07:30 +08:00
07b5bbae06 feat: add a minimal separator between pinned apps and unpinned apps in the explore page (#10871) 2024-11-20 09:32:59 +08:00
3087913b74 Fix the situation where output_tokens/input_tokens may be None in response.usage (#10728) 2024-11-19 21:19:13 +08:00
904ea05bf6 fix: download some remote files raise error (#10781) 2024-11-19 21:18:53 +08:00
6f4885d86d Encode invitee email in the invitation link (#10842) 2024-11-19 21:08:37 +08:00
Joe
2dc29cfee3 Feat/add langsmith dotted order (#10856) 2024-11-19 21:08:23 +08:00
bd05df5cc5 fix tongyi embedding endpoint return None output (#10857) 2024-11-19 21:04:17 +08:00
ee1f14621a fix httpx doesn't support stream parameter (#10859) 2024-11-19 21:03:01 +08:00
58a9d9eb9a fix: better WeightRerankRunner run logic use O(1) and delete unused code (#10849)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-11-19 20:12:13 +08:00
bc1013dacf feat: support json schema for gemini models (#10835) 2024-11-19 17:49:58 +08:00
9f195df103 Support Video Proxy and TED Embedding (#10819) 2024-11-19 17:49:14 +08:00
1cc7dc6360 style: refactor fetch and context (#10795) 2024-11-19 17:16:06 +08:00
328965ed7c Fix: crash of workflow file upload (#10831)
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2024-11-19 14:15:18 +08:00
133de9a087 fix: upload file component support multiple (#10817) 2024-11-19 14:00:54 +08:00
7261384655 fix: close child modal on log drawer close (#10839) 2024-11-19 12:09:55 +08:00
4718071cbb feat: Knowledge-base-api-get-post-method-text-error-#10836 (#10837) 2024-11-19 12:08:10 +08:00
22be0816aa feat: add TOC to app develop doc (#10799) 2024-11-19 09:06:12 +08:00
49e88322de doc: add clarification for length limit of init password (#10824) 2024-11-19 09:05:05 +08:00
14f3d44c37 refactor: improve handling of leading punctuation removal (#10761) 2024-11-18 21:32:33 +08:00
0ba17ec116 fix: correct typo in ETL type comment in .env.example (#10822) 2024-11-18 20:58:43 +08:00
79d59c004b chore: update .gitignore to include mise.toml (#10778) 2024-11-18 19:35:12 +08:00
873e9720e9 feat: AnalyticDB vector store supports invocation via SQL. (#10802)
Co-authored-by: 璟义 <yangshangpo.ysp@alibaba-inc.com>
2024-11-18 19:29:54 +08:00
de6d3e493c fix: script rendering in message (#10807)
Co-authored-by: crazywoola <427733928@qq.com>
2024-11-18 19:19:10 +08:00
290 changed files with 9469 additions and 2043 deletions

View File

@ -1,6 +1,8 @@
# CONTRIBUTING
So you're looking to contribute to Dify - that's awesome, we can't wait to see what you do. As a startup with limited headcount and funding, we have grand ambitions to design the most intuitive workflow for building and managing LLM applications. Any help from the community counts, truly.
We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part.
We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part.
This guide, like Dify itself, is a constant work in progress. We highly appreciate your understanding if at times it lags behind the actual project, and welcome any feedback for us to improve.
@ -10,14 +12,12 @@ In terms of licensing, please take a minute to read our short [License and Contr
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
### Feature requests:
### Feature requests
* If you're opening a new feature request, we'd like you to explain what the proposed feature achieves, and include as much context as possible. [@perzeusss](https://github.com/perzeuss) has made a solid [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) that helps you draft out your needs. Feel free to give it a try.
* If you want to pick one up from the existing issues, simply drop a comment below it saying so.
A team member working in the related direction will be looped in. If all looks good, they will give the go-ahead for you to start coding. We ask that you hold off working on the feature until then, so none of your work goes to waste should we propose changes.
Depending on whichever area the proposed feature falls under, you might talk to different team members. Here's rundown of the areas each our team members are working on at the moment:
@ -40,7 +40,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
| Non-core features and minor enhancements | Low Priority |
| Valuable but not immediate | Future-Feature |
### Anything else (e.g. bug report, performance optimization, typo correction):
### Anything else (e.g. bug report, performance optimization, typo correction)
* Start coding right away.
@ -52,7 +52,6 @@ In terms of licensing, please take a minute to read our short [License and Contr
| Non-critical bugs, performance boosts | Medium Priority |
| Minor fixes (typos, confusing but working UI) | Low Priority |
## Installing
Here are the steps to set up Dify for development:
@ -63,7 +62,7 @@ Here are the steps to set up Dify for development:
Clone the forked repository from your terminal:
```
```shell
git clone git@github.com:<github_username>/dify.git
```
@ -71,11 +70,11 @@ git clone git@github.com:<github_username>/dify.git
Dify requires the following dependencies to build, make sure they're installed on your system:
- [Docker](https://www.docker.com/)
- [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) version 3.10.x
* [Docker](https://www.docker.com/)
* [Docker Compose](https://docs.docker.com/compose/install/)
* [Node.js v18.x (LTS)](http://nodejs.org)
* [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
* [Python](https://www.python.org/) version 3.11.x or 3.12.x
### 4. Installations
@ -85,7 +84,7 @@ Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) fo
### 5. Visit dify in your browser
To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running.
To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running.
## Developing
@ -97,9 +96,9 @@ To help you quickly navigate where your contribution fits, a brief, annotated ou
### Backend
Difys backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login.
Difys backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login.
```
```text
[api/]
├── constants // Constant settings used throughout code base.
├── controllers // API route definitions and request handling logic.
@ -121,7 +120,7 @@ Difys backend is written in Python using [Flask](https://flask.palletsproject
The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typescript and uses [Tailwind CSS](https://tailwindcss.com/) for styling. [React-i18next](https://react.i18next.com/) is used for internationalization.
```
```text
[web/]
├── app // layouts, pages, and components
│ ├── (commonLayout) // common layout used throughout the app
@ -149,10 +148,10 @@ The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typ
## Submitting your PR
At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests).
At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests).
And that's it! Once your PR is merged, you will be featured as a contributor in our [README](https://github.com/langgenius/dify/blob/main/README.md).
## Getting Help
If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.

View File

@ -42,6 +42,11 @@ REDIS_SENTINEL_USERNAME=
REDIS_SENTINEL_PASSWORD=
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
# redis Cluster configuration.
REDIS_USE_CLUSTERS=false
REDIS_CLUSTERS=
REDIS_CLUSTERS_PASSWORD=
# PostgreSQL database configuration
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
@ -234,6 +239,10 @@ ANALYTICDB_ACCOUNT=testaccount
ANALYTICDB_PASSWORD=testpassword
ANALYTICDB_NAMESPACE=dify
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
ANALYTICDB_HOST=gp-test.aliyuncs.com
ANALYTICDB_PORT=5432
ANALYTICDB_MIN_CONNECTION=1
ANALYTICDB_MAX_CONNECTION=5
# OpenSearch configuration
OPENSEARCH_HOST=127.0.0.1

View File

@ -1,5 +1,5 @@
# base image
FROM python:3.10-slim-bookworm AS base
FROM python:3.12-slim-bookworm AS base
WORKDIR /app/api

View File

@ -18,12 +18,17 @@
```
2. Copy `.env.example` to `.env`
```cli
cp .env.example .env
```
3. Generate a `SECRET_KEY` in the `.env` file.
bash for Linux
```bash for Linux
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
```
bash for Mac
```bash for Mac
secret_key=$(openssl rand -base64 42)
sed -i '' "/^SECRET_KEY=/c\\
@ -37,18 +42,10 @@
5. Install dependencies
```bash
poetry env use 3.10
poetry env use 3.12
poetry install
```
In case of contributors missing to update dependencies for `pyproject.toml`, you can perform the following shell instead.
```bash
poetry shell # activate current environment
poetry add $(cat requirements.txt) # install dependencies of production and update pyproject.toml
poetry add $(cat requirements-dev.txt) --group dev # install dependencies of development and update pyproject.toml
```
6. Run migrate
Before the first launch, migrate the database to the latest version.
@ -84,5 +81,3 @@
```bash
poetry run -C api bash dev/pytest/pytest_all_tests.sh
```

View File

@ -27,7 +27,6 @@ class DifyConfig(
# read from dotenv format config file
env_file=".env",
env_file_encoding="utf-8",
frozen=True,
# ignore extra attributes
extra="ignore",
)

View File

@ -68,3 +68,18 @@ class RedisConfig(BaseSettings):
description="Socket timeout in seconds for Redis Sentinel connections",
default=0.1,
)
REDIS_USE_CLUSTERS: bool = Field(
description="Enable Redis Clusters mode for high availability",
default=False,
)
REDIS_CLUSTERS: Optional[str] = Field(
description="Comma-separated list of Redis Clusters nodes (host:port)",
default=None,
)
REDIS_CLUSTERS_PASSWORD: Optional[str] = Field(
description="Password for Redis Clusters authentication (if required)",
default=None,
)

View File

@ -1,6 +1,6 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PositiveInt
class AnalyticdbConfig(BaseModel):
@ -40,3 +40,11 @@ class AnalyticdbConfig(BaseModel):
description="The password for accessing the specified namespace within the AnalyticDB instance"
" (if namespace feature is enabled).",
)
ANALYTICDB_HOST: Optional[str] = Field(
default=None, description="The host of the AnalyticDB instance you want to connect to."
)
ANALYTICDB_PORT: PositiveInt = Field(
default=5432, description="The port of the AnalyticDB instance you want to connect to."
)
ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.")
ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.")

View File

@ -2,6 +2,7 @@ from flask import Blueprint
from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportConfirmApi
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
@ -17,6 +18,10 @@ api.add_resource(FileSupportTypeApi, "/files/support-type")
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Import App
api.add_resource(AppImportApi, "/apps/imports")
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
# Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version
@ -57,6 +62,7 @@ from .datasets import (
external,
hit_testing,
website,
fta_test,
)
# Import explore controllers

View File

@ -1,7 +1,10 @@
import uuid
from typing import cast
from flask_login import current_user
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort
from controllers.console import api
@ -13,13 +16,15 @@ from controllers.console.wraps import (
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db
from fields.app_fields import (
app_detail_fields,
app_detail_fields_with_site,
app_pagination_fields,
)
from libs.login import login_required
from services.app_dsl_service import AppDslService
from models import Account, App
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
@ -92,61 +97,6 @@ class AppListApi(Resource):
return app, 201
class AppImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check("apps")
def post(self):
"""Import app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
app = AppDslService.import_and_create_new_app(
tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user
)
return app, 201
class AppImportFromUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check("apps")
def post(self):
"""Import app from url"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
app = AppDslService.import_and_create_new_app_from_url(
tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user
)
return app, 201
class AppApi(Resource):
@setup_required
@login_required
@ -224,10 +174,24 @@ class AppCopyApi(Resource):
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
data = AppDslService.export_dsl(app_model=app_model, include_secret=True)
app = AppDslService.import_and_create_new_app(
tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user
)
with Session(db.engine) as session:
import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
account = cast(Account, current_user)
result = import_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=yaml_content,
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
)
session.commit()
stmt = select(App).where(App.id == result.app.id)
app = session.scalar(stmt)
return app, 201
@ -368,8 +332,6 @@ class AppTraceApi(Resource):
api.add_resource(AppListApi, "/apps")
api.add_resource(AppImportApi, "/apps/import")
api.add_resource(AppImportFromUrlApi, "/apps/import/url")
api.add_resource(AppApi, "/apps/<uuid:app_id>")
api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy")
api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export")

View File

@ -0,0 +1,90 @@
from typing import cast
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_import_fields
from libs.login import login_required
from models import Account
from services.app_dsl_service import AppDslService, ImportStatus
class AppImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
def post(self):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("app_id", type=str, location="json")
args = parser.parse_args()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Import app
account = cast(Account, current_user)
result = import_service.import_app(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
app_id=args.get("app_id"),
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
class AppImportConfirmApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Confirm import
account = cast(Account, current_user)
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200

View File

@ -20,7 +20,6 @@ from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models import App
from models.model import AppMode
from services.app_dsl_service import AppDslService
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.workflow_service import WorkflowService
@ -126,31 +125,6 @@ class DraftWorkflowApi(Resource):
}
class DraftWorkflowImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
def post(self, app_model: App):
"""
Import draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
workflow = AppDslService.import_and_overwrite_workflow(
app_model=app_model, data=args["data"], account=current_user
)
return workflow
class AdvancedChatDraftWorkflowRunApi(Resource):
@setup_required
@login_required
@ -453,7 +427,6 @@ class ConvertToWorkflowApi(Resource):
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
api.add_resource(DraftWorkflowImportApi, "/apps/<uuid:app_id>/workflows/draft/import")
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")

View File

@ -0,0 +1,145 @@
import json
import requests
from flask import Response
from flask_restful import Resource, reqparse
from sqlalchemy import text
from controllers.console import api
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.fta import ComponentFailure, ComponentFailureStats
class FATTestApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("log_process_data", nullable=False, required=True, type=str, location="args")
args = parser.parse_args()
print(args["log_process_data"])
# Extract the JSON string from the text field
json_str = args["log_process_data"].strip("```json\\n").strip("```").strip().replace("\\n", "")
log_data = json.loads(json_str)
db.session.query(ComponentFailure).delete()
for data in log_data:
if not isinstance(data, dict):
raise TypeError("Data must be a dictionary.")
required_keys = {"Date", "Component", "FailureMode", "Cause", "RepairAction", "Technician"}
if not required_keys.issubset(data.keys()):
raise ValueError(f"Data dictionary must contain the following keys: {required_keys}")
try:
# Clear existing stats
component_failure = ComponentFailure(
Date=data["Date"],
Component=data["Component"],
FailureMode=data["FailureMode"],
Cause=data["Cause"],
RepairAction=data["RepairAction"],
Technician=data["Technician"],
)
db.session.add(component_failure)
db.session.commit()
except Exception as e:
print(e)
# Clear existing stats
db.session.query(ComponentFailureStats).delete()
# Insert calculated statistics
try:
db.session.execute(
text("""
INSERT INTO component_failure_stats ("Component", "FailureMode", "Cause", "PossibleAction", "Probability", "MTBF")
SELECT
cf."Component",
cf."FailureMode",
cf."Cause",
cf."RepairAction" as "PossibleAction",
COUNT(*) * 1.0 / (SELECT COUNT(*) FROM component_failure WHERE "Component" = cf."Component") AS "Probability",
COALESCE(AVG(EXTRACT(EPOCH FROM (next_failure_date::timestamp - cf."Date"::timestamp)) / 86400.0),0)AS "MTBF"
FROM (
SELECT
"Component",
"FailureMode",
"Cause",
"RepairAction",
"Date",
LEAD("Date") OVER (PARTITION BY "Component", "FailureMode", "Cause" ORDER BY "Date") AS next_failure_date
FROM
component_failure
) cf
GROUP BY
cf."Component", cf."FailureMode", cf."Cause", cf."RepairAction";
""")
)
db.session.commit()
except Exception as e:
db.session.rollback()
print(f"Error during stats calculation: {e}")
# output format
# [
# (17, 'Hydraulic system', 'Leak', 'Hose rupture', 'Replaced hydraulic hose', 0.3333333333333333, None),
# (18, 'Hydraulic system', 'Leak', 'Seal Wear', 'Replaced the faulty seal', 0.3333333333333333, None),
# (19, 'Hydraulic system', 'Pressure drop', 'Fluid leak', 'Replaced hydraulic fluid and seals', 0.3333333333333333, None)
# ]
component_failure_stats = db.session.query(ComponentFailureStats).all()
# Convert stats to list of tuples format
stats_list = []
for stat in component_failure_stats:
stats_list.append(
(
stat.StatID,
stat.Component,
stat.FailureMode,
stat.Cause,
stat.PossibleAction,
stat.Probability,
stat.MTBF,
)
)
return {"data": stats_list}, 200
# generate-fault-tree
class GenerateFaultTreeApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("llm_text", nullable=False, required=True, type=str, location="args")
args = parser.parse_args()
entities = args["llm_text"].replace("```", "").replace("\\n", "\n")
print(entities)
request_data = {"fault_tree_text": entities}
url = "https://fta.cognitech-dev.live/generate-fault-tree"
headers = {"accept": "application/json", "Content-Type": "application/json"}
response = requests.post(url, json=request_data, headers=headers)
print(response.json())
return {"data": response.json()}, 200
class ExtractSVGApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("svg_text", nullable=False, required=True, type=str, location="args")
args = parser.parse_args()
# svg_text = ''.join(args["svg_text"].splitlines())
svg_text = args["svg_text"].replace("\n", "")
svg_text = svg_text.replace('"', '"')
print(svg_text)
svg_text_json = json.loads(svg_text)
svg_content = svg_text_json.get("data").get("svg_content")[0]
svg_content = svg_content.replace("\n", "").replace('"', '"')
file_key = "fta_svg/" + "fat.svg"
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, svg_content.encode("utf-8"))
generator = storage.load(file_key, stream=True)
return Response(generator, mimetype="image/svg+xml")
api.add_resource(FATTestApi, "/fta/db-handler")
api.add_resource(GenerateFaultTreeApi, "/fta/generate-fault-tree")
api.add_resource(ExtractSVGApi, "/fta/extract-svg")

View File

@ -45,7 +45,7 @@ class RemoteFileUploadApi(Resource):
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3)
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
resp.raise_for_status()
file_info = helpers.guess_file_info_from_response(resp)

View File

@ -1,3 +1,5 @@
from urllib import parse
from flask_login import current_user
from flask_restful import Resource, abort, marshal_with, reqparse
@ -57,11 +59,12 @@ class MemberInviteEmailApi(Resource):
token = RegisterService.invite_new_member(
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
)
encoded_invitee_email = parse.quote(invitee_email)
invitation_results.append(
{
"status": "success",
"email": invitee_email,
"url": f"{console_web_url}/activate?email={invitee_email}&token={token}",
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:

View File

@ -114,16 +114,9 @@ class BaseAgentRunner(AppRunner):
# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
self.stream_tool_call = True
else:
self.stream_tool_call = False
# check if model supports vision
if model_schema and ModelFeature.VISION in (model_schema.features or []):
self.files = application_generate_entity.files
else:
self.files = []
features = model_schema.features if model_schema and model_schema.features else []
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
self.query = None
self._current_thoughts: list[PromptMessage] = []
@ -250,7 +243,7 @@ class BaseAgentRunner(AppRunner):
update prompt message tool
"""
# try to get tool runtime parameters
tool_runtime_parameters = tool.get_runtime_parameters() or []
tool_runtime_parameters = tool.get_runtime_parameters()
for parameter in tool_runtime_parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:

View File

@ -11,7 +11,7 @@ from core.provider_manager import ProviderManager
class ModelConfigConverter:
@classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
"""
Convert app model config dict to entity.
:param app_config: app config
@ -38,27 +38,23 @@ class ModelConfigConverter:
)
if model_credentials is None:
if not skip_check:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
else:
model_credentials = {}
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
if not skip_check:
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model_type=ModelType.LLM
)
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model_type=ModelType.LLM
)
if provider_model is None:
model_name = model_config.model
raise ValueError(f"Model {model_name} not exist.")
if provider_model is None:
model_name = model_config.model
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = model_config.parameters
@ -76,7 +72,7 @@ class ModelConfigConverter:
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
if not skip_check and not model_schema:
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
return ModelConfigWithCredentialsEntity(

View File

@ -16,9 +16,7 @@ class FileUploadConfigManager:
file_upload_dict = config.get("file_upload")
if file_upload_dict:
if file_upload_dict.get("enabled"):
transform_methods = file_upload_dict.get("allowed_file_upload_methods") or file_upload_dict.get(
"allowed_upload_methods", []
)
transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
data = {
"image_config": {
"number_limits": file_upload_dict["number_limits"],

View File

@ -33,8 +33,8 @@ class BaseAppGenerator:
tenant_id=app_config.tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
)
for k, v in user_inputs.items()
@ -47,8 +47,8 @@ class BaseAppGenerator:
tenant_id=app_config.tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
)
for k, v in user_inputs.items()

View File

@ -217,9 +217,12 @@ class WorkflowCycleManage:
).total_seconds()
db.session.commit()
db.session.refresh(workflow_run)
db.session.close()
with Session(db.engine, expire_on_commit=False) as session:
session.add(workflow_run)
session.refresh(workflow_run)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
@ -381,7 +384,7 @@ class WorkflowCycleManage:
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
inputs=workflow_run.inputs_dict or {},
inputs=workflow_run.inputs_dict,
created_at=int(workflow_run.created_at.timestamp()),
),
)
@ -428,7 +431,7 @@ class WorkflowCycleManage:
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
),
)

View File

@ -1,9 +1,16 @@
import base64
import tempfile
from pathlib import Path
from configs import dify_config
from core.file import file_repository
from core.helper import ssrf_proxy
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
VideoPromptMessageContent,
)
from extensions.ext_database import db
from extensions.ext_storage import storage
@ -13,6 +20,38 @@ from .models import File, FileTransferMethod, FileType
from .tool_file_parser import ToolFileParser
def download_to_target_path(f: File, temp_dir: str, /):
if f.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
suffix = Path(tool_file.file_key).suffix
target_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
_download_file_to_target_path(tool_file.file_key, target_path)
return target_path
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
suffix = Path(upload_file.key).suffix
target_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
_download_file_to_target_path(upload_file.key, target_path)
return target_path
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
def _download_file_to_target_path(path: str, target_path: str, /):
"""
Download and return the contents of a file as bytes.
This function loads the file from storage and ensures it's in bytes format.
Args:
path (str): The path to the file in storage.
target_path (str): The path to the target file.
Raises:
ValueError: If the loaded file is not a bytes object.
"""
storage.download(path, target_path)
def get_attr(*, file: File, attr: FileAttribute):
match attr:
case FileAttribute.TYPE:
@ -29,35 +68,17 @@ def get_attr(*, file: File, attr: FileAttribute):
return file.remote_url
case FileAttribute.EXTENSION:
return file.extension
case _:
raise ValueError(f"Invalid file attribute: {attr}")
def to_prompt_message_content(
f: File,
/,
*,
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
):
"""
Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.
This function takes a File object and converts it to an appropriate PromptMessageContent
object, which can be used as a prompt for image or audio-based AI models.
Args:
f (File): The File object to convert.
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.
Returns:
Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level
Raises:
ValueError: If the file type is not supported or if required data is missing.
"""
match f.type:
case FileType.IMAGE:
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
data = _to_url(f)
else:
@ -65,7 +86,7 @@ def to_prompt_message_content(
return ImagePromptMessageContent(data=data, detail=image_detail_config)
case FileType.AUDIO:
encoded_string = _file_to_encoded_string(f)
encoded_string = _get_encoded_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
@ -74,9 +95,20 @@ def to_prompt_message_content(
data = _to_url(f)
else:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _get_encoded_string(f)
if f.mime_type is None:
raise ValueError("Missing file mime_type")
return DocumentPromptMessageContent(
encode_format="base64",
mime_type=f.mime_type,
data=data,
)
case _:
raise ValueError("file type f.type is not supported")
raise ValueError(f"file type {f.type} is not supported")
def download(f: File, /):
@ -118,21 +150,16 @@ def _get_encoded_string(f: File, /):
case FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
content = response.content
encoded_string = base64.b64encode(content).decode("utf-8")
return encoded_string
data = response.content
case FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
data = _download_file_content(upload_file.key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
case FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
data = _download_file_content(tool_file.file_key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
case _:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
def _to_base64_data_string(f: File, /):
@ -140,18 +167,6 @@ def _to_base64_data_string(f: File, /):
return f"data:{f.mime_type};base64,{encoded_string}"
def _file_to_encoded_string(f: File, /):
match f.type:
case FileType.IMAGE:
return _to_base64_data_string(f)
case FileType.VIDEO:
return _to_base64_data_string(f)
case FileType.AUDIO:
return _get_encoded_string(f)
case _:
raise ValueError(f"file type {f.type} is not supported")
def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:

View File

@ -28,8 +28,8 @@ class FileUploadConfig(BaseModel):
image_config: Optional[ImageConfig] = None
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_extensions: Sequence[str] = Field(default_factory=list)
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
number_limits: int = 0

View File

@ -39,6 +39,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
)
retries = 0
stream = kwargs.pop("stream", False)
while retries <= max_retries:
try:
if dify_config.SSRF_PROXY_ALL_URL:
@ -52,6 +53,8 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST:
if stream:
return response.iter_bytes()
return response
else:
logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list")

View File

@ -29,6 +29,8 @@ from core.rag.splitter.fixed_text_splitter import (
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.text_splitter import TextSplitter
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@ -278,6 +280,19 @@ class IndexingRunner:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
try:
storage.delete(image_file.key)
except Exception:
logging.exception(
"Delete image_files failed while indexing_estimate, \
image_upload_file_is: {}".format(upload_file_id)
)
db.session.delete(image_file)
if doc_form and doc_form == "qa_model":
if len(preview_texts) > 0:
# qa model document
@ -500,11 +515,7 @@ class IndexingRunner:
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
document_node.page_content = remove_leading_symbols(page_content)
if document_node.page_content:
split_documents.append(document_node)

View File

@ -1,8 +1,8 @@
from collections.abc import Sequence
from typing import Optional
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager
from core.file.models import FileType
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
@ -27,7 +27,7 @@ class TokenBufferMemory:
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> list[PromptMessage]:
) -> Sequence[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: max token limit
@ -102,12 +102,11 @@ class TokenBufferMemory:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file in file_objs:
if file.type in {FileType.IMAGE, FileType.AUDIO}:
prompt_message = file_manager.to_prompt_message_content(
file,
image_detail_config=detail,
)
prompt_message_contents.append(prompt_message)
prompt_message = file_manager.to_prompt_message_content(
file,
image_detail_config=detail,
)
prompt_message_contents.append(prompt_message)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))

View File

@ -100,10 +100,10 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
prompt_messages: Sequence[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Optional
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@ -31,7 +32,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@ -60,7 +61,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
@ -90,7 +91,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@ -120,7 +121,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:

View File

@ -2,6 +2,7 @@ from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsa
from .message_entities import (
AssistantPromptMessage,
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
@ -37,4 +38,5 @@ __all__ = [
"LLMResultChunk",
"LLMResultChunkDelta",
"AudioPromptMessageContent",
"DocumentPromptMessageContent",
]

View File

@ -1,6 +1,7 @@
from abc import ABC
from collections.abc import Sequence
from enum import Enum
from typing import Optional
from typing import Literal, Optional
from pydantic import BaseModel, Field, field_validator
@ -48,7 +49,7 @@ class PromptMessageFunction(BaseModel):
function: PromptMessageTool
class PromptMessageContentType(Enum):
class PromptMessageContentType(str, Enum):
"""
Enum class for prompt message content type.
"""
@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
DOCUMENT = "document"
class PromptMessageContent(BaseModel):
@ -101,13 +103,20 @@ class ImagePromptMessageContent(PromptMessageContent):
detail: DETAIL = DETAIL.LOW
class DocumentPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
mime_type: str
data: str
class PromptMessage(ABC, BaseModel):
"""
Model class for prompt message.
"""
role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None
content: Optional[str | Sequence[PromptMessageContent]] = None
name: Optional[str] = None
def is_empty(self) -> bool:

View File

@ -87,6 +87,9 @@ class ModelFeature(Enum):
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
class DefaultParameterName(str, Enum):

View File

@ -2,7 +2,7 @@ import logging
import re
import time
from abc import abstractmethod
from collections.abc import Generator, Mapping
from collections.abc import Generator, Mapping, Sequence
from typing import Optional, Union
from pydantic import ConfigDict
@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -169,7 +169,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -212,7 +212,7 @@ if you are not sure about the structure.
)
model_parameters.pop("response_format")
stop = stop or []
stop = list(stop) if stop is not None else []
stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block)
@ -408,7 +408,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -479,7 +479,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -601,7 +601,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -647,7 +647,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -694,7 +694,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -742,7 +742,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 200000

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 200000

View File

@ -1,7 +1,7 @@
import base64
import io
import json
from collections.abc import Generator
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast
import anthropic
@ -21,9 +21,9 @@ from httpx import Timeout
from PIL import Image
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities import (
AssistantPromptMessage,
DocumentPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
@ -33,6 +33,7 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@ -86,10 +87,10 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -130,9 +131,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
# Add the new header for claude-3-5-sonnet-20240620 model
extra_headers = {}
if model == "claude-3-5-sonnet-20240620":
if model_parameters.get("max_tokens") > 4096:
if model_parameters.get("max_tokens", 0) > 4096:
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
if any(
isinstance(content, DocumentPromptMessageContent)
for prompt_message in prompt_messages
if isinstance(prompt_message.content, list)
for content in prompt_message.content
):
extra_headers["anthropic-beta"] = "pdfs-2024-09-25"
if tools:
extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools]
response = client.beta.tools.messages.create(
@ -325,14 +334,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
assistant_prompt_message.tool_calls.append(tool_call)
# calculate num tokens
if response.usage:
# transform usage
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
prompt_tokens = (response.usage and response.usage.input_tokens) or self.get_num_tokens(
model, credentials, prompt_messages
)
completion_tokens = (response.usage and response.usage.output_tokens) or self.get_num_tokens(
model, credentials, [assistant_prompt_message]
)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
@ -505,6 +513,21 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
"source": {"type": "base64", "media_type": mime_type, "data": base64_data},
}
sub_messages.append(sub_message_dict)
elif isinstance(message_content, DocumentPromptMessageContent):
if message_content.mime_type != "application/pdf":
raise ValueError(
f"Unsupported document type {message_content.mime_type}, "
"only support application/pdf"
)
sub_message_dict = {
"type": "document",
"source": {
"type": message_content.encode_format,
"media_type": message_content.mime_type,
"data": message_content.data,
},
}
sub_messages.append(sub_message_dict)
prompt_message_dicts.append({"role": "user", "content": sub_messages})
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)

View File

@ -2,13 +2,11 @@
import base64
import json
import logging
import mimetypes
from collections.abc import Generator
from typing import Optional, Union, cast
# 3rd import
import boto3
import requests
from botocore.config import Config
from botocore.exceptions import (
ClientError,
@ -439,22 +437,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
url = message_content.data
image_content = requests.get(url).content
if "?" in url:
url = url.split("?")[0]
mime_type, _ = mimetypes.guess_type(url)
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
image_content = base64.b64decode(base64_data)
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
image_content = base64.b64decode(base64_data)
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(

View File

@ -691,8 +691,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
base_model_schema = cast(AIModelEntity, base_model_schema)
base_model_schema_features = base_model_schema.features or []
base_model_schema_model_properties = base_model_schema.model_properties or {}
base_model_schema_parameters_rules = base_model_schema.parameter_rules or []
base_model_schema_model_properties = base_model_schema.model_properties
base_model_schema_parameters_rules = base_model_schema.parameter_rules
entity = AIModelEntity(
model=model,

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 1048576
@ -24,14 +25,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 1048576
@ -24,14 +25,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 1048576
@ -24,14 +25,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 1048576
@ -24,14 +25,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 1048576
@ -24,14 +25,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 1048576
@ -24,14 +25,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 1048576
@ -24,14 +25,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 2097152
@ -24,14 +25,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 2097152
@ -24,14 +25,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 2097152
@ -24,14 +25,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 2097152
@ -24,14 +25,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 2097152
@ -24,14 +25,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 2097152
@ -24,14 +25,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -7,9 +7,10 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 2097152
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
@ -24,14 +25,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

View File

@ -0,0 +1,38 @@
model: gemini-exp-1121
label:
en_US: Gemini exp 1121
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -32,3 +32,4 @@ pricing:
output: '0.00'
unit: '0.000001'
currency: USD
deprecated: true

View File

@ -36,3 +36,4 @@ pricing:
output: '0.00'
unit: '0.000001'
currency: USD
deprecated: true

View File

@ -0,0 +1,38 @@
model: learnlm-1.5-pro-experimental
label:
en_US: LearnLM 1.5 Pro Experimental
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -1,7 +1,6 @@
import base64
import io
import json
import logging
from collections.abc import Generator
from typing import Optional, Union, cast
@ -17,6 +16,7 @@ from PIL import Image
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
DocumentPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
@ -36,16 +36,20 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
""" # noqa: E501
GOOGLE_AVAILABLE_MIMETYPE = [
"application/pdf",
"application/x-javascript",
"text/javascript",
"application/x-python",
"text/x-python",
"text/plain",
"text/html",
"text/css",
"text/md",
"text/csv",
"text/xml",
"text/rtf",
]
class GoogleLargeLanguageModel(LargeLanguageModel):
@ -155,7 +159,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
try:
ping_message = SystemPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -184,7 +188,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
if schema := config_kwargs.pop("json_schema", None):
try:
schema = json.loads(schema)
except:
raise exceptions.InvalidArgument("Invalid JSON Schema")
if tools:
raise exceptions.InvalidArgument("gemini not support use Tools and JSON Schema at same time")
config_kwargs["response_schema"] = schema
config_kwargs["response_mime_type"] = "application/json"
if stop:
config_kwargs["stop_sequences"] = stop
@ -374,6 +386,12 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
glm_content["parts"].append(blob)
elif c.type == PromptMessageContentType.DOCUMENT:
message_content = cast(DocumentPromptMessageContent, c)
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
glm_content["parts"].append(blob)
return glm_content
elif isinstance(message, AssistantPromptMessage):

View File

@ -22,6 +22,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
@ -86,6 +87,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
@ -153,6 +155,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
@ -196,6 +199,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, "api/chat")
data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
if tools:
data["tools"] = [self._convert_prompt_message_tool_to_dict(tool) for tool in tools]
else:
endpoint_url = urljoin(endpoint_url, "api/generate")
first_prompt_message = prompt_messages[0]
@ -232,7 +237,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
if stream:
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages, tools)
def _handle_generate_response(
self,
@ -241,6 +246,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
completion_type: LLMMode,
response: requests.Response,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]],
) -> LLMResult:
"""
Handle llm completion response
@ -253,14 +259,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:return: llm result
"""
response_json = response.json()
tool_calls = []
if completion_type is LLMMode.CHAT:
message = response_json.get("message", {})
response_content = message.get("content", "")
response_tool_calls = message.get("tool_calls", [])
tool_calls = [self._extract_response_tool_call(tool_call) for tool_call in response_tool_calls]
else:
response_content = response_json["response"]
assistant_message = AssistantPromptMessage(content=response_content)
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=tool_calls)
if "prompt_eval_count" in response_json and "eval_count" in response_json:
# transform usage
@ -405,9 +413,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
chunk_index += 1
def _convert_prompt_message_tool_to_dict(self, tool: PromptMessageTool) -> dict:
"""
Convert PromptMessageTool to dict for Ollama API
:param tool: tool
:return: tool dict
"""
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
},
}
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for Ollama API
:param message: prompt message
:return: message dict
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
@ -432,6 +459,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"role": "tool", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
@ -452,6 +482,29 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
return num_tokens
def _extract_response_tool_call(self, response_tool_call: dict) -> AssistantPromptMessage.ToolCall:
"""
Extract response tool call
"""
tool_call = None
if response_tool_call and "function" in response_tool_call:
# Convert arguments to JSON string if it's a dict
arguments = response_tool_call.get("function").get("arguments")
if isinstance(arguments, dict):
arguments = json.dumps(arguments)
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.get("function").get("name"),
arguments=arguments,
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.get("function").get("name"),
type="function",
function=function,
)
return tool_call
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
Get customizable model schema.
@ -461,10 +514,15 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:return: model schema
"""
extras = {}
extras = {
"features": [],
}
if "vision_support" in credentials and credentials["vision_support"] == "true":
extras["features"] = [ModelFeature.VISION]
extras["features"].append(ModelFeature.VISION)
if "function_call_support" in credentials and credentials["function_call_support"] == "true":
extras["features"].append(ModelFeature.TOOL_CALL)
extras["features"].append(ModelFeature.MULTI_TOOL_CALL)
entity = AIModelEntity(
model=model,

View File

@ -96,3 +96,22 @@ model_credential_schema:
label:
en_US: 'No'
zh_Hans:
- variable: function_call_support
label:
zh_Hans: 是否支持函数调用
en_US: Function call support
show_on:
- variable: __model_type
value: llm
default: 'false'
type: radio
required: false
options:
- value: 'true'
label:
en_US: 'Yes'
zh_Hans:
- value: 'false'
label:
en_US: 'No'
zh_Hans:

View File

@ -3,6 +3,7 @@
- gpt-4o
- gpt-4o-2024-05-13
- gpt-4o-2024-08-06
- gpt-4o-2024-11-20
- chatgpt-4o-latest
- gpt-4o-mini
- gpt-4o-mini-2024-07-18

View File

@ -0,0 +1,47 @@
model: gpt-4o-2024-11-20
label:
zh_Hans: gpt-4o-2024-11-20
en_US: gpt-4o-2024-11-20
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
- json_schema
- name: json_schema
use_template: json_schema
pricing:
input: '2.50'
output: '10.00'
unit: '0.000001'
currency: USD

View File

@ -7,7 +7,7 @@ features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
- audio
model_properties:
mode: chat
context_size: 128000

View File

@ -615,19 +615,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
# o1 compatibility
block_as_stream = False
if model.startswith("o1"):
if "max_tokens" in model_parameters:
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
del model_parameters["max_tokens"]
if stream:
block_as_stream = True
stream = False
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]
if "stop" in extra_model_kwargs:
del extra_model_kwargs["stop"]
@ -644,47 +636,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
if block_as_stream:
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
return block_result
def _handle_chat_block_as_stream_response(
self,
block_result: LLMResult,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:param stop: stop words
:return: llm response chunk generator
"""
text = block_result.message.content
text = cast(str, text)
if stop:
text = self.enforce_stop_tokens(text, stop)
yield LLMResultChunk(
model=block_result.model,
prompt_messages=prompt_messages,
system_fingerprint=block_result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=text),
finish_reason="stop",
usage=block_result.usage,
),
)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
def _handle_chat_generate_response(
self,
@ -1178,8 +1130,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
base_model_schema = model_map[base_model]
base_model_schema_features = base_model_schema.features or []
base_model_schema_model_properties = base_model_schema.model_properties or {}
base_model_schema_parameters_rules = base_model_schema.parameter_rules or []
base_model_schema_model_properties = base_model_schema.model_properties
base_model_schema_parameters_rules = base_model_schema.parameter_rules
entity = AIModelEntity(
model=model,

View File

@ -64,7 +64,7 @@ class OAICompatRerankModel(RerankModel):
# TODO: Do we need truncate docs to avoid llama.cpp return error?
data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n}
data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n, "return_documents": True}
try:
response = post(str(URL(url) / "rerank"), headers=headers, data=dumps(data), timeout=60)
@ -83,7 +83,13 @@ class OAICompatRerankModel(RerankModel):
index = result["index"]
# Retrieve document text (fallback if llama.cpp rerank doesn't return it)
text = result.get("document", {}).get("text", docs[index])
text = docs[index]
document = result.get("document", {})
if document:
if isinstance(document, dict):
text = document.get("text", docs[index])
elif isinstance(document, str):
text = document
# Normalize the score
normalized_score = (result["relevance_score"] - min_score) / score_range

View File

@ -37,13 +37,14 @@ class OpenLLMGenerateMessage:
class OpenLLMGenerate:
def generate(
self,
*,
server_url: str,
model_name: str,
stream: bool,
model_parameters: dict[str, Any],
stop: list[str],
stop: list[str] | None = None,
prompt_messages: list[OpenLLMGenerateMessage],
user: str,
user: str | None = None,
) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]:
if not server_url:
raise InvalidAuthenticationError("Invalid server URL")

View File

@ -45,19 +45,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._update_credential(model, credentials)
block_as_stream = False
if model.startswith("openai/o1"):
block_as_stream = True
stop = None
# invoke block as stream
if stream and block_as_stream:
return self._generate_block_as_stream(
model, credentials, prompt_messages, model_parameters, tools, stop, user
)
else:
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _generate_block_as_stream(
self,
@ -69,9 +57,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
stop: Optional[list[str]] = None,
user: Optional[str] = None,
) -> Generator:
resp: LLMResult = super()._generate(
model, credentials, prompt_messages, model_parameters, tools, stop, False, user
)
resp = super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, False, user)
yield LLMResultChunk(
model=model,

View File

@ -6,6 +6,7 @@ model_type: llm
features:
- vision
- agent-thought
- video
model_properties:
mode: chat
context_size: 32000

View File

@ -6,6 +6,7 @@ model_type: llm
features:
- vision
- agent-thought
- video
model_properties:
mode: chat
context_size: 32000

View File

@ -6,6 +6,7 @@ model_type: llm
features:
- vision
- agent-thought
- video
model_properties:
mode: chat
context_size: 32768

View File

@ -6,6 +6,7 @@ model_type: llm
features:
- vision
- agent-thought
- video
model_properties:
mode: chat
context_size: 8000

View File

@ -65,6 +65,8 @@ class GTERerankModel(RerankModel):
)
rerank_documents = []
if not response.output:
return RerankResult(model=model, docs=rerank_documents)
for _, result in enumerate(response.output.results):
# format document
rerank_document = RerankDocument(

View File

@ -22,7 +22,7 @@ def get_model_config(credentials: dict) -> ModelConfig:
return ModelConfig(
properties=ModelProperties(
context_size=int(credentials.get("context_size", 0)),
max_chunks=int(credentials.get("max_chunks", 0)),
max_chunks=int(credentials.get("max_chunks", 1)),
)
)
return model_configs

View File

@ -6,6 +6,7 @@ model_properties:
mode: chat
features:
- vision
- video
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -1,3 +1,6 @@
from collections.abc import Sequence
from typing import Any
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
@ -62,5 +65,5 @@ class KeywordsModeration(Moderation):
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
def _check_keywords_in_value(self, keywords_list, value) -> bool:
return any(keyword.lower() in value.lower() for keyword in keywords_list)
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:
return any(keyword.lower() in str(value).lower() for keyword in keywords_list)

View File

@ -49,6 +49,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run")
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
@field_validator("inputs", "outputs")
@classmethod

View File

@ -25,7 +25,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunType,
LangSmithRunUpdateModel,
)
from core.ops.utils import filter_none_values
from core.ops.utils import filter_none_values, generate_dotted_order
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
@ -62,6 +62,16 @@ class LangSmithDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.message_id or trace_info.workflow_app_log_id or trace_info.workflow_run_id
message_dotted_order = (
generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
)
workflow_dotted_order = generate_dotted_order(
trace_info.workflow_app_log_id or trace_info.workflow_run_id,
trace_info.workflow_data.created_at,
message_dotted_order,
)
if trace_info.message_id:
message_run = LangSmithRunModel(
id=trace_info.message_id,
@ -76,6 +86,8 @@ class LangSmithDataTrace(BaseTraceInstance):
},
tags=["message", "workflow"],
error=trace_info.error,
trace_id=trace_id,
dotted_order=message_dotted_order,
)
self.add_run(message_run)
@ -95,6 +107,8 @@ class LangSmithDataTrace(BaseTraceInstance):
error=trace_info.error,
tags=["workflow"],
parent_run_id=trace_info.message_id or None,
trace_id=trace_id,
dotted_order=workflow_dotted_order,
)
self.add_run(langsmith_run)
@ -177,6 +191,7 @@ class LangSmithDataTrace(BaseTraceInstance):
else:
run_type = LangSmithRunType.tool
node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
langsmith_run = LangSmithRunModel(
total_tokens=node_total_tokens,
name=node_type,
@ -191,6 +206,9 @@ class LangSmithDataTrace(BaseTraceInstance):
},
parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
tags=["node_execution"],
id=node_execution_id,
trace_id=trace_id,
dotted_order=node_dotted_order,
)
self.add_run(langsmith_run)

View File

@ -1,5 +1,6 @@
from contextlib import contextmanager
from datetime import datetime
from typing import Optional, Union
from extensions.ext_database import db
from models.model import Message
@ -43,3 +44,19 @@ def replace_text_with_content(data):
return [replace_text_with_content(item) for item in data]
else:
return data
def generate_dotted_order(
run_id: str, start_time: Union[str, datetime], parent_dotted_order: Optional[str] = None
) -> str:
"""
generate dotted_order for langsmith
"""
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
current_segment = f"{timestamp}{run_id}"
if parent_dotted_order is None:
return current_segment
return f"{parent_dotted_order}.{current_segment}"

View File

@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import cast
from core.model_runtime.entities import (
@ -14,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode
class PromptMessageUtil:
@staticmethod
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]:
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]:
"""
Prompt messages to prompt for saving.
:param model_mode: model mode

View File

@ -12,7 +12,7 @@ class CleanProcessor:
# Unicode U+FFFE
text = re.sub("\ufffe", "", text)
rules = process_rule["rules"] if process_rule else None
rules = process_rule["rules"] if process_rule else {}
if "pre_processing_rules" in rules:
pre_processing_rules = rules["pre_processing_rules"]
for pre_processing_rule in pre_processing_rules:

View File

@ -1,310 +1,62 @@
import json
from typing import Any
from pydantic import BaseModel
_import_err_msg = (
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from configs import dify_config
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
class AnalyticdbConfig(BaseModel):
access_key_id: str
access_key_secret: str
region_id: str
instance_id: str
account: str
account_password: str
namespace: str = ("dify",)
namespace_password: str = (None,)
metrics: str = ("cosine",)
read_timeout: int = 60000
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret,
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
class AnalyticdbVector(BaseVector):
def __init__(self, collection_name: str, config: AnalyticdbConfig):
self._collection_name = collection_name.lower()
try:
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models
except:
raise ImportError(_import_err_msg)
self.config = config
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
self._client = Client(self._client_config)
self._initialize()
def _initialize(self) -> None:
cache_key = f"vector_indexing_{self.config.instance_id}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self.config.instance_id}"
if redis_client.get(collection_exist_cache_key):
return
self._initialize_vector_database()
self._create_namespace_if_not_exists()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.init_vector_database(request)
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.describe_namespace(request)
except TeaException as e:
if e.statusCode == 404:
request = gpdb_20160503_models.CreateNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
)
self._client.create_namespace(request)
else:
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
try:
request = gpdb_20160503_models.DescribeCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
)
self._client.describe_collection(request)
except TeaException as e:
if e.statusCode == 404:
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
full_text_retrieval_fields = "page_content"
request = gpdb_20160503_models.CreateCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
collection=self._collection_name,
dimension=embedding_dimension,
metrics=self.config.metrics,
metadata=metadata,
full_text_retrieval_fields=full_text_retrieval_fields,
)
self._client.create_collection(request)
else:
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def __init__(
self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
):
super().__init__(collection_name)
if api_config is not None:
self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
else:
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
def get_type(self) -> str:
return VectorType.ANALYTICDB
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection_if_not_exists(dimension)
self.add_texts(texts, embeddings)
self.analyticdb_vector._create_collection_if_not_exists(dimension)
self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
)
)
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
rows=rows,
)
self._client.upsert_collection_data(request)
def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.analyticdb_vector.add_texts(texts, embeddings)
def text_exists(self, id: str) -> bool:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'",
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
return self.analyticdb_vector.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
ids_str = f"({ids_str})"
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
self.analyticdb_vector.delete_by_ids(ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
self.analyticdb_vector.delete_by_metadata_field(key, value)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
return self.analyticdb_vector.search_by_vector(query_vector)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.metadata.get("vector"),
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
def delete(self) -> None:
try:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
collection=self._collection_name,
dbinstance_id=self.config.instance_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
region_id=self.config.region_id,
)
self._client.delete_collection(request)
except Exception as e:
raise e
self.analyticdb_vector.delete()
class AnalyticdbVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
@ -313,26 +65,9 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
# handle optional params
if dify_config.ANALYTICDB_KEY_ID is None:
raise ValueError("ANALYTICDB_KEY_ID should not be None")
if dify_config.ANALYTICDB_KEY_SECRET is None:
raise ValueError("ANALYTICDB_KEY_SECRET should not be None")
if dify_config.ANALYTICDB_REGION_ID is None:
raise ValueError("ANALYTICDB_REGION_ID should not be None")
if dify_config.ANALYTICDB_INSTANCE_ID is None:
raise ValueError("ANALYTICDB_INSTANCE_ID should not be None")
if dify_config.ANALYTICDB_ACCOUNT is None:
raise ValueError("ANALYTICDB_ACCOUNT should not be None")
if dify_config.ANALYTICDB_PASSWORD is None:
raise ValueError("ANALYTICDB_PASSWORD should not be None")
if dify_config.ANALYTICDB_NAMESPACE is None:
raise ValueError("ANALYTICDB_NAMESPACE should not be None")
if dify_config.ANALYTICDB_NAMESPACE_PASSWORD is None:
raise ValueError("ANALYTICDB_NAMESPACE_PASSWORD should not be None")
return AnalyticdbVector(
collection_name,
AnalyticdbConfig(
if dify_config.ANALYTICDB_HOST is None:
# implemented through OpenAPI
apiConfig = AnalyticdbVectorOpenAPIConfig(
access_key_id=dify_config.ANALYTICDB_KEY_ID,
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
region_id=dify_config.ANALYTICDB_REGION_ID,
@ -341,5 +76,22 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
account_password=dify_config.ANALYTICDB_PASSWORD,
namespace=dify_config.ANALYTICDB_NAMESPACE,
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
),
)
sqlConfig = None
else:
# implemented through sql
sqlConfig = AnalyticdbVectorBySqlConfig(
host=dify_config.ANALYTICDB_HOST,
port=dify_config.ANALYTICDB_PORT,
account=dify_config.ANALYTICDB_ACCOUNT,
account_password=dify_config.ANALYTICDB_PASSWORD,
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
namespace=dify_config.ANALYTICDB_NAMESPACE,
)
apiConfig = None
return AnalyticdbVector(
collection_name,
apiConfig,
sqlConfig,
)

View File

@ -0,0 +1,309 @@
import json
from typing import Any
from pydantic import BaseModel, model_validator
_import_err_msg = (
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
class AnalyticdbVectorOpenAPIConfig(BaseModel):
access_key_id: str
access_key_secret: str
region_id: str
instance_id: str
account: str
account_password: str
namespace: str = "dify"
namespace_password: str = (None,)
metrics: str = "cosine"
read_timeout: int = 60000
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["access_key_id"]:
raise ValueError("config ANALYTICDB_KEY_ID is required")
if not values["access_key_secret"]:
raise ValueError("config ANALYTICDB_KEY_SECRET is required")
if not values["region_id"]:
raise ValueError("config ANALYTICDB_REGION_ID is required")
if not values["instance_id"]:
raise ValueError("config ANALYTICDB_INSTANCE_ID is required")
if not values["account"]:
raise ValueError("config ANALYTICDB_ACCOUNT is required")
if not values["account_password"]:
raise ValueError("config ANALYTICDB_PASSWORD is required")
if not values["namespace_password"]:
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
return values
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret,
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
class AnalyticdbVectorOpenAPI:
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
try:
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models
except:
raise ImportError(_import_err_msg)
self._collection_name = collection_name.lower()
self.config = config
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
self._client = Client(self._client_config)
self._initialize()
def _initialize(self) -> None:
cache_key = f"vector_initialize_{self.config.instance_id}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
database_exist_cache_key = f"vector_initialize_{self.config.instance_id}"
if redis_client.get(database_exist_cache_key):
return
self._initialize_vector_database()
self._create_namespace_if_not_exists()
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.init_vector_database(request)
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.describe_namespace(request)
except TeaException as e:
if e.statusCode == 404:
request = gpdb_20160503_models.CreateNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
)
self._client.create_namespace(request)
else:
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
try:
request = gpdb_20160503_models.DescribeCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
)
self._client.describe_collection(request)
except TeaException as e:
if e.statusCode == 404:
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
full_text_retrieval_fields = "page_content"
request = gpdb_20160503_models.CreateCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
collection=self._collection_name,
dimension=embedding_dimension,
metrics=self.config.metrics,
metadata=metadata,
full_text_retrieval_fields=full_text_retrieval_fields,
)
self._client.create_collection(request)
else:
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
)
)
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
rows=rows,
)
self._client.upsert_collection_data(request)
def text_exists(self, id: str) -> bool:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'",
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
ids_str = f"({ids_str})"
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
def delete_by_metadata_field(self, key: str, value: str) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.values.value,
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.values.value,
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
def delete(self) -> None:
try:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
collection=self._collection_name,
dbinstance_id=self.config.instance_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
region_id=self.config.region_id,
)
self._client.delete_collection(request)
except Exception as e:
raise e

View File

@ -0,0 +1,245 @@
import json
import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
class AnalyticdbVectorBySqlConfig(BaseModel):
host: str
port: int
account: str
account_password: str
min_connection: int
max_connection: int
namespace: str = "dify"
metrics: str = "cosine"
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config ANALYTICDB_HOST is required")
if not values["port"]:
raise ValueError("config ANALYTICDB_PORT is required")
if not values["account"]:
raise ValueError("config ANALYTICDB_ACCOUNT is required")
if not values["account_password"]:
raise ValueError("config ANALYTICDB_PASSWORD is required")
if not values["min_connection"]:
raise ValueError("config ANALYTICDB_MIN_CONNECTION is required")
if not values["max_connection"]:
raise ValueError("config ANALYTICDB_MAX_CONNECTION is required")
if values["min_connection"] > values["max_connection"]:
raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION")
return values
class AnalyticdbVectorBySql:
def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig):
self._collection_name = collection_name.lower()
self.databaseName = "knowledgebase"
self.config = config
self.table_name = f"{self.config.namespace}.{self._collection_name}"
self.pool = None
self._initialize()
if not self.pool:
self.pool = self._create_connection_pool()
def _initialize(self) -> None:
cache_key = f"vector_initialize_{self.config.host}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
database_exist_cache_key = f"vector_initialize_{self.config.host}"
if redis_client.get(database_exist_cache_key):
return
self._initialize_vector_database()
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _create_connection_pool(self):
return psycopg2.pool.SimpleConnectionPool(
self.config.min_connection,
self.config.max_connection,
host=self.config.host,
port=self.config.port,
user=self.config.account,
password=self.config.account_password,
database=self.databaseName,
)
@contextmanager
def _get_cursor(self):
conn = self.pool.getconn()
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
self.pool.putconn(conn)
def _initialize_vector_database(self) -> None:
conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
user=self.config.account,
password=self.config.account_password,
database="postgres",
)
conn.autocommit = True
cur = conn.cursor()
try:
cur.execute(f"CREATE DATABASE {self.databaseName}")
except Exception as e:
if "already exists" in str(e):
return
raise e
finally:
cur.close()
conn.close()
self.pool = self._create_connection_pool()
with self._get_cursor() as cur:
try:
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
except Exception as e:
if "already exists" not in str(e):
raise e
cur.execute(
"CREATE OR REPLACE FUNCTION "
"public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) "
"RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ "
"SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) "
"FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) "
"AS words_only;$function$"
)
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
with self._get_cursor() as cur:
cur.execute(
f"CREATE TABLE IF NOT EXISTS {self.table_name}("
f"id text PRIMARY KEY,"
f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
f"to_tsvector TSVECTOR"
f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
)
if embedding_dimension is not None:
index_name = f"{self._collection_name}_embedding_idx"
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
cur.execute(
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
f"pq_enable=0, external_storage=0)"
)
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
id_prefix = str(uuid.uuid4()) + "_"
sql = f"""
INSERT INTO {self.table_name}
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
"""
for i, doc in enumerate(documents):
values.append(
(
id_prefix + str(i),
doc.metadata.get("doc_id", str(uuid.uuid4())),
embeddings[i],
doc.page_content,
json.dumps(doc.metadata),
doc.page_content,
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_batch(cur, sql, values)
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,))
return cur.fetchone() is not None
def delete_by_ids(self, ids: list[str]) -> None:
with self._get_cursor() as cur:
try:
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),))
except Exception as e:
if "does not exist" not in str(e):
raise e
def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_cursor() as cur:
try:
cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value))
except Exception as e:
if "does not exist" not in str(e):
raise e
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
with self._get_cursor() as cur:
query_vector_str = json.dumps(query_vector)
query_vector_str = "{" + query_vector_str[1:-1] + "}"
cur.execute(
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
f"t.page_content as page_content, t.metadata_ AS metadata_ "
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
(query_vector_str,),
)
documents = []
for record in cur:
id, vector, score, page_content, metadata = record
if score > score_threshold:
metadata["score"] = score
doc = Document(
page_content=page_content,
vector=vector,
metadata=metadata,
)
documents.append(doc)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
with self._get_cursor() as cur:
cur.execute(
f"""SELECT id, vector, page_content, metadata_,
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
ORDER BY score DESC
LIMIT {top_k}""",
(f"'{query}'", f"'{query}'"),
)
documents = []
for record in cur:
id, vector, page_content, metadata, score = record
metadata["score"] = score
doc = Document(
page_content=page_content,
vector=vector,
metadata=metadata,
)
documents.append(doc)
return documents
def delete(self) -> None:
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View File

@ -11,6 +11,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset
@ -43,11 +44,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:].strip()
else:
page_content = page_content
page_content = remove_leading_symbols(document_node.page_content).strip()
if len(page_content) > 0:
document_node.page_content = page_content
split_documents.append(document_node)

View File

@ -18,6 +18,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset
@ -53,11 +54,7 @@ class QAIndexProcessor(BaseIndexProcessor):
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
document_node.page_content = remove_leading_symbols(page_content)
split_documents.append(document_node)
all_documents.extend(split_documents)
for i in range(0, len(all_documents), 10):

View File

@ -27,11 +27,11 @@ class RerankModelRunner(BaseRerankRunner):
:return:
"""
docs = []
doc_id = set()
doc_ids = set()
unique_documents = []
for document in documents:
if document.provider == "dify" and document.metadata["doc_id"] not in doc_id:
doc_id.add(document.metadata["doc_id"])
if document.provider == "dify" and document.metadata["doc_id"] not in doc_ids:
doc_ids.add(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
elif document.provider == "external":

View File

@ -36,23 +36,20 @@ class WeightRerankRunner(BaseRerankRunner):
:return:
"""
docs = []
doc_id = []
unique_documents = []
doc_ids = set()
for document in documents:
if document.metadata["doc_id"] not in doc_id:
doc_id.append(document.metadata["doc_id"])
docs.append(document.page_content)
if document.metadata["doc_id"] not in doc_ids:
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
documents = unique_documents
rerank_documents = []
query_scores = self._calculate_keyword_score(query, documents)
query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
rerank_documents = []
for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
# format document
score = (
self.weights.vector_setting.vector_weight * query_vector_score
+ self.weights.keyword_setting.keyword_weight * query_score
@ -61,7 +58,8 @@ class WeightRerankRunner(BaseRerankRunner):
continue
document.metadata["score"] = score
rerank_documents.append(document)
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True)
rerank_documents.sort(key=lambda x: x.metadata["score"], reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:

View File

@ -66,6 +66,41 @@ class BingSearchTool(BuiltinTool):
results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}'))
return results
elif result_type == "json":
result = {}
if search_results:
result["organic"] = [
{
"title": item.get("name", ""),
"snippet": item.get("snippet", ""),
"url": item.get("url", ""),
"siteName": item.get("siteName", ""),
}
for item in search_results
]
if computation and "expression" in computation and "value" in computation:
result["computation"] = {"expression": computation["expression"], "value": computation["value"]}
if entities:
result["entities"] = [
{
"name": item.get("name", ""),
"url": item.get("url", ""),
"description": item.get("description", ""),
}
for item in entities
]
if news:
result["news"] = [{"name": item.get("name", ""), "url": item.get("url", "")} for item in news]
if related_searches:
result["related searches"] = [
{"displayText": item.get("displayText", ""), "url": item.get("webSearchUrl", "")} for item in news
]
return self.create_json_message(result)
else:
# construct text
text = ""

View File

@ -113,9 +113,9 @@ parameters:
zh_Hans: 结果类型
pt_BR: result type
human_description:
en_US: return a list of links or texts
zh_Hans: 返回一个连接列表还是纯文本内容
pt_BR: return a list of links or texts
en_US: return a list of links, json or texts
zh_Hans: 返回一个列表内容是链接、json还是纯文本
pt_BR: return a list of links, json or texts
default: text
options:
- value: link
@ -123,6 +123,11 @@ parameters:
en_US: Link
zh_Hans: 链接
pt_BR: Link
- value: json
label:
en_US: JSON
zh_Hans: JSON
pt_BR: JSON
- value: text
label:
en_US: Text

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, ClassVar
from duckduckgo_search import DDGS
@ -11,6 +11,17 @@ class DuckDuckGoVideoSearchTool(BuiltinTool):
Tool for performing a video search using DuckDuckGo search engine.
"""
IFRAME_TEMPLATE: ClassVar[str] = """
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden; \
max-width: 100%; border-radius: 8px;">
<iframe
style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;"
src="{src}"
frameborder="0"
allowfullscreen>
</iframe>
</div>"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
query_dict = {
"keywords": tool_parameters.get("query"),
@ -26,6 +37,9 @@ class DuckDuckGoVideoSearchTool(BuiltinTool):
# Remove None values to use API defaults
query_dict = {k: v for k, v in query_dict.items() if v is not None}
# Get proxy URL from parameters
proxy_url = tool_parameters.get("proxy_url", "").strip()
response = DDGS().videos(**query_dict)
# Create HTML result with embedded iframes
@ -36,20 +50,21 @@ class DuckDuckGoVideoSearchTool(BuiltinTool):
title = res.get("title", "")
embed_html = res.get("embed_html", "")
description = res.get("description", "")
content_url = res.get("content", "")
# Modify iframe to be responsive
if embed_html:
# Replace fixed dimensions with responsive wrapper and iframe
embed_html = """
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden; \
max-width: 100%; border-radius: 8px;">
<iframe
style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;"
src="{src}"
frameborder="0"
allowfullscreen>
</iframe>
</div>""".format(src=res.get("embed_url", ""))
# Handle TED.com videos
if not embed_html and "ted.com/talks" in content_url:
embed_url = content_url.replace("www.ted.com", "embed.ted.com")
if proxy_url:
embed_url = f"{proxy_url}{embed_url}"
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
# Original YouTube/other platform handling
elif embed_html:
embed_url = res.get("embed_url", "")
if proxy_url and embed_url:
embed_url = f"{proxy_url}{embed_url}"
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
markdown_result += f"{title}\n\n"
markdown_result += f"{embed_html}\n\n"

View File

@ -1,40 +1,43 @@
identity:
name: ddgo_video
author: Assistant
author: Tao Wang
label:
en_US: DuckDuckGo Video Search
zh_Hans: DuckDuckGo 视频搜索
description:
human:
en_US: Perform video searches on DuckDuckGo and get results with embedded videos.
zh_Hans: 在 DuckDuckGo 上进行视频搜索并获取可嵌入视频结果。
llm: Perform video searches on DuckDuckGo and get results with embedded videos.
en_US: Search and embedded videos.
zh_Hans: 搜索并嵌入视频
llm: Search videos on duckduckgo and embed videos in iframe
parameters:
- name: query
type: string
required: true
label:
en_US: Query String
zh_Hans: 查询语句
type: string
required: true
human_description:
en_US: Search Query
zh_Hans: 搜索查询语句
zh_Hans: 搜索查询语句
llm_description: Key words for searching
form: llm
- name: max_results
label:
en_US: Max Results
zh_Hans: 最大结果数量
type: number
required: true
default: 3
minimum: 1
maximum: 10
label:
en_US: Max Results
zh_Hans: 最大结果数量
human_description:
en_US: The max results (1-10).
zh_Hans: 最大结果数量1-10
en_US: The max results (1-10)
zh_Hans: 最大结果数量1-10
form: form
- name: timelimit
label:
en_US: Result Time Limit
zh_Hans: 结果时间限制
type: select
required: false
options:
@ -54,14 +57,14 @@ parameters:
label:
en_US: Current Year
zh_Hans: 今年
label:
en_US: Result Time Limit
zh_Hans: 结果时间限制
human_description:
en_US: Use when querying results within a specific time range only.
en_US: Query results within a specific time range only
zh_Hans: 只查询一定时间范围内的结果时使用
form: form
- name: duration
label:
en_US: Video Duration
zh_Hans: 视频时长
type: select
required: false
options:
@ -77,10 +80,18 @@ parameters:
label:
en_US: Long (>20 minutes)
zh_Hans: 长视频(>20分钟
label:
en_US: Video Duration
zh_Hans: 视频时长
human_description:
en_US: Filter videos by duration
zh_Hans: 按时长筛选视频
form: form
- name: proxy_url
label:
en_US: Proxy URL
zh_Hans: 视频代理地址
type: string
required: false
default: ""
human_description:
en_US: Proxy URL
zh_Hans: 视频代理地址
form: form

View File

@ -17,7 +17,7 @@ class SendMailTool(BuiltinTool):
invoke tools
"""
sender = self.runtime.credentials.get("email_account", "")
email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
password = self.runtime.credentials.get("email_password", "")
smtp_server = self.runtime.credentials.get("smtp_server", "")
if not smtp_server:

View File

@ -18,7 +18,7 @@ class SendMailTool(BuiltinTool):
invoke tools
"""
sender = self.runtime.credentials.get("email_account", "")
email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
password = self.runtime.credentials.get("email_password", "")
smtp_server = self.runtime.credentials.get("smtp_server", "")
if not smtp_server:

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.3 KiB

View File

@ -0,0 +1,8 @@
from typing import Any
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class FileExtractorProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
pass

View File

@ -0,0 +1,15 @@
identity:
author: Jyong
name: file_extractor
label:
en_US: File Extractor
zh_Hans: 文件提取
pt_BR: File Extractor
description:
en_US: Extract text from file
zh_Hans: 从文件中提取文本
pt_BR: Extract text from file
icon: icon.png
tags:
- utilities
- productivity

View File

@ -0,0 +1,45 @@
import tempfile
from typing import Any, Union
from core.file.enums import FileType
from core.file.file_manager import download_to_target_path
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.splitter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolParameterValidationError
from core.tools.tool.builtin_tool import BuiltinTool
class FileExtractorTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
# image file for workflow mode
file = tool_parameters.get("text_file")
if file and file.type != FileType.DOCUMENT:
raise ToolParameterValidationError("Not a valid document")
if file:
with tempfile.TemporaryDirectory() as temp_dir:
file_path = download_to_target_path(file, temp_dir)
extractor = TextExtractor(file_path, autodetect_encoding=True)
documents = extractor.extract()
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=tool_parameters.get("max_token", 500),
chunk_overlap=0,
fixed_separator=tool_parameters.get("separator", "\n\n"),
separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=None,
)
chunks = character_splitter.split_documents(documents)
content = "\n".join([chunk.page_content for chunk in chunks])
return self.create_text_message(content)
else:
raise ToolParameterValidationError("Please provide either file")

View File

@ -0,0 +1,49 @@
identity:
name: text extractor
author: Jyong
label:
en_US: Text extractor
zh_Hans: Text 文本解析
description:
en_US: Extract content from text file and support split to chunks by split characters and token length
zh_Hans: 支持从文本文件中提取内容并支持通过分割字符和令牌长度分割成块
pt_BR: Extract content from text file and support split to chunks by split characters and token length
description:
human:
en_US: Text extractor is a text extract tool
zh_Hans: Text extractor 是一个文本提取工具
pt_BR: Text extractor is a text extract tool
llm: Text extractor is a tool used to extract text file
parameters:
- name: text_file
type: file
label:
en_US: Text file
human_description:
en_US: The text file to be extracted.
zh_Hans: 要提取的 text 文档。
llm_description: you should not input this parameter. just input the image_id.
form: llm
- name: separator
type: string
required: false
label:
en_US: split character
zh_Hans: 分隔符号
human_description:
en_US: Text content split character
zh_Hans: 用于文档分隔的符号
llm_description: it is used for split content to chunks
form: form
- name: max_token
type: number
required: false
label:
en_US: Maximum chunk length
zh_Hans: 最大分段长度
human_description:
en_US: Maximum chunk length
zh_Hans: 最大分段长度
llm_description: it is used for limit chunk's max length
form: form

View File

@ -0,0 +1,25 @@
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GiteeAIToolEmbedding(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
headers = {
"content-type": "application/json",
"authorization": f"Bearer {self.runtime.credentials['api_key']}",
}
payload = {"inputs": tool_parameters.get("inputs")}
model = tool_parameters.get("model", "bge-m3")
url = f"https://ai.gitee.com/api/serverless/{model}/embeddings"
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
return self.create_text_message(f"Got Error Response:{response.text}")
return [self.create_text_message(response.content.decode("utf-8"))]

View File

@ -0,0 +1,37 @@
identity:
name: embedding
author: gitee_ai
label:
en_US: embedding
icon: icon.svg
description:
human:
en_US: Generate word embeddings using Serverless-supported models (compatible with OpenAI)
llm: This tool is used to generate word embeddings from text input.
parameters:
- name: model
type: string
required: true
in: path
description:
en_US: Supported Embedding (compatible with OpenAI) interface models
enum:
- bge-m3
- bge-large-zh-v1.5
- bge-small-zh-v1.5
label:
en_US: Service Model
zh_Hans: 服务模型
default: bge-m3
form: form
- name: inputs
type: string
required: true
label:
en_US: Input Text
zh_Hans: 输入文本
human_description:
en_US: The text input used to generate embeddings.
zh_Hans: 用于生成词向量的输入文本。
llm_description: This text input will be used to generate embeddings.
form: llm

View File

@ -6,7 +6,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GiteeAITool(BuiltinTool):
class GiteeAIToolText2Image(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:

View File

@ -40,6 +40,9 @@ class JSONParseTool(BuiltinTool):
expr = parse(json_filter)
result = [match.value for match in expr.find(input_data)]
if not result:
return ""
if len(result) == 1:
result = result[0]

View File

@ -12,7 +12,7 @@ class NovitaAiToolBase:
if not loras_str:
return []
loras_ori_list = lora_str.strip().split(";")
loras_ori_list = loras_str.strip().split(";")
result_list = []
for lora_str in loras_ori_list:
lora_info = lora_str.strip().split(",")

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

Some files were not shown because too many files have changed in this diff Show More