mirror of
https://github.com/langgenius/dify.git
synced 2026-01-31 00:56:36 +08:00
Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| eee95190cc | |||
| e8311357ff | |||
| 0f14fdd4c9 | |||
| ece0f08a2b | |||
| 5edb3d55e5 | |||
| 63382f758e | |||
| bbef964eb5 | |||
| e6db7ad1d5 | |||
| 8cc492721b | |||
| a80fe20456 | |||
| f7986805c6 | |||
| aa5ca90f00 | |||
| 4af00e4a45 | |||
| c01c95d77f | |||
| 20a9037d5b | |||
| 34d3998566 | |||
| 198d6c00d6 | |||
| 1663df8a05 | |||
| d8926a2571 | |||
| 4796f9d914 | |||
| a588df4371 | |||
| 2c1c660c6e | |||
| 13f4ed6e0e | |||
| 1e451991db | |||
| 749b236d3d | |||
| 00ce372b71 | |||
| 370e1c1a17 | |||
| 28495273b4 | |||
| 36a9c5cc6b | |||
| 228de1f12a | |||
| 01555463d2 | |||
| 6b99075dc8 | |||
| 8578ee0864 | |||
| 897e07f639 | |||
| 875249eb00 | |||
| 4d5a4e4cef | |||
| 86a6e6bd04 | |||
| 8f3042e5b3 | |||
| a1ab87107b | |||
| f49c99937c | |||
| 9b24f12bf5 | |||
| 487ce7c82a | |||
| cc835d523c |
@ -1,4 +1,4 @@
|
||||
# Devlopment with devcontainer
|
||||
# Development with devcontainer
|
||||
This project includes a devcontainer configuration that allows you to open the project in a container with a fully configured development environment.
|
||||
Both frontend and backend environments are initialized when the container is started.
|
||||
## GitHub Codespaces
|
||||
@ -33,5 +33,5 @@ Performance Impact: While usually minimal, programs running inside a devcontaine
|
||||
if you see such error message when you open this project in codespaces:
|
||||

|
||||
|
||||
a simple workaround is change `/signin` endpoint into another one, then login with github account and close the tab, then change it back to `/signin` endpoint. Then all things will be fine.
|
||||
a simple workaround is change `/signin` endpoint into another one, then login with GitHub account and close the tab, then change it back to `/signin` endpoint. Then all things will be fine.
|
||||
The reason is `signin` endpoint is not allowed in codespaces, details can be found [here](https://github.com/orgs/community/discussions/5204)
|
||||
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -8,13 +8,13 @@ body:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: This is only for bug report, if you would like to ask a quesion, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
- label: This is only for bug report, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
required: true
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
4
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
@ -1,7 +1,7 @@
|
||||
name: "📚 Documentation Issue"
|
||||
description: Report issues in our documentation
|
||||
labels:
|
||||
- ducumentation
|
||||
- documentation
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
@ -12,7 +12,7 @@ body:
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
2
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
@ -12,7 +12,7 @@ body:
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
2
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
@ -12,7 +12,7 @@ body:
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: input
|
||||
attributes:
|
||||
|
||||
5
.github/workflows/api-tests.yml
vendored
5
.github/workflows/api-tests.yml
vendored
@ -46,11 +46,12 @@ jobs:
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
sandbox
|
||||
ssrf_proxy
|
||||
|
||||
- name: Run Workflow
|
||||
run: dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, Milvus, PgVecto-RS)
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
@ -58,6 +59,7 @@ jobs:
|
||||
docker/docker-compose.qdrant.yaml
|
||||
docker/docker-compose.milvus.yaml
|
||||
docker/docker-compose.pgvecto-rs.yaml
|
||||
docker/docker-compose.pgvector.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
@ -65,6 +67,7 @@ jobs:
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: dev/pytest/pytest_vdb.sh
|
||||
|
||||
@ -37,11 +37,7 @@
|
||||
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
#
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/2152" target="_blank"><img src="https://trendshift.io/api/badge/repositories/2152" alt="langgenius%2Fdify | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</p>
|
||||
Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features:
|
||||
</br> </br>
|
||||
|
||||
@ -109,7 +105,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
@ -127,7 +123,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Enterprise Feature (SSO/Access control)</td>
|
||||
<td align="center">Enterprise Features (SSO/Access control)</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
|
||||
@ -111,7 +111,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
|
||||
@ -111,7 +111,7 @@ es basados en LLM Function Calling o ReAct, y agregar herramientas preconstruida
|
||||
<td align="center">Agente</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
|
||||
@ -111,7 +111,7 @@ ités d'agent**:
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
|
||||
@ -110,7 +110,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
|
||||
<td align="center">エージェント</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
|
||||
@ -111,7 +111,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
|
||||
@ -65,7 +65,7 @@ GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON=your-google-service-account-json-base64-stri
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs, pgvector
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@ -102,6 +102,13 @@ PGVECTO_RS_USER=postgres
|
||||
PGVECTO_RS_PASSWORD=difyai123456
|
||||
PGVECTO_RS_DATABASE=postgres
|
||||
|
||||
# PGVector configuration
|
||||
PGVECTOR_HOST=127.0.0.1
|
||||
PGVECTOR_PORT=5433
|
||||
PGVECTOR_USER=postgres
|
||||
PGVECTOR_PASSWORD=postgres
|
||||
PGVECTOR_DATABASE=postgres
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
|
||||
@ -305,6 +305,14 @@ def migrate_knowledge_vector_database():
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == "pgvector":
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'pgvector',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
|
||||
@ -107,7 +107,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.6.7"
|
||||
self.CURRENT_VERSION = "0.6.8"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = get_env('EDITION')
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@ -222,7 +222,7 @@ class Config:
|
||||
|
||||
# ------------------------
|
||||
# Vector Store Configurations.
|
||||
# Currently, only support: qdrant, milvus, zilliz, weaviate, relyt
|
||||
# Currently, only support: qdrant, milvus, zilliz, weaviate, relyt, pgvector
|
||||
# ------------------------
|
||||
self.VECTOR_STORE = get_env('VECTOR_STORE')
|
||||
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
|
||||
@ -261,6 +261,13 @@ class Config:
|
||||
self.PGVECTO_RS_PASSWORD = get_env('PGVECTO_RS_PASSWORD')
|
||||
self.PGVECTO_RS_DATABASE = get_env('PGVECTO_RS_DATABASE')
|
||||
|
||||
# pgvector settings
|
||||
self.PGVECTOR_HOST = get_env('PGVECTOR_HOST')
|
||||
self.PGVECTOR_PORT = get_env('PGVECTOR_PORT')
|
||||
self.PGVECTOR_USER = get_env('PGVECTOR_USER')
|
||||
self.PGVECTOR_PASSWORD = get_env('PGVECTOR_PASSWORD')
|
||||
self.PGVECTOR_DATABASE = get_env('PGVECTOR_DATABASE')
|
||||
|
||||
# ------------------------
|
||||
# Mail Configurations.
|
||||
# ------------------------
|
||||
|
||||
@ -91,3 +91,9 @@ class DraftWorkflowNotExist(BaseHTTPException):
|
||||
error_code = 'draft_workflow_not_exist'
|
||||
description = "Draft workflow need to be initialized."
|
||||
code = 400
|
||||
|
||||
|
||||
class DraftWorkflowNotSync(BaseHTTPException):
|
||||
error_code = 'draft_workflow_not_sync'
|
||||
description = "Workflow graph might have been modified, please refresh and resubmit."
|
||||
code = 400
|
||||
|
||||
@ -7,7 +7,7 @@ from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
@ -20,6 +20,7 @@ from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -59,6 +60,7 @@ class DraftWorkflowApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('hash', type=str, required=False, location='json')
|
||||
args = parser.parse_args()
|
||||
elif 'text/plain' in content_type:
|
||||
try:
|
||||
@ -71,7 +73,8 @@ class DraftWorkflowApi(Resource):
|
||||
|
||||
args = {
|
||||
'graph': data.get('graph'),
|
||||
'features': data.get('features')
|
||||
'features': data.get('features'),
|
||||
'hash': data.get('hash')
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {'message': 'Invalid JSON data'}, 400
|
||||
@ -79,15 +82,21 @@ class DraftWorkflowApi(Resource):
|
||||
abort(415)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args.get('graph'),
|
||||
features=args.get('features'),
|
||||
account=current_user
|
||||
)
|
||||
|
||||
try:
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args.get('graph'),
|
||||
features=args.get('features'),
|
||||
unique_hash=args.get('hash'),
|
||||
account=current_user
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at)
|
||||
}
|
||||
|
||||
|
||||
@ -476,13 +476,13 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = current_app.config['VECTOR_STORE']
|
||||
if vector_type == 'milvus' or vector_type == 'pgvecto_rs' or vector_type == 'relyt':
|
||||
if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs"}:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
]
|
||||
}
|
||||
elif vector_type == 'qdrant' or vector_type == 'weaviate':
|
||||
elif vector_type in {"qdrant", "weaviate"}:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search', 'full_text_search', 'hybrid_search'
|
||||
@ -497,14 +497,13 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
|
||||
if vector_type == 'milvus' or vector_type == 'relyt':
|
||||
if vector_type in {'milvus', 'relyt', 'pgvector'}:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
]
|
||||
}
|
||||
elif vector_type == 'qdrant' or vector_type == 'weaviate':
|
||||
elif vector_type in {'qdrant', 'weaviate'}:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search', 'full_text_search', 'hybrid_search'
|
||||
|
||||
@ -8,6 +8,8 @@ from core.app.entities.task_entities import (
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
|
||||
@ -111,6 +113,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
|
||||
@ -5,6 +5,8 @@ from typing import cast
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
@ -68,4 +70,24 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
:param stream_response: stream response
|
||||
:return:
|
||||
"""
|
||||
return cls.convert_stream_full_response(stream_response)
|
||||
for chunk in stream_response:
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'workflow_run_id': chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@ -246,6 +246,24 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
"workflow_run_id": self.workflow_run_id,
|
||||
"data": {
|
||||
"id": self.data.id,
|
||||
"node_id": self.data.node_id,
|
||||
"node_type": self.data.node_type,
|
||||
"title": self.data.title,
|
||||
"index": self.data.index,
|
||||
"predecessor_node_id": self.data.predecessor_node_id,
|
||||
"inputs": None,
|
||||
"created_at": self.data.created_at,
|
||||
"extras": {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class NodeFinishStreamResponse(StreamResponse):
|
||||
"""
|
||||
@ -276,6 +294,31 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
"workflow_run_id": self.workflow_run_id,
|
||||
"data": {
|
||||
"id": self.data.id,
|
||||
"node_id": self.data.node_id,
|
||||
"node_type": self.data.node_type,
|
||||
"title": self.data.title,
|
||||
"index": self.data.index,
|
||||
"predecessor_node_id": self.data.predecessor_node_id,
|
||||
"inputs": None,
|
||||
"process_data": None,
|
||||
"outputs": None,
|
||||
"status": self.data.status,
|
||||
"error": None,
|
||||
"elapsed_time": self.data.elapsed_time,
|
||||
"execution_metadata": None,
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": []
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
|
||||
@ -1,14 +1,20 @@
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from threading import Lock
|
||||
from typing import Literal, Optional
|
||||
|
||||
from httpx import post
|
||||
from httpx import get, post
|
||||
from pydantic import BaseModel
|
||||
from yarl import URL
|
||||
|
||||
from config import get_env
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer
|
||||
from core.helper.code_executor.jinja2_transformer import Jinja2TemplateTransformer
|
||||
from core.helper.code_executor.python_transformer import PythonTemplateTransformer
|
||||
from core.helper.code_executor.python_transformer import PYTHON_STANDARD_PACKAGES, PythonTemplateTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Code Executor
|
||||
CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT')
|
||||
@ -28,7 +34,6 @@ class CodeExecutionResponse(BaseModel):
|
||||
message: str
|
||||
data: Data
|
||||
|
||||
|
||||
class CodeLanguage(str, Enum):
|
||||
PYTHON3 = 'python3'
|
||||
JINJA2 = 'jinja2'
|
||||
@ -36,6 +41,9 @@ class CodeLanguage(str, Enum):
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
dependencies_cache = {}
|
||||
dependencies_cache_lock = Lock()
|
||||
|
||||
code_template_transformers = {
|
||||
CodeLanguage.PYTHON3: PythonTemplateTransformer,
|
||||
CodeLanguage.JINJA2: Jinja2TemplateTransformer,
|
||||
@ -49,7 +57,11 @@ class CodeExecutor:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], preload: str, code: str) -> str:
|
||||
def execute_code(cls,
|
||||
language: Literal['python3', 'javascript', 'jinja2'],
|
||||
preload: str,
|
||||
code: str,
|
||||
dependencies: Optional[list[CodeDependency]] = None) -> str:
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
@ -65,9 +77,13 @@ class CodeExecutor:
|
||||
data = {
|
||||
'language': cls.code_language_to_running_language.get(language),
|
||||
'code': code,
|
||||
'preload': preload
|
||||
'preload': preload,
|
||||
'enable_network': True
|
||||
}
|
||||
|
||||
if dependencies:
|
||||
data['dependencies'] = [dependency.dict() for dependency in dependencies]
|
||||
|
||||
try:
|
||||
response = post(str(url), json=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT)
|
||||
if response.status_code == 503:
|
||||
@ -95,7 +111,7 @@ class CodeExecutor:
|
||||
return response.data.stdout
|
||||
|
||||
@classmethod
|
||||
def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict:
|
||||
def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
@ -107,11 +123,63 @@ class CodeExecutor:
|
||||
if not template_transformer:
|
||||
raise CodeExecutionException(f'Unsupported language {language}')
|
||||
|
||||
runner, preload = template_transformer.transform_caller(code, inputs)
|
||||
runner, preload, dependencies = template_transformer.transform_caller(code, inputs, dependencies)
|
||||
|
||||
try:
|
||||
response = cls.execute_code(language, preload, runner)
|
||||
response = cls.execute_code(language, preload, runner, dependencies)
|
||||
except CodeExecutionException as e:
|
||||
raise e
|
||||
|
||||
return template_transformer.transform_response(response)
|
||||
return template_transformer.transform_response(response)
|
||||
|
||||
@classmethod
|
||||
def list_dependencies(cls, language: Literal['python3']) -> list[CodeDependency]:
|
||||
with cls.dependencies_cache_lock:
|
||||
if language in cls.dependencies_cache:
|
||||
# check expiration
|
||||
dependencies = cls.dependencies_cache[language]
|
||||
if dependencies['expiration'] > time.time():
|
||||
return dependencies['data']
|
||||
# remove expired cache
|
||||
del cls.dependencies_cache[language]
|
||||
|
||||
dependencies = cls._get_dependencies(language)
|
||||
with cls.dependencies_cache_lock:
|
||||
cls.dependencies_cache[language] = {
|
||||
'data': dependencies,
|
||||
'expiration': time.time() + 60
|
||||
}
|
||||
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def _get_dependencies(cls, language: Literal['python3']) -> list[CodeDependency]:
|
||||
"""
|
||||
List dependencies
|
||||
"""
|
||||
url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'dependencies'
|
||||
|
||||
headers = {
|
||||
'X-Api-Key': CODE_EXECUTION_API_KEY
|
||||
}
|
||||
|
||||
running_language = cls.code_language_to_running_language.get(language)
|
||||
if isinstance(running_language, Enum):
|
||||
running_language = running_language.value
|
||||
|
||||
data = {
|
||||
'language': running_language,
|
||||
}
|
||||
|
||||
try:
|
||||
response = get(str(url), params=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to list dependencies, got status code {response.status_code}, please check if the sandbox service is running')
|
||||
response = response.json()
|
||||
dependencies = response.get('data', {}).get('dependencies', [])
|
||||
return [
|
||||
CodeDependency(**dependency) for dependency in dependencies if dependency.get('name') not in PYTHON_STANDARD_PACKAGES
|
||||
]
|
||||
except Exception as e:
|
||||
logger.exception(f'Failed to list dependencies: {e}')
|
||||
return []
|
||||
6
api/core/helper/code_executor/entities.py
Normal file
6
api/core/helper/code_executor/entities.py
Normal file
@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeDependency(BaseModel):
|
||||
name: str
|
||||
version: str
|
||||
@ -1,6 +1,8 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
NODEJS_RUNNER = """// declare main function here
|
||||
@ -22,7 +24,8 @@ NODEJS_PRELOAD = """"""
|
||||
|
||||
class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
|
||||
def transform_caller(cls, code: str, inputs: dict,
|
||||
dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:
|
||||
"""
|
||||
Transform code to python runner
|
||||
:param code: code
|
||||
@ -37,7 +40,7 @@ class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
runner = NODEJS_RUNNER.replace('{{code}}', code)
|
||||
runner = runner.replace('{{inputs}}', inputs_str)
|
||||
|
||||
return runner, NODEJS_PRELOAD
|
||||
return runner, NODEJS_PRELOAD, []
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str) -> dict:
|
||||
|
||||
17
api/core/helper/code_executor/jinja2_formatter.py
Normal file
17
api/core/helper/code_executor/jinja2_formatter.py
Normal file
@ -0,0 +1,17 @@
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
|
||||
|
||||
class Jinja2Formatter:
|
||||
@classmethod
|
||||
def format(cls, template: str, inputs: str) -> str:
|
||||
"""
|
||||
Format template
|
||||
:param template: template
|
||||
:param inputs: inputs
|
||||
:return:
|
||||
"""
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language='jinja2', code=template, inputs=inputs
|
||||
)
|
||||
|
||||
return result['result']
|
||||
@ -1,7 +1,10 @@
|
||||
import json
|
||||
import re
|
||||
from base64 import b64encode
|
||||
from typing import Optional
|
||||
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
from core.helper.code_executor.python_transformer import PYTHON_STANDARD_PACKAGES
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
PYTHON_RUNNER = """
|
||||
@ -58,7 +61,8 @@ if __name__ == '__main__':
|
||||
|
||||
class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
|
||||
def transform_caller(cls, code: str, inputs: dict,
|
||||
dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:
|
||||
"""
|
||||
Transform code to python runner
|
||||
:param code: code
|
||||
@ -72,7 +76,19 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
runner = PYTHON_RUNNER.replace('{{code}}', code)
|
||||
runner = runner.replace('{{inputs}}', inputs_str)
|
||||
|
||||
return runner, JINJA2_PRELOAD
|
||||
if not dependencies:
|
||||
dependencies = []
|
||||
|
||||
# add native packages and jinja2
|
||||
for package in PYTHON_STANDARD_PACKAGES.union(['jinja2']):
|
||||
dependencies.append(CodeDependency(name=package, version=''))
|
||||
|
||||
# deduplicate
|
||||
dependencies = list({
|
||||
dep.name: dep for dep in dependencies if dep.name
|
||||
}.values())
|
||||
|
||||
return runner, JINJA2_PRELOAD, dependencies
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str) -> dict:
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import json
|
||||
import re
|
||||
from base64 import b64encode
|
||||
from typing import Optional
|
||||
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
PYTHON_RUNNER = """# declare main function here
|
||||
@ -25,32 +27,17 @@ result = f'''<<RESULT>>
|
||||
print(result)
|
||||
"""
|
||||
|
||||
PYTHON_PRELOAD = """
|
||||
# prepare general imports
|
||||
import json
|
||||
import datetime
|
||||
import math
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import os
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import binascii
|
||||
import collections
|
||||
import functools
|
||||
import operator
|
||||
import itertools
|
||||
"""
|
||||
PYTHON_PRELOAD = """"""
|
||||
|
||||
PYTHON_STANDARD_PACKAGES = set([
|
||||
'json', 'datetime', 'math', 'random', 're', 'string', 'sys', 'time', 'traceback', 'uuid', 'os', 'base64',
|
||||
'hashlib', 'hmac', 'binascii', 'collections', 'functools', 'operator', 'itertools', 'uuid',
|
||||
])
|
||||
|
||||
class PythonTemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
|
||||
def transform_caller(cls, code: str, inputs: dict,
|
||||
dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:
|
||||
"""
|
||||
Transform code to python runner
|
||||
:param code: code
|
||||
@ -65,7 +52,18 @@ class PythonTemplateTransformer(TemplateTransformer):
|
||||
runner = PYTHON_RUNNER.replace('{{code}}', code)
|
||||
runner = runner.replace('{{inputs}}', inputs_str)
|
||||
|
||||
return runner, PYTHON_PRELOAD
|
||||
# add standard packages
|
||||
if dependencies is None:
|
||||
dependencies = []
|
||||
|
||||
for package in PYTHON_STANDARD_PACKAGES:
|
||||
if package not in dependencies:
|
||||
dependencies.append(CodeDependency(name=package, version=''))
|
||||
|
||||
# deduplicate
|
||||
dependencies = list({dep.name: dep for dep in dependencies if dep.name}.values())
|
||||
|
||||
return runner, PYTHON_PRELOAD, dependencies
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str) -> dict:
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
|
||||
|
||||
class TemplateTransformer(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
|
||||
def transform_caller(cls, code: str, inputs: dict,
|
||||
dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:
|
||||
"""
|
||||
Transform code to python runner
|
||||
:param code: code
|
||||
|
||||
@ -482,6 +482,82 @@ LLM_BASE_MODELS = [
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4-turbo',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.VISION,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||
ParameterRule(
|
||||
name='seed',
|
||||
label=I18nObject(
|
||||
zh_Hans='种子',
|
||||
en_US='Seed'
|
||||
),
|
||||
type='int',
|
||||
help=I18nObject(
|
||||
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
|
||||
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
|
||||
),
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='response_format',
|
||||
label=I18nObject(
|
||||
zh_Hans='回复格式',
|
||||
en_US='response_format'
|
||||
),
|
||||
type='string',
|
||||
help=I18nObject(
|
||||
zh_Hans='指定模型必须输出的格式',
|
||||
en_US='specifying the format that the model must output'
|
||||
),
|
||||
required=False,
|
||||
options=['text', 'json_object']
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=0.001,
|
||||
output=0.003,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4-turbo-2024-04-09',
|
||||
entity=AIModelEntity(
|
||||
|
||||
@ -99,6 +99,12 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo
|
||||
value: gpt-4-turbo
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo-2024-04-09
|
||||
value: gpt-4-turbo-2024-04-09
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from os.path import join
|
||||
from typing import Optional, cast
|
||||
|
||||
from httpx import Timeout
|
||||
@ -19,6 +18,7 @@ from openai import (
|
||||
)
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
@ -265,7 +265,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
client_kwargs = {
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"api_key": "1",
|
||||
"base_url": join(credentials['api_base'], 'v1')
|
||||
"base_url": str(URL(credentials['api_base']) / 'v1')
|
||||
}
|
||||
|
||||
return client_kwargs
|
||||
|
||||
@ -15,6 +15,7 @@ help:
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
- speech2text
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
|
||||
120
api/core/model_runtime/model_providers/localai/rerank/rerank.py
Normal file
120
api/core/model_runtime/model_providers/localai/rerank/rerank.py
Normal file
@ -0,0 +1,120 @@
|
||||
from json import dumps
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from requests import post
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
|
||||
class LocalaiRerankModel(RerankModel):
|
||||
"""
|
||||
LocalAI rerank model API is compatible with Jina rerank model API. So just copy the JinaRerankModel class code here.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n documents to return
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
server_url = credentials['server_url']
|
||||
model_name = model
|
||||
|
||||
if not server_url:
|
||||
raise CredentialsValidateFailedError('server_url is required')
|
||||
if not model_name:
|
||||
raise CredentialsValidateFailedError('model_name is required')
|
||||
|
||||
url = server_url
|
||||
headers = {
|
||||
'Authorization': f"Bearer {credentials.get('api_key')}",
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
"top_n": top_n
|
||||
}
|
||||
|
||||
try:
|
||||
response = post(str(URL(url) / 'rerank'), headers=headers, data=dumps(data), timeout=10)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
rerank_documents = []
|
||||
for result in results['results']:
|
||||
rerank_document = RerankDocument(
|
||||
index=result['index'],
|
||||
text=result['document']['text'],
|
||||
score=result['relevance_score'],
|
||||
)
|
||||
if score_threshold is None or result['relevance_score'] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [httpx.ConnectError],
|
||||
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||
InvokeBadRequestError: [httpx.RequestError]
|
||||
}
|
||||
@ -20,16 +20,16 @@ class MinimaxChatCompletionPro:
|
||||
Minimax Chat Completion Pro API, supports function calling
|
||||
however, we do not have enough time and energy to implement it, but the parameters are reserved
|
||||
"""
|
||||
def generate(self, model: str, api_key: str, group_id: str,
|
||||
def generate(self, model: str, api_key: str, group_id: str,
|
||||
prompt_messages: list[MinimaxMessage], model_parameters: dict,
|
||||
tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \
|
||||
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
||||
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
||||
"""
|
||||
generate chat completion
|
||||
"""
|
||||
if not api_key or not group_id:
|
||||
raise InvalidAPIKeyError('Invalid API key or group ID')
|
||||
|
||||
|
||||
url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}'
|
||||
|
||||
extra_kwargs = {}
|
||||
@ -42,7 +42,7 @@ class MinimaxChatCompletionPro:
|
||||
|
||||
if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
|
||||
extra_kwargs['top_p'] = model_parameters['top_p']
|
||||
|
||||
|
||||
if 'plugin_web_search' in model_parameters and model_parameters['plugin_web_search']:
|
||||
extra_kwargs['plugins'] = [
|
||||
'plugin_web_search'
|
||||
@ -61,7 +61,7 @@ class MinimaxChatCompletionPro:
|
||||
# check if there is a system message
|
||||
if len(prompt_messages) == 0:
|
||||
raise BadRequestError('At least one message is required')
|
||||
|
||||
|
||||
if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
|
||||
if prompt_messages[0].content:
|
||||
bot_setting['content'] = prompt_messages[0].content
|
||||
@ -70,7 +70,7 @@ class MinimaxChatCompletionPro:
|
||||
# check if there is a user message
|
||||
if len(prompt_messages) == 0:
|
||||
raise BadRequestError('At least one user message is required')
|
||||
|
||||
|
||||
messages = [message.to_dict() for message in prompt_messages]
|
||||
|
||||
headers = {
|
||||
@ -89,21 +89,21 @@ class MinimaxChatCompletionPro:
|
||||
|
||||
if tools:
|
||||
body['functions'] = tools
|
||||
body['function_call'] = { 'type': 'auto' }
|
||||
body['function_call'] = {'type': 'auto'}
|
||||
|
||||
try:
|
||||
response = post(
|
||||
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
|
||||
except Exception as e:
|
||||
raise InternalServerError(e)
|
||||
|
||||
|
||||
if response.status_code != 200:
|
||||
raise InternalServerError(response.text)
|
||||
|
||||
|
||||
if stream:
|
||||
return self._handle_stream_chat_generate_response(response)
|
||||
return self._handle_chat_generate_response(response)
|
||||
|
||||
|
||||
def _handle_error(self, code: int, msg: str):
|
||||
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
|
||||
raise InternalServerError(msg)
|
||||
@ -127,7 +127,7 @@ class MinimaxChatCompletionPro:
|
||||
code = response['base_resp']['status_code']
|
||||
msg = response['base_resp']['status_msg']
|
||||
self._handle_error(code, msg)
|
||||
|
||||
|
||||
message = MinimaxMessage(
|
||||
content=response['reply'],
|
||||
role=MinimaxMessage.Role.ASSISTANT.value
|
||||
@ -144,7 +144,6 @@ class MinimaxChatCompletionPro:
|
||||
"""
|
||||
handle stream chat generate response
|
||||
"""
|
||||
function_call_storage = None
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
@ -158,54 +157,41 @@ class MinimaxChatCompletionPro:
|
||||
msg = data['base_resp']['status_msg']
|
||||
self._handle_error(code, msg)
|
||||
|
||||
# final chunk
|
||||
if data['reply'] or 'usage' in data and data['usage']:
|
||||
total_tokens = data['usage']['total_tokens']
|
||||
message = MinimaxMessage(
|
||||
minimax_message = MinimaxMessage(
|
||||
role=MinimaxMessage.Role.ASSISTANT.value,
|
||||
content=''
|
||||
)
|
||||
message.usage = {
|
||||
minimax_message.usage = {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': total_tokens,
|
||||
'total_tokens': total_tokens
|
||||
}
|
||||
message.stop_reason = data['choices'][0]['finish_reason']
|
||||
minimax_message.stop_reason = data['choices'][0]['finish_reason']
|
||||
|
||||
if function_call_storage:
|
||||
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
function_call_message.function_call = function_call_storage
|
||||
yield function_call_message
|
||||
choices = data.get('choices', [])
|
||||
if len(choices) > 0:
|
||||
for choice in choices:
|
||||
message = choice['messages'][0]
|
||||
# append function_call message
|
||||
if 'function_call' in message:
|
||||
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
function_call_message.function_call = message['function_call']
|
||||
yield function_call_message
|
||||
|
||||
yield message
|
||||
yield minimax_message
|
||||
return
|
||||
|
||||
# partial chunk
|
||||
choices = data.get('choices', [])
|
||||
if len(choices) == 0:
|
||||
continue
|
||||
|
||||
for choice in choices:
|
||||
message = choice['messages'][0]
|
||||
|
||||
if 'function_call' in message:
|
||||
if not function_call_storage:
|
||||
function_call_storage = message['function_call']
|
||||
if 'arguments' not in function_call_storage or not function_call_storage['arguments']:
|
||||
function_call_storage['arguments'] = ''
|
||||
continue
|
||||
else:
|
||||
function_call_storage['arguments'] += message['function_call']['arguments']
|
||||
continue
|
||||
else:
|
||||
if function_call_storage:
|
||||
message['function_call'] = function_call_storage
|
||||
function_call_storage = None
|
||||
|
||||
minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
|
||||
if 'function_call' in message:
|
||||
minimax_message.function_call = message['function_call']
|
||||
|
||||
# append text message
|
||||
if 'text' in message:
|
||||
minimax_message.content = message['text']
|
||||
|
||||
yield minimax_message
|
||||
minimax_message = MinimaxMessage(content=message['text'], role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
yield minimax_message
|
||||
|
||||
@ -1,7 +1,11 @@
|
||||
- google/gemma-7b
|
||||
- google/codegemma-7b
|
||||
- google/recurrentgemma-2b
|
||||
- meta/llama2-70b
|
||||
- meta/llama3-8b-instruct
|
||||
- meta/llama3-70b-instruct
|
||||
- mistralai/mistral-large
|
||||
- mistralai/mixtral-8x7b-instruct-v0.1
|
||||
- mistralai/mixtral-8x22b-instruct-v0.1
|
||||
- fuyu-8b
|
||||
- snowflake/arctic
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
model: snowflake/arctic
|
||||
label:
|
||||
zh_Hans: snowflake/arctic
|
||||
en_US: snowflake/arctic
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
@ -22,12 +22,16 @@ from core.model_runtime.utils import helper
|
||||
class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
MODEL_SUFFIX_MAP = {
|
||||
'fuyu-8b': 'vlm/adept/fuyu-8b',
|
||||
'mistralai/mistral-large': '',
|
||||
'mistralai/mixtral-8x7b-instruct-v0.1': '',
|
||||
'mistralai/mixtral-8x22b-instruct-v0.1': '',
|
||||
'google/gemma-7b': '',
|
||||
'google/codegemma-7b': '',
|
||||
'snowflake/arctic':'',
|
||||
'meta/llama2-70b': '',
|
||||
'meta/llama3-8b-instruct': '',
|
||||
'meta/llama3-70b-instruct': ''
|
||||
'meta/llama3-70b-instruct': '',
|
||||
'google/recurrentgemma-2b': ''
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
model: mistralai/mistral-large
|
||||
label:
|
||||
zh_Hans: mistralai/mistral-large
|
||||
en_US: mistralai/mistral-large
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
@ -0,0 +1,36 @@
|
||||
model: mistralai/mixtral-8x22b-instruct-v0.1
|
||||
label:
|
||||
zh_Hans: mistralai/mixtral-8x22b-instruct-v0.1
|
||||
en_US: mistralai/mixtral-8x22b-instruct-v0.1
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 64000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
@ -0,0 +1,37 @@
|
||||
model: google/recurrentgemma-2b
|
||||
label:
|
||||
zh_Hans: google/recurrentgemma-2b
|
||||
en_US: google/recurrentgemma-2b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2048
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.2
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.7
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: Seed
|
||||
zh_Hans: 种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
@ -1,4 +1,6 @@
|
||||
- gpt-4
|
||||
- gpt-4o
|
||||
- gpt-4o-2024-05-13
|
||||
- gpt-4-turbo
|
||||
- gpt-4-turbo-2024-04-09
|
||||
- gpt-4-turbo-preview
|
||||
|
||||
@ -0,0 +1,44 @@
|
||||
model: gpt-4o-2024-05-13
|
||||
label:
|
||||
zh_Hans: gpt-4o-2024-05-13
|
||||
en_US: gpt-4o-2024-05-13
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '5.00'
|
||||
output: '15.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,44 @@
|
||||
model: gpt-4o
|
||||
label:
|
||||
zh_Hans: gpt-4o
|
||||
en_US: gpt-4o
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '5.00'
|
||||
output: '15.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -1,3 +1,9 @@
|
||||
- yi-34b-chat-0205
|
||||
- yi-34b-chat-200k
|
||||
- yi-vl-plus
|
||||
- yi-large
|
||||
- yi-medium
|
||||
- yi-vision
|
||||
- yi-medium-200k
|
||||
- yi-spark
|
||||
- yi-large-turbo
|
||||
|
||||
@ -0,0 +1,43 @@
|
||||
model: yi-large-turbo
|
||||
label:
|
||||
zh_Hans: yi-large-turbo
|
||||
en_US: yi-large-turbo
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16384
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。
|
||||
en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 16384
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.9
|
||||
min: 0.01
|
||||
max: 1.00
|
||||
help:
|
||||
zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
|
||||
en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
|
||||
pricing:
|
||||
input: '12'
|
||||
output: '12'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
43
api/core/model_runtime/model_providers/yi/llm/yi-large.yaml
Normal file
43
api/core/model_runtime/model_providers/yi/llm/yi-large.yaml
Normal file
@ -0,0 +1,43 @@
|
||||
model: yi-large
|
||||
label:
|
||||
zh_Hans: yi-large
|
||||
en_US: yi-large
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16384
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。
|
||||
en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 16384
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.9
|
||||
min: 0.01
|
||||
max: 1.00
|
||||
help:
|
||||
zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
|
||||
en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
|
||||
pricing:
|
||||
input: '20'
|
||||
output: '20'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,43 @@
|
||||
model: yi-medium-200k
|
||||
label:
|
||||
zh_Hans: yi-medium-200k
|
||||
en_US: yi-medium-200k
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 204800
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。
|
||||
en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 204800
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.9
|
||||
min: 0.01
|
||||
max: 1.00
|
||||
help:
|
||||
zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
|
||||
en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
|
||||
pricing:
|
||||
input: '12'
|
||||
output: '12'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
43
api/core/model_runtime/model_providers/yi/llm/yi-medium.yaml
Normal file
43
api/core/model_runtime/model_providers/yi/llm/yi-medium.yaml
Normal file
@ -0,0 +1,43 @@
|
||||
model: yi-medium
|
||||
label:
|
||||
zh_Hans: yi-medium
|
||||
en_US: yi-medium
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16384
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。
|
||||
en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 16384
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.9
|
||||
min: 0.01
|
||||
max: 1.00
|
||||
help:
|
||||
zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
|
||||
en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
|
||||
pricing:
|
||||
input: '2.5'
|
||||
output: '2.5'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
43
api/core/model_runtime/model_providers/yi/llm/yi-spark.yaml
Normal file
43
api/core/model_runtime/model_providers/yi/llm/yi-spark.yaml
Normal file
@ -0,0 +1,43 @@
|
||||
model: yi-spark
|
||||
label:
|
||||
zh_Hans: yi-spark
|
||||
en_US: yi-spark
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16384
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。
|
||||
en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 16384
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.9
|
||||
min: 0.01
|
||||
max: 1.00
|
||||
help:
|
||||
zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
|
||||
en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
|
||||
pricing:
|
||||
input: '1'
|
||||
output: '1'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
44
api/core/model_runtime/model_providers/yi/llm/yi-vision.yaml
Normal file
44
api/core/model_runtime/model_providers/yi/llm/yi-vision.yaml
Normal file
@ -0,0 +1,44 @@
|
||||
model: yi-vision
|
||||
label:
|
||||
zh_Hans: yi-vision
|
||||
en_US: yi-vision
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。
|
||||
en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.9
|
||||
min: 0.01
|
||||
max: 1.00
|
||||
help:
|
||||
zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
|
||||
en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
|
||||
pricing:
|
||||
input: '6'
|
||||
output: '6'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@ -2,6 +2,7 @@ from typing import Optional, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file.file_obj import FileVar
|
||||
from core.helper.code_executor.jinja2_formatter import Jinja2Formatter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -80,29 +81,35 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
prompt_messages = []
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
if prompt_template.edition_type == 'basic' or not prompt_template.edition_type:
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
||||
if memory and memory_config:
|
||||
role_prefix = memory_config.role_prefix
|
||||
prompt_inputs = self._set_histories_variable(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
raw_prompt=raw_prompt,
|
||||
role_prefix=role_prefix,
|
||||
prompt_template=prompt_template,
|
||||
prompt_inputs=prompt_inputs,
|
||||
model_config=model_config
|
||||
if memory and memory_config:
|
||||
role_prefix = memory_config.role_prefix
|
||||
prompt_inputs = self._set_histories_variable(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
raw_prompt=raw_prompt,
|
||||
role_prefix=role_prefix,
|
||||
prompt_template=prompt_template,
|
||||
prompt_inputs=prompt_inputs,
|
||||
model_config=model_config
|
||||
)
|
||||
|
||||
if query:
|
||||
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
else:
|
||||
prompt = raw_prompt
|
||||
prompt_inputs = inputs
|
||||
|
||||
if query:
|
||||
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||
|
||||
if files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
||||
@ -135,14 +142,22 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
for prompt_item in raw_prompt_list:
|
||||
raw_prompt = prompt_item.text
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
if prompt_item.edition_type == 'basic' or not prompt_item.edition_type:
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
elif prompt_item.edition_type == 'jinja2':
|
||||
prompt = raw_prompt
|
||||
prompt_inputs = inputs
|
||||
|
||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||
else:
|
||||
raise ValueError(f'Invalid edition type: {prompt_item.edition_type}')
|
||||
|
||||
if prompt_item.role == PromptMessageRole.USER:
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -11,6 +11,7 @@ class ChatModelMessage(BaseModel):
|
||||
"""
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
edition_type: Optional[Literal['basic', 'jinja2']]
|
||||
|
||||
|
||||
class CompletionModelPromptTemplate(BaseModel):
|
||||
@ -18,6 +19,7 @@ class CompletionModelPromptTemplate(BaseModel):
|
||||
Completion Model Prompt Template.
|
||||
"""
|
||||
text: str
|
||||
edition_type: Optional[Literal['basic', 'jinja2']]
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
|
||||
0
api/core/rag/datasource/vdb/pgvector/__init__.py
Normal file
0
api/core/rag/datasource/vdb/pgvector/__init__.py
Normal file
169
api/core/rag/datasource/vdb/pgvector/pgvector.py
Normal file
169
api/core/rag/datasource/vdb/pgvector/pgvector.py
Normal file
@ -0,0 +1,169 @@
|
||||
import json
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class PGVectorConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config PGVECTOR_HOST is required")
|
||||
if not values["port"]:
|
||||
raise ValueError("config PGVECTOR_PORT is required")
|
||||
if not values["user"]:
|
||||
raise ValueError("config PGVECTOR_USER is required")
|
||||
if not values["password"]:
|
||||
raise ValueError("config PGVECTOR_PASSWORD is required")
|
||||
if not values["database"]:
|
||||
raise ValueError("config PGVECTOR_DATABASE is required")
|
||||
return values
|
||||
|
||||
|
||||
SQL_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
id UUID PRIMARY KEY,
|
||||
text TEXT NOT NULL,
|
||||
meta JSONB NOT NULL,
|
||||
embedding vector({dimension}) NOT NULL
|
||||
) using heap;
|
||||
"""
|
||||
|
||||
|
||||
class PGVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: PGVectorConfig):
|
||||
super().__init__(collection_name)
|
||||
self.pool = self._create_connection_pool(config)
|
||||
self.table_name = f"embedding_{collection_name}"
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "pgvector"
|
||||
|
||||
def _create_connection_pool(self, config: PGVectorConfig):
|
||||
return psycopg2.pool.SimpleConnectionPool(
|
||||
1,
|
||||
5,
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
user=config.user,
|
||||
password=config.password,
|
||||
database=config.database,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
conn.commit()
|
||||
self.pool.putconn(conn)
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
return self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
values = []
|
||||
pks = []
|
||||
for i, doc in enumerate(documents):
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
pks.append(doc_id)
|
||||
values.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
embeddings[i],
|
||||
)
|
||||
)
|
||||
with self._get_cursor() as cur:
|
||||
psycopg2.extras.execute_values(
|
||||
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
|
||||
)
|
||||
return pks
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
docs = []
|
||||
for record in cur:
|
||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Search the nearest neighbors to a vector.
|
||||
|
||||
:param query_vector: The input vector to search for similar items.
|
||||
:param top_k: The number of nearest neighbors to return, default is 5.
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name} ORDER BY distance LIMIT {top_k}",
|
||||
(json.dumps(query_vector),),
|
||||
)
|
||||
docs = []
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# do not support bm25 search
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||
# TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
@ -164,6 +164,29 @@ class Vector:
|
||||
),
|
||||
dim=dim
|
||||
)
|
||||
elif vector_type == "pgvector":
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
|
||||
|
||||
if self._dataset.index_struct_dict:
|
||||
class_prefix: str = self._dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = self._dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": "pgvector",
|
||||
"vector_store": {"class_prefix": collection_name}}
|
||||
self._dataset.index_struct = json.dumps(index_struct_dict)
|
||||
return PGVector(
|
||||
collection_name=collection_name,
|
||||
config=PGVectorConfig(
|
||||
host=config.get("PGVECTOR_HOST"),
|
||||
port=config.get("PGVECTOR_PORT"),
|
||||
user=config.get("PGVECTOR_USER"),
|
||||
password=config.get("PGVECTOR_PASSWORD"),
|
||||
database=config.get("PGVECTOR_DATABASE"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from base64 import b64decode
|
||||
from os.path import join
|
||||
from typing import Any, Union
|
||||
|
||||
from openai import OpenAI
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
@ -23,7 +23,7 @@ class DallE2Tool(BuiltinTool):
|
||||
if not openai_base_url:
|
||||
openai_base_url = None
|
||||
else:
|
||||
openai_base_url = join(openai_base_url, 'v1')
|
||||
openai_base_url = str(URL(openai_base_url) / 'v1')
|
||||
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials['openai_api_key'],
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from base64 import b64decode
|
||||
from os.path import join
|
||||
from typing import Any, Union
|
||||
|
||||
from openai import OpenAI
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
@ -23,7 +23,7 @@ class DallE3Tool(BuiltinTool):
|
||||
if not openai_base_url:
|
||||
openai_base_url = None
|
||||
else:
|
||||
openai_base_url = join(openai_base_url, 'v1')
|
||||
openai_base_url = str(URL(openai_base_url) / 'v1')
|
||||
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials['openai_api_key'],
|
||||
|
||||
@ -2,6 +2,7 @@ import os
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
@ -61,7 +62,8 @@ class CodeNode(BaseNode):
|
||||
"children": None
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"available_dependencies": []
|
||||
}
|
||||
|
||||
return {
|
||||
@ -84,8 +86,11 @@ class CodeNode(BaseNode):
|
||||
"type": "string",
|
||||
"children": None
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"dependencies": [
|
||||
]
|
||||
},
|
||||
"available_dependencies": jsonable_encoder(CodeExecutor.list_dependencies('python3'))
|
||||
}
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
@ -115,7 +120,8 @@ class CodeNode(BaseNode):
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=code_language,
|
||||
code=code,
|
||||
inputs=variables
|
||||
inputs=variables,
|
||||
dependencies=node_data.dependencies
|
||||
)
|
||||
|
||||
# Transform result
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
@ -17,4 +18,5 @@ class CodeNodeData(BaseNodeData):
|
||||
variables: list[VariableSelector]
|
||||
code_language: Literal['python3', 'javascript']
|
||||
code: str
|
||||
outputs: dict[str, Output]
|
||||
outputs: dict[str, Output]
|
||||
dependencies: Optional[list[CodeDependency]] = None
|
||||
@ -236,7 +236,7 @@ class HttpExecutor:
|
||||
for kv in kv_paris:
|
||||
if not kv.strip():
|
||||
continue
|
||||
kv = kv.split(':')
|
||||
kv = kv.split(':', 1)
|
||||
if len(kv) == 2:
|
||||
body[kv[0].strip()] = kv[1]
|
||||
elif len(kv) == 1:
|
||||
|
||||
@ -4,6 +4,7 @@ from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
@ -37,13 +38,31 @@ class VisionConfig(BaseModel):
|
||||
enabled: bool
|
||||
configs: Optional[Configs] = None
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""
|
||||
Prompt Config.
|
||||
"""
|
||||
jinja2_variables: Optional[list[VariableSelector]] = None
|
||||
|
||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||
"""
|
||||
LLM Node Chat Model Message.
|
||||
"""
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
"""
|
||||
LLM Node Chat Model Prompt Template.
|
||||
"""
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
"""
|
||||
LLM Node Data.
|
||||
"""
|
||||
model: ModelConfig
|
||||
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
|
||||
prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
|
||||
prompt_config: Optional[PromptConfig] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
@ -17,11 +19,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
@ -39,16 +45,24 @@ class LLMNode(BaseNode):
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
node_data = cast(LLMNodeData, deepcopy(self.node_data))
|
||||
|
||||
node_inputs = None
|
||||
process_data = None
|
||||
|
||||
try:
|
||||
# init messages template
|
||||
node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)
|
||||
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data, variable_pool)
|
||||
|
||||
# fetch jinja2 inputs
|
||||
jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)
|
||||
|
||||
# merge inputs
|
||||
inputs.update(jinja_inputs)
|
||||
|
||||
node_inputs = {}
|
||||
|
||||
# fetch files
|
||||
@ -183,6 +197,86 @@ class LLMNode(BaseNode):
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
return full_text, usage
|
||||
|
||||
def _transform_chat_messages(self,
|
||||
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
"""
|
||||
Transform chat messages
|
||||
|
||||
:param messages: chat messages
|
||||
:return:
|
||||
"""
|
||||
|
||||
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
|
||||
if messages.edition_type == 'jinja2':
|
||||
messages.text = messages.jinja2_text
|
||||
|
||||
return messages
|
||||
|
||||
for message in messages:
|
||||
if message.edition_type == 'jinja2':
|
||||
message.text = message.jinja2_text
|
||||
|
||||
return messages
|
||||
|
||||
def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
Fetch jinja inputs
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
variables = {}
|
||||
|
||||
if not node_data.prompt_config:
|
||||
return variables
|
||||
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
|
||||
def parse_dict(d: dict) -> str:
|
||||
"""
|
||||
Parse dict into string
|
||||
"""
|
||||
# check if it's a context structure
|
||||
if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
|
||||
return d['content']
|
||||
|
||||
# else, parse the dict
|
||||
try:
|
||||
return json.dumps(d, ensure_ascii=False)
|
||||
except Exception:
|
||||
return str(d)
|
||||
|
||||
if isinstance(value, str):
|
||||
value = value
|
||||
elif isinstance(value, list):
|
||||
result = ''
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
result += parse_dict(item)
|
||||
elif isinstance(item, str):
|
||||
result += item
|
||||
elif isinstance(item, int | float):
|
||||
result += str(item)
|
||||
else:
|
||||
result += str(item)
|
||||
result += '\n'
|
||||
value = result.strip()
|
||||
elif isinstance(value, dict):
|
||||
value = parse_dict(value)
|
||||
elif isinstance(value, int | float):
|
||||
value = str(value)
|
||||
else:
|
||||
value = str(value)
|
||||
|
||||
variables[variable] = value
|
||||
|
||||
return variables
|
||||
|
||||
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
@ -531,25 +625,25 @@ class LLMNode(BaseNode):
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = node_data
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
|
||||
prompt_template = node_data.prompt_template
|
||||
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
if prompt.edition_type != 'jinja2':
|
||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
else:
|
||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
if prompt_template.edition_type != 'jinja2':
|
||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
@ -571,6 +665,22 @@ class LLMNode(BaseNode):
|
||||
if node_data.memory:
|
||||
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
|
||||
|
||||
if node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
if prompt.edition_type == 'jinja2':
|
||||
enable_jinja = True
|
||||
break
|
||||
else:
|
||||
if prompt_template.edition_type == 'jinja2':
|
||||
enable_jinja = True
|
||||
|
||||
if enable_jinja:
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
@ -588,7 +698,8 @@ class LLMNode(BaseNode):
|
||||
"prompts": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "You are a helpful AI assistant."
|
||||
"text": "You are a helpful AI assistant.",
|
||||
"edition_type": "basic"
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -600,7 +711,8 @@ class LLMNode(BaseNode):
|
||||
"prompt": {
|
||||
"text": "Here is the chat histories between human and assistant, inside "
|
||||
"<histories></histories> XML tags.\n\n<histories>\n{{"
|
||||
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:"
|
||||
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
|
||||
"edition_type": "basic"
|
||||
},
|
||||
"stop": ["Human:"]
|
||||
}
|
||||
|
||||
@ -259,7 +259,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
user_prompt_message_3 = ChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text,
|
||||
categories=json.dumps(categories),
|
||||
categories=json.dumps(categories, ensure_ascii=False),
|
||||
classification_instructions=instruction)
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_3)
|
||||
@ -269,7 +269,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str,
|
||||
input_text=input_text,
|
||||
categories=json.dumps(categories),
|
||||
classification_instructions=instruction)
|
||||
classification_instructions=instruction, ensure_ascii=False)
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@ -6,7 +6,7 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
|
||||
### Task
|
||||
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output.Additionally, you need to extract the key words from the text that are related to the classification.
|
||||
### Format
|
||||
The input text is in the variable text_field.Categories are specified as a category list in the variable categories or left empty for automatic determination.Classification instructions may be included to improve the classification accuracy.
|
||||
The input text is in the variable text_field.Categories are specified as a category list with two filed category_id and category_name in the variable categories .Classification instructions may be included to improve the classification accuracy.
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
### Memory
|
||||
@ -24,7 +24,8 @@ QUESTION_CLASSIFIER_USER_PROMPT_1 = """
|
||||
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """
|
||||
```json
|
||||
{"category_id": "f5660049-284f-41a7-b301-fd24176a711c",
|
||||
{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],
|
||||
"category_id": "f5660049-284f-41a7-b301-fd24176a711c",
|
||||
"category_name": "Customer Service"}
|
||||
```
|
||||
"""
|
||||
@ -37,7 +38,8 @@ QUESTION_CLASSIFIER_USER_PROMPT_2 = """
|
||||
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """
|
||||
```json
|
||||
{"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f",
|
||||
{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],
|
||||
"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f",
|
||||
"category_name": "Experience"}
|
||||
```
|
||||
"""
|
||||
@ -54,16 +56,16 @@ You are a text classification engine that analyzes text data and assigns categor
|
||||
### Task
|
||||
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification.
|
||||
### Format
|
||||
The input text is in the variable text_field. Categories are specified as a category list in the variable categories or left empty for automatic determination. Classification instructions may be included to improve the classification accuracy.
|
||||
The input text is in the variable text_field. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy.
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
### Example
|
||||
Here is the chat example between human and assistant, inside <example></example> XML tags.
|
||||
<example>
|
||||
User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], "categories": [{{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"}},{{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"}},{{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"}},{{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}}], "classification_instructions": ["classify the text based on the feedback provided by customer"]}}
|
||||
Assistant:{{"category_id": "f5660049-284f-41a7-b301-fd24176a711c","category_name": "Customer Service"}}
|
||||
Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"category_id": "f5660049-284f-41a7-b301-fd24176a711c","category_name": "Customer Service"}}
|
||||
User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"}},{{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"}},{{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}}], "classification_instructions": []}}
|
||||
Assistant:{{"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Customer Service"}}
|
||||
Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Customer Service"}}
|
||||
</example>
|
||||
### Memory
|
||||
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
|
||||
@ -148,9 +148,9 @@ class WorkflowEngineManager:
|
||||
|
||||
has_entry_node = True
|
||||
|
||||
# max steps 30 reached
|
||||
if len(workflow_run_state.workflow_nodes_and_results) > 30:
|
||||
raise ValueError('Max steps 30 reached.')
|
||||
# max steps 50 reached
|
||||
if len(workflow_run_state.workflow_nodes_and_results) > 50:
|
||||
raise ValueError('Max steps 50 reached.')
|
||||
|
||||
# or max execution time 10min reached
|
||||
if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=600):
|
||||
|
||||
@ -7,6 +7,7 @@ workflow_fields = {
|
||||
'id': fields.String,
|
||||
'graph': fields.Raw(attribute='graph_dict'),
|
||||
'features': fields.Raw(attribute='features_dict'),
|
||||
'hash': fields.String(attribute='unique_hash'),
|
||||
'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'),
|
||||
'created_at': TimestampField,
|
||||
'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True),
|
||||
|
||||
@ -48,7 +48,7 @@ class PKCS1OAEP_Cipher:
|
||||
`Crypto.Hash.SHA1` is used.
|
||||
mgfunc : callable
|
||||
A mask generation function that accepts two parameters: a string to
|
||||
use as seed, and the lenth of the mask to generate, in bytes.
|
||||
use as seed, and the length of the mask to generate, in bytes.
|
||||
If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice).
|
||||
label : bytes/bytearray/memoryview
|
||||
A label to apply to this particular encryption. If not specified,
|
||||
@ -218,7 +218,7 @@ def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None):
|
||||
|
||||
:param mgfunc:
|
||||
A mask generation function that accepts two parameters: a string to
|
||||
use as seed, and the lenth of the mask to generate, in bytes.
|
||||
use as seed, and the length of the mask to generate, in bytes.
|
||||
If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice).
|
||||
:type mgfunc: callable
|
||||
|
||||
|
||||
@ -0,0 +1,39 @@
|
||||
"""modify default model name length
|
||||
|
||||
Revision ID: 47cc7df8c4f3
|
||||
Revises: 3c7cac9521c6
|
||||
Create Date: 2024-05-10 09:48:09.046298
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '47cc7df8c4f3'
|
||||
down_revision = '3c7cac9521c6'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tenant_default_models', schema=None) as batch_op:
|
||||
batch_op.alter_column('model_name',
|
||||
existing_type=sa.VARCHAR(length=40),
|
||||
type_=sa.String(length=255),
|
||||
existing_nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tenant_default_models', schema=None) as batch_op:
|
||||
batch_op.alter_column('model_name',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.VARCHAR(length=40),
|
||||
existing_nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -113,7 +113,7 @@ class TenantDefaultModel(db.Model):
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider_name = db.Column(db.String(40), nullable=False)
|
||||
model_name = db.Column(db.String(40), nullable=False)
|
||||
model_name = db.Column(db.String(255), nullable=False)
|
||||
model_type = db.Column(db.String(40), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Optional, Union
|
||||
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models import StringUUID
|
||||
from models.account import Account
|
||||
|
||||
@ -156,6 +157,21 @@ class Workflow(db.Model):
|
||||
|
||||
return variables
|
||||
|
||||
@property
|
||||
def unique_hash(self) -> str:
|
||||
"""
|
||||
Get hash of workflow.
|
||||
|
||||
:return: hash
|
||||
"""
|
||||
entity = {
|
||||
'graph': self.graph_dict,
|
||||
'features': self.features_dict
|
||||
}
|
||||
|
||||
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(Enum):
|
||||
"""
|
||||
Workflow Run Triggered From Enum
|
||||
|
||||
@ -3,3 +3,4 @@ pytest~=8.1.1
|
||||
pytest-benchmark~=4.0.0
|
||||
pytest-env~=1.1.3
|
||||
pytest-mock~=3.14.0
|
||||
jinja2~=3.1.2
|
||||
@ -9,8 +9,8 @@ flask-restful~=0.3.10
|
||||
flask-cors~=4.0.0
|
||||
gunicorn~=22.0.0
|
||||
gevent~=23.9.1
|
||||
openai~=1.26.0
|
||||
tiktoken~=0.6.0
|
||||
openai~=1.29.0
|
||||
tiktoken~=0.7.0
|
||||
psycopg2-binary~=2.9.6
|
||||
pycryptodome==3.19.1
|
||||
python-dotenv==1.0.0
|
||||
@ -83,3 +83,4 @@ pydantic~=1.10.0
|
||||
pgvecto-rs==0.1.4
|
||||
firecrawl-py==0.0.5
|
||||
oss2==2.15.0
|
||||
pgvector==0.2.5
|
||||
|
||||
@ -196,6 +196,7 @@ class AppService:
|
||||
app_model=app,
|
||||
graph=workflow.get('graph'),
|
||||
features=workflow.get('features'),
|
||||
unique_hash=None,
|
||||
account=account
|
||||
)
|
||||
workflow_service.publish_workflow(
|
||||
|
||||
@ -1,2 +1,6 @@
|
||||
class MoreLikeThisDisabledError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowHashNotEqualError(Exception):
|
||||
pass
|
||||
|
||||
@ -21,6 +21,7 @@ from models.workflow import (
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowType,
|
||||
)
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
|
||||
@ -63,13 +64,20 @@ class WorkflowService:
|
||||
def sync_draft_workflow(self, app_model: App,
|
||||
graph: dict,
|
||||
features: dict,
|
||||
unique_hash: Optional[str],
|
||||
account: Account) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
@throws WorkflowHashNotEqualError
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if workflow:
|
||||
# validate unique hash
|
||||
if workflow.unique_hash != unique_hash:
|
||||
raise WorkflowHashNotEqualError()
|
||||
|
||||
# validate features structure
|
||||
self.validate_features_structure(
|
||||
app_model=app_model,
|
||||
|
||||
158
api/tests/integration_tests/model_runtime/localai/test_rerank.py
Normal file
158
api/tests/integration_tests/model_runtime/localai/test_rerank.py
Normal file
@ -0,0 +1,158 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from api.core.model_runtime.entities.rerank_entities import RerankResult
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
model = LocalaiRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='bge-reranker-v2-m3',
|
||||
credentials={
|
||||
'server_url': 'hahahaha',
|
||||
'completion_type': 'completion',
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'completion',
|
||||
}
|
||||
)
|
||||
|
||||
def test_invoke_rerank_model():
|
||||
model = LocalaiRerankModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL')
|
||||
},
|
||||
query='Organic skincare products for sensitive skin',
|
||||
docs=[
|
||||
"Eco-friendly kitchenware for modern homes",
|
||||
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||
"Organic cotton baby clothes for sensitive skin",
|
||||
"Natural organic skincare range for sensitive skin",
|
||||
"Tech gadgets for smart homes: 2024 edition",
|
||||
"Sustainable gardening tools and compost solutions",
|
||||
"Sensitive skin-friendly facial cleansers and toners",
|
||||
"Organic food wraps and storage solutions",
|
||||
"Yoga mats made from recycled materials"
|
||||
],
|
||||
top_n=3,
|
||||
score_threshold=0.75,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, RerankResult)
|
||||
assert len(response.docs) == 3
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from api.core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
model = LocalaiRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='bge-reranker-v2-m3',
|
||||
credentials={
|
||||
'server_url': 'hahahaha',
|
||||
'completion_type': 'completion',
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'completion',
|
||||
}
|
||||
)
|
||||
|
||||
def test_invoke_rerank_model():
|
||||
model = LocalaiRerankModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL')
|
||||
},
|
||||
query='Organic skincare products for sensitive skin',
|
||||
docs=[
|
||||
"Eco-friendly kitchenware for modern homes",
|
||||
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||
"Organic cotton baby clothes for sensitive skin",
|
||||
"Natural organic skincare range for sensitive skin",
|
||||
"Tech gadgets for smart homes: 2024 edition",
|
||||
"Sustainable gardening tools and compost solutions",
|
||||
"Sensitive skin-friendly facial cleansers and toners",
|
||||
"Organic food wraps and storage solutions",
|
||||
"Yoga mats made from recycled materials"
|
||||
],
|
||||
top_n=3,
|
||||
score_threshold=0.75,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, RerankResult)
|
||||
assert len(response.docs) == 3
|
||||
|
||||
def test__invoke():
|
||||
model = LocalaiRerankModel()
|
||||
|
||||
# Test case 1: Empty docs
|
||||
result = model._invoke(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': 'https://example.com',
|
||||
'api_key': '1234567890'
|
||||
},
|
||||
query='Organic skincare products for sensitive skin',
|
||||
docs=[],
|
||||
top_n=3,
|
||||
score_threshold=0.75,
|
||||
user="abc-123"
|
||||
)
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 0
|
||||
|
||||
# Test case 2: Valid invocation
|
||||
result = model._invoke(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': 'https://example.com',
|
||||
'api_key': '1234567890'
|
||||
},
|
||||
query='Organic skincare products for sensitive skin',
|
||||
docs=[
|
||||
"Eco-friendly kitchenware for modern homes",
|
||||
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||
"Organic cotton baby clothes for sensitive skin",
|
||||
"Natural organic skincare range for sensitive skin",
|
||||
"Tech gadgets for smart homes: 2024 edition",
|
||||
"Sustainable gardening tools and compost solutions",
|
||||
"Sensitive skin-friendly facial cleansers and toners",
|
||||
"Organic food wraps and storage solutions",
|
||||
"Yoga mats made from recycled materials"
|
||||
],
|
||||
top_n=3,
|
||||
score_threshold=0.75,
|
||||
user="abc-123"
|
||||
)
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 3
|
||||
assert all(isinstance(doc, RerankDocument) for doc in result.docs)
|
||||
30
api/tests/integration_tests/vdb/pgvector/test_pgvector.py
Normal file
30
api/tests/integration_tests/vdb/pgvector/test_pgvector.py
Normal file
@ -0,0 +1,30 @@
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class TestPGVector(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = PGVector(
|
||||
collection_name=self.collection_name,
|
||||
config=PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5433,
|
||||
user="postgres",
|
||||
password="difyai123456",
|
||||
database="dify",
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
|
||||
def test_pgvector(setup_mock_redis):
|
||||
TestPGVector().run_all_tests()
|
||||
@ -1,16 +1,19 @@
|
||||
import os
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from jinja2 import Template
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
|
||||
|
||||
class MockedCodeExecutor:
|
||||
@classmethod
|
||||
def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict:
|
||||
def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'],
|
||||
code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
|
||||
# invoke directly
|
||||
if language == 'python3':
|
||||
return {
|
||||
@ -18,7 +21,7 @@ class MockedCodeExecutor:
|
||||
}
|
||||
elif language == 'jinja2':
|
||||
return {
|
||||
"result": "3"
|
||||
"result": Template(code).render(inputs)
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -19,6 +20,7 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
@ -116,3 +118,118 @@ def test_execute_llm(setup_openai_mock):
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs['text'] is not None
|
||||
assert result.outputs['usage']['total_tokens'] > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True)
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
|
||||
"""
|
||||
Test execute LLM node with jinja2
|
||||
"""
|
||||
node = LLMNode(
|
||||
tenant_id='1',
|
||||
app_id='1',
|
||||
workflow_id='1',
|
||||
user_id='1',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
config={
|
||||
'id': 'llm',
|
||||
'data': {
|
||||
'title': '123',
|
||||
'type': 'llm',
|
||||
'model': {
|
||||
'provider': 'openai',
|
||||
'name': 'gpt-3.5-turbo',
|
||||
'mode': 'chat',
|
||||
'completion_params': {}
|
||||
},
|
||||
'prompt_config': {
|
||||
'jinja2_variables': [{
|
||||
'variable': 'sys_query',
|
||||
'value_selector': ['sys', 'query']
|
||||
}, {
|
||||
'variable': 'output',
|
||||
'value_selector': ['abc', 'output']
|
||||
}]
|
||||
},
|
||||
'prompt_template': [
|
||||
{
|
||||
'role': 'system',
|
||||
'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}',
|
||||
'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.',
|
||||
'edition_type': 'jinja2'
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'text': '{{#sys.query#}}',
|
||||
'jinja2_text': '{{sys_query}}',
|
||||
'edition_type': 'basic'
|
||||
}
|
||||
],
|
||||
'memory': None,
|
||||
'context': {
|
||||
'enabled': False
|
||||
},
|
||||
'vision': {
|
||||
'enabled': False
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather today?',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
|
||||
|
||||
credentials = {
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
}
|
||||
|
||||
provider_instance = ModelProviderFactory().get_provider_instance('openai')
|
||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
configuration=ProviderConfiguration(
|
||||
tenant_id='1',
|
||||
provider=provider_instance.get_provider_schema(),
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
using_provider_type=ProviderType.CUSTOM,
|
||||
system_configuration=SystemConfiguration(
|
||||
enabled=False
|
||||
),
|
||||
custom_configuration=CustomConfiguration(
|
||||
provider=CustomProviderConfiguration(
|
||||
credentials=credentials
|
||||
)
|
||||
)
|
||||
),
|
||||
provider_instance=provider_instance,
|
||||
model_type_instance=model_type_instance
|
||||
)
|
||||
|
||||
model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')
|
||||
|
||||
model_config = ModelConfigWithCredentialsEntity(
|
||||
model='gpt-3.5-turbo',
|
||||
provider='openai',
|
||||
mode='chat',
|
||||
credentials=credentials,
|
||||
parameters={},
|
||||
model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'),
|
||||
provider_model_bundle=provider_model_bundle
|
||||
)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
|
||||
|
||||
# execute node
|
||||
result = node.run(pool)
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert 'sunny' in json.dumps(result.process_data)
|
||||
assert 'what\'s the weather today?' in json.dumps(result.process_data)
|
||||
23
api/tests/unit_tests/libs/test_yarl.py
Normal file
23
api/tests/unit_tests/libs/test_yarl.py
Normal file
@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
from yarl import URL
|
||||
|
||||
|
||||
def test_yarl_urls():
|
||||
expected_1 = 'https://dify.ai/api'
|
||||
assert str(URL('https://dify.ai') / 'api') == expected_1
|
||||
assert str(URL('https://dify.ai/') / 'api') == expected_1
|
||||
|
||||
expected_2 = 'http://dify.ai:12345/api'
|
||||
assert str(URL('http://dify.ai:12345') / 'api') == expected_2
|
||||
assert str(URL('http://dify.ai:12345/') / 'api') == expected_2
|
||||
|
||||
expected_3 = 'https://dify.ai/api/v1'
|
||||
assert str(URL('https://dify.ai') / 'api' / 'v1') == expected_3
|
||||
assert str(URL('https://dify.ai') / 'api/v1') == expected_3
|
||||
assert str(URL('https://dify.ai/') / 'api/v1') == expected_3
|
||||
assert str(URL('https://dify.ai/api') / 'v1') == expected_3
|
||||
assert str(URL('https://dify.ai/api/') / 'v1') == expected_3
|
||||
|
||||
with pytest.raises(ValueError) as e1:
|
||||
str(URL('https://dify.ai') / '/api')
|
||||
assert str(e1.value) == "Appending path '/api' starting from slash is forbidden"
|
||||
@ -53,20 +53,38 @@ services:
|
||||
|
||||
# The DifySandbox
|
||||
sandbox:
|
||||
image: langgenius/dify-sandbox:0.1.0
|
||||
image: langgenius/dify-sandbox:0.2.0
|
||||
restart: always
|
||||
cap_add:
|
||||
# Why is sys_admin permission needed?
|
||||
# https://docs.dify.ai/getting-started/install-self-hosted/install-faq#id-16.-why-is-sys_admin-permission-needed
|
||||
- SYS_ADMIN
|
||||
environment:
|
||||
# The DifySandbox configurations
|
||||
# Make sure you are changing this key for your deployment with a strong key.
|
||||
# You can generate a strong key using `openssl rand -base64 42`.
|
||||
API_KEY: dify-sandbox
|
||||
GIN_MODE: 'release'
|
||||
WORKER_TIMEOUT: 15
|
||||
ports:
|
||||
- "8194:8194"
|
||||
ENABLE_NETWORK: 'true'
|
||||
HTTP_PROXY: 'http://ssrf_proxy:3128'
|
||||
HTTPS_PROXY: 'http://ssrf_proxy:3128'
|
||||
volumes:
|
||||
- ./volumes/sandbox/dependencies:/dependencies
|
||||
networks:
|
||||
- ssrf_proxy_network
|
||||
|
||||
# ssrf_proxy server
|
||||
# for more information, please refer to
|
||||
# https://docs.dify.ai/getting-started/install-self-hosted/install-faq#id-16.-why-is-ssrf_proxy-needed
|
||||
ssrf_proxy:
|
||||
image: ubuntu/squid:latest
|
||||
restart: always
|
||||
ports:
|
||||
- "3128:3128"
|
||||
- "8194:8194"
|
||||
volumes:
|
||||
# pls clearly modify the squid.conf file to fit your network environment.
|
||||
- ./volumes/ssrf_proxy/squid.conf:/etc/squid/squid.conf
|
||||
networks:
|
||||
- ssrf_proxy_network
|
||||
- default
|
||||
# Qdrant vector store.
|
||||
# uncomment to use qdrant as vector store.
|
||||
# (if uncommented, you need to comment out the weaviate service above,
|
||||
@ -81,3 +99,10 @@ services:
|
||||
# ports:
|
||||
# - "6333:6333"
|
||||
# - "6334:6334"
|
||||
|
||||
|
||||
networks:
|
||||
# create a network between sandbox, api and ssrf_proxy, and can not access outside.
|
||||
ssrf_proxy_network:
|
||||
driver: bridge
|
||||
internal: true
|
||||
|
||||
24
docker/docker-compose.pgvector.yaml
Normal file
24
docker/docker-compose.pgvector.yaml
Normal file
@ -0,0 +1,24 @@
|
||||
version: '3'
|
||||
services:
|
||||
# Qdrant vector store.
|
||||
pgvector:
|
||||
image: pgvector/pgvector:pg16
|
||||
restart: always
|
||||
environment:
|
||||
PGUSER: postgres
|
||||
# The password for the default postgres user.
|
||||
POSTGRES_PASSWORD: difyai123456
|
||||
# The name of the default postgres database.
|
||||
POSTGRES_DB: dify
|
||||
# postgres data directory
|
||||
PGDATA: /var/lib/postgresql/data/pgdata
|
||||
volumes:
|
||||
- ./volumes/pgvector/data:/var/lib/postgresql/data
|
||||
# uncomment to expose db(postgresql) port to host
|
||||
ports:
|
||||
- "5433:5432"
|
||||
healthcheck:
|
||||
test: [ "CMD", "pg_isready" ]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
retries: 30
|
||||
@ -2,7 +2,7 @@ version: '3'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.6.7
|
||||
image: langgenius/dify-api:0.6.8
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@ -122,6 +122,12 @@ services:
|
||||
RELYT_USER: postgres
|
||||
RELYT_PASSWORD: difyai123456
|
||||
RELYT_DATABASE: postgres
|
||||
# pgvector configurations
|
||||
PGVECTOR_HOST: pgvector
|
||||
PGVECTOR_PORT: 5432
|
||||
PGVECTOR_USER: postgres
|
||||
PGVECTOR_PASSWORD: difyai123456
|
||||
PGVECTOR_DATABASE: dify
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE: ''
|
||||
# default send from email address, if not specified
|
||||
@ -155,6 +161,9 @@ services:
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: 30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH: 30
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH: 1000
|
||||
# SSRF Proxy server
|
||||
SSRF_PROXY_HTTP_URL: 'http://ssrf_proxy:3128'
|
||||
SSRF_PROXY_HTTPS_URL: 'http://ssrf_proxy:3128'
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
@ -164,13 +173,17 @@ services:
|
||||
# uncomment to expose dify-api port to host
|
||||
# ports:
|
||||
# - "5001:5001"
|
||||
networks:
|
||||
- ssrf_proxy_network
|
||||
- default
|
||||
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.6.7
|
||||
image: langgenius/dify-api:0.6.8
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_WEB_URL: ''
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
MODE: worker
|
||||
|
||||
@ -197,7 +210,7 @@ services:
|
||||
REDIS_USE_SSL: 'false'
|
||||
# The configurations of celery broker.
|
||||
CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1
|
||||
# The type of storage to use for storing user files. Supported values are `local` and `s3` and `azure-blob`, Default: `local`
|
||||
# The type of storage to use for storing user files. Supported values are `local` and `s3` and `azure-blob` and `google-storage`, Default: `local`
|
||||
STORAGE_TYPE: local
|
||||
STORAGE_LOCAL_PATH: storage
|
||||
# The S3 storage configurations, only available when STORAGE_TYPE is `s3`.
|
||||
@ -211,7 +224,10 @@ services:
|
||||
AZURE_BLOB_ACCOUNT_KEY: 'difyai'
|
||||
AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
|
||||
AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
|
||||
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`.
|
||||
# The Google storage configurations, only available when STORAGE_TYPE is `google-storage`.
|
||||
GOOGLE_STORAGE_BUCKET_NAME: 'yout-bucket-name'
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: 'your-google-service-account-json-base64-string'
|
||||
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`, `pgvector`.
|
||||
VECTOR_STORE: weaviate
|
||||
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
|
||||
WEAVIATE_ENDPOINT: http://weaviate:8080
|
||||
@ -242,6 +258,11 @@ services:
|
||||
MAIL_TYPE: ''
|
||||
# default send from email address, if not specified
|
||||
MAIL_DEFAULT_SEND_FROM: 'YOUR EMAIL FROM (eg: no-reply <no-reply@dify.ai>)'
|
||||
SMTP_SERVER: ''
|
||||
SMTP_PORT: 587
|
||||
SMTP_USERNAME: ''
|
||||
SMTP_PASSWORD: ''
|
||||
SMTP_USE_TLS: 'true'
|
||||
# the api-key for resend (https://resend.com)
|
||||
RESEND_API_KEY: ''
|
||||
RESEND_API_URL: https://api.resend.com
|
||||
@ -251,6 +272,12 @@ services:
|
||||
RELYT_USER: postgres
|
||||
RELYT_PASSWORD: difyai123456
|
||||
RELYT_DATABASE: postgres
|
||||
# pgvector configurations
|
||||
PGVECTOR_HOST: pgvector
|
||||
PGVECTOR_PORT: 5432
|
||||
PGVECTOR_USER: postgres
|
||||
PGVECTOR_PASSWORD: difyai123456
|
||||
PGVECTOR_DATABASE: dify
|
||||
# Notion import configuration, support public and internal
|
||||
NOTION_INTEGRATION_TYPE: public
|
||||
NOTION_CLIENT_SECRET: you-client-secret
|
||||
@ -262,10 +289,13 @@ services:
|
||||
volumes:
|
||||
# Mount the storage directory to the container, for storing user files.
|
||||
- ./volumes/app/storage:/app/api/storage
|
||||
networks:
|
||||
- ssrf_proxy_network
|
||||
- default
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.6.7
|
||||
image: langgenius/dify-web:0.6.8
|
||||
restart: always
|
||||
environment:
|
||||
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is
|
||||
@ -346,18 +376,35 @@ services:
|
||||
|
||||
# The DifySandbox
|
||||
sandbox:
|
||||
image: langgenius/dify-sandbox:0.1.0
|
||||
image: langgenius/dify-sandbox:0.2.0
|
||||
restart: always
|
||||
cap_add:
|
||||
# Why is sys_admin permission needed?
|
||||
# https://docs.dify.ai/getting-started/install-self-hosted/install-faq#id-16.-why-is-sys_admin-permission-needed
|
||||
- SYS_ADMIN
|
||||
environment:
|
||||
# The DifySandbox configurations
|
||||
# Make sure you are changing this key for your deployment with a strong key.
|
||||
# You can generate a strong key using `openssl rand -base64 42`.
|
||||
API_KEY: dify-sandbox
|
||||
GIN_MODE: release
|
||||
GIN_MODE: 'release'
|
||||
WORKER_TIMEOUT: 15
|
||||
ENABLE_NETWORK: 'true'
|
||||
HTTP_PROXY: 'http://ssrf_proxy:3128'
|
||||
HTTPS_PROXY: 'http://ssrf_proxy:3128'
|
||||
volumes:
|
||||
- ./volumes/sandbox/dependencies:/dependencies
|
||||
networks:
|
||||
- ssrf_proxy_network
|
||||
|
||||
# ssrf_proxy server
|
||||
# for more information, please refer to
|
||||
# https://docs.dify.ai/getting-started/install-self-hosted/install-faq#id-16.-why-is-ssrf_proxy-needed
|
||||
ssrf_proxy:
|
||||
image: ubuntu/squid:latest
|
||||
restart: always
|
||||
volumes:
|
||||
# pls clearly modify the squid.conf file to fit your network environment.
|
||||
- ./volumes/ssrf_proxy/squid.conf:/etc/squid/squid.conf
|
||||
networks:
|
||||
- ssrf_proxy_network
|
||||
- default
|
||||
# Qdrant vector store.
|
||||
# uncomment to use qdrant as vector store.
|
||||
# (if uncommented, you need to comment out the weaviate service above,
|
||||
@ -374,6 +421,31 @@ services:
|
||||
# # - "6333:6333"
|
||||
# # - "6334:6334"
|
||||
|
||||
# The pgvector vector database.
|
||||
# Uncomment to use qdrant as vector store.
|
||||
# pgvector:
|
||||
# image: pgvector/pgvector:pg16
|
||||
# restart: always
|
||||
# environment:
|
||||
# PGUSER: postgres
|
||||
# # The password for the default postgres user.
|
||||
# POSTGRES_PASSWORD: difyai123456
|
||||
# # The name of the default postgres database.
|
||||
# POSTGRES_DB: dify
|
||||
# # postgres data directory
|
||||
# PGDATA: /var/lib/postgresql/data/pgdata
|
||||
# volumes:
|
||||
# - ./volumes/pgvector/data:/var/lib/postgresql/data
|
||||
# # uncomment to expose db(postgresql) port to host
|
||||
# # ports:
|
||||
# # - "5433:5432"
|
||||
# healthcheck:
|
||||
# test: [ "CMD", "pg_isready" ]
|
||||
# interval: 1s
|
||||
# timeout: 3s
|
||||
# retries: 30
|
||||
|
||||
|
||||
# The nginx reverse proxy.
|
||||
# used for reverse proxying the API service and Web service.
|
||||
nginx:
|
||||
@ -390,3 +462,8 @@ services:
|
||||
ports:
|
||||
- "80:80"
|
||||
#- "443:443"
|
||||
networks:
|
||||
# create a network between sandbox, api and ssrf_proxy, and can not access outside.
|
||||
ssrf_proxy_network:
|
||||
driver: bridge
|
||||
internal: true
|
||||
|
||||
50
docker/volumes/ssrf_proxy/squid.conf
Normal file
50
docker/volumes/ssrf_proxy/squid.conf
Normal file
@ -0,0 +1,50 @@
|
||||
acl localnet src 0.0.0.1-0.255.255.255 # RFC 1122 "this" network (LAN)
|
||||
acl localnet src 10.0.0.0/8 # RFC 1918 local private network (LAN)
|
||||
acl localnet src 100.64.0.0/10 # RFC 6598 shared address space (CGN)
|
||||
acl localnet src 169.254.0.0/16 # RFC 3927 link-local (directly plugged) machines
|
||||
acl localnet src 172.16.0.0/12 # RFC 1918 local private network (LAN)
|
||||
acl localnet src 192.168.0.0/16 # RFC 1918 local private network (LAN)
|
||||
acl localnet src fc00::/7 # RFC 4193 local private network range
|
||||
acl localnet src fe80::/10 # RFC 4291 link-local (directly plugged) machines
|
||||
acl SSL_ports port 443
|
||||
acl Safe_ports port 80 # http
|
||||
acl Safe_ports port 21 # ftp
|
||||
acl Safe_ports port 443 # https
|
||||
acl Safe_ports port 70 # gopher
|
||||
acl Safe_ports port 210 # wais
|
||||
acl Safe_ports port 1025-65535 # unregistered ports
|
||||
acl Safe_ports port 280 # http-mgmt
|
||||
acl Safe_ports port 488 # gss-http
|
||||
acl Safe_ports port 591 # filemaker
|
||||
acl Safe_ports port 777 # multiling http
|
||||
acl CONNECT method CONNECT
|
||||
http_access deny !Safe_ports
|
||||
http_access deny CONNECT !SSL_ports
|
||||
http_access allow localhost manager
|
||||
http_access deny manager
|
||||
http_access allow localhost
|
||||
http_access allow localnet
|
||||
http_access deny all
|
||||
|
||||
################################## Proxy Server ################################
|
||||
http_port 3128
|
||||
coredump_dir /var/spool/squid
|
||||
refresh_pattern ^ftp: 1440 20% 10080
|
||||
refresh_pattern ^gopher: 1440 0% 1440
|
||||
refresh_pattern -i (/cgi-bin/|\?) 0 0% 0
|
||||
refresh_pattern \/(Packages|Sources)(|\.bz2|\.gz|\.xz)$ 0 0% 0 refresh-ims
|
||||
refresh_pattern \/Release(|\.gpg)$ 0 0% 0 refresh-ims
|
||||
refresh_pattern \/InRelease$ 0 0% 0 refresh-ims
|
||||
refresh_pattern \/(Translation-.*)(|\.bz2|\.gz|\.xz)$ 0 0% 0 refresh-ims
|
||||
refresh_pattern . 0 20% 4320
|
||||
logfile_rotate 0
|
||||
|
||||
# upstream proxy, set to your own upstream proxy IP to avoid SSRF attacks
|
||||
# cache_peer 172.1.1.1 parent 3128 0 no-query no-digest no-netdb-exchange default
|
||||
|
||||
|
||||
################################## Reverse Proxy To Sandbox ################################
|
||||
http_port 8194 accel vhost
|
||||
cache_peer sandbox parent 8194 0 no-query originserver
|
||||
acl all src all
|
||||
http_access allow all
|
||||
@ -58,6 +58,7 @@ export type IGenerationItemProps = {
|
||||
innerClassName?: string
|
||||
contentClassName?: string
|
||||
footerClassName?: string
|
||||
hideProcessDetail?: boolean
|
||||
}
|
||||
|
||||
export const SimpleBtn = ({ className, isDisabled, onClick, children }: {
|
||||
@ -108,6 +109,7 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
varList,
|
||||
innerClassName,
|
||||
contentClassName,
|
||||
hideProcessDetail,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const params = useParams()
|
||||
@ -265,6 +267,8 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
</>
|
||||
)
|
||||
|
||||
const [currentTab, setCurrentTab] = useState<string>('DETAIL')
|
||||
|
||||
return (
|
||||
<div ref={ref} className={cn(className, isTop ? `rounded-xl border ${!isError ? 'border-gray-200 bg-white' : 'border-[#FECDCA] bg-[#FEF3F2]'} ` : 'rounded-br-xl !mt-0')}
|
||||
style={isTop
|
||||
@ -291,10 +295,10 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
<div className={`flex ${contentClassName}`}>
|
||||
<div className='grow w-0'>
|
||||
{workflowProcessData && (
|
||||
<WorkflowProcessItem grayBg hideInfo data={workflowProcessData} expand={workflowProcessData.expand} />
|
||||
<WorkflowProcessItem grayBg hideInfo data={workflowProcessData} expand={workflowProcessData.expand} hideProcessDetail={hideProcessDetail} />
|
||||
)}
|
||||
{workflowProcessData && !isError && (
|
||||
<ResultTab data={workflowProcessData} content={content} />
|
||||
<ResultTab data={workflowProcessData} content={content} currentTab={currentTab} onCurrentTabChange={setCurrentTab} />
|
||||
)}
|
||||
{isError && (
|
||||
<div className='text-gray-400 text-sm'>{t('share.generation.batchFailed.outputPlaceholder')}</div>
|
||||
@ -318,19 +322,23 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
</SimpleBtn>
|
||||
)
|
||||
}
|
||||
<SimpleBtn
|
||||
isDisabled={isError || !messageId}
|
||||
className={cn(isMobile && '!px-1.5', 'space-x-1')}
|
||||
onClick={() => {
|
||||
if (typeof content === 'string')
|
||||
copy(content)
|
||||
else
|
||||
copy(JSON.stringify(content))
|
||||
Toast.notify({ type: 'success', message: t('common.actionMsg.copySuccessfully') })
|
||||
}}>
|
||||
<Clipboard className='w-3.5 h-3.5' />
|
||||
{!isMobile && <div>{t('common.operation.copy')}</div>}
|
||||
</SimpleBtn>
|
||||
{(currentTab === 'RESULT' || !isWorkflow) && (
|
||||
<SimpleBtn
|
||||
isDisabled={isError || !messageId}
|
||||
className={cn(isMobile && '!px-1.5', 'space-x-1')}
|
||||
onClick={() => {
|
||||
const copyContent = isWorkflow ? workflowProcessData?.resultText : content
|
||||
if (typeof copyContent === 'string')
|
||||
copy(copyContent)
|
||||
else
|
||||
copy(JSON.stringify(copyContent))
|
||||
Toast.notify({ type: 'success', message: t('common.actionMsg.copySuccessfully') })
|
||||
}}>
|
||||
<Clipboard className='w-3.5 h-3.5' />
|
||||
{!isMobile && <div>{t('common.operation.copy')}</div>}
|
||||
</SimpleBtn>
|
||||
)}
|
||||
|
||||
{isInWebApp && (
|
||||
<>
|
||||
{!isWorkflow && (
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import {
|
||||
memo,
|
||||
useEffect,
|
||||
// useRef,
|
||||
useState,
|
||||
} from 'react'
|
||||
import cn from 'classnames'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
@ -16,15 +14,18 @@ import type { WorkflowProcess } from '@/app/components/base/chat/types'
|
||||
const ResultTab = ({
|
||||
data,
|
||||
content,
|
||||
currentTab,
|
||||
onCurrentTabChange,
|
||||
}: {
|
||||
data?: WorkflowProcess
|
||||
content: any
|
||||
currentTab: string
|
||||
onCurrentTabChange: (tab: string) => void
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [currentTab, setCurrentTab] = useState<string>('DETAIL')
|
||||
|
||||
const switchTab = async (tab: string) => {
|
||||
setCurrentTab(tab)
|
||||
onCurrentTabChange(tab)
|
||||
}
|
||||
useEffect(() => {
|
||||
if (data?.resultText)
|
||||
|
||||
@ -140,6 +140,7 @@ const ChatWrapper = () => {
|
||||
allToolIcons={appMeta?.tool_icons || {}}
|
||||
onFeedback={handleFeedback}
|
||||
suggestedQuestions={suggestedQuestions}
|
||||
hideProcessDetail
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
@ -37,6 +37,17 @@ const Form = () => {
|
||||
/>
|
||||
)
|
||||
}
|
||||
if (form.type === 'number') {
|
||||
return (
|
||||
<input
|
||||
className="grow h-9 rounded-lg bg-gray-100 px-2.5 outline-none appearance-none"
|
||||
type="number"
|
||||
value={newConversationInputs[variable] || ''}
|
||||
onChange={e => handleFormChange(variable, e.target.value)}
|
||||
placeholder={`${label}${!required ? `(${t('appDebug.variableTable.optional')})` : ''}`}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<PortalSelect
|
||||
|
||||
@ -134,7 +134,7 @@ const ConfigPanel = () => {
|
||||
{site?.privacy_policy
|
||||
? <div className={`flex items-center ${isMobile && 'w-full justify-end'}`}>{t('share.chat.privacyPolicyLeft')}
|
||||
<a
|
||||
className='text-gray-500'
|
||||
className='text-gray-500 px-1'
|
||||
href={site?.privacy_policy}
|
||||
target='_blank' rel='noopener noreferrer'>{t('share.chat.privacyPolicyMiddle')}</a>
|
||||
{t('share.chat.privacyPolicyRight')}
|
||||
|
||||
@ -129,19 +129,26 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
|
||||
setNewConversationInputs(newInputs)
|
||||
}, [])
|
||||
const inputsForms = useMemo(() => {
|
||||
return (appParams?.user_input_form || []).filter((item: any) => item.paragraph || item.select || item['text-input']).map((item: any) => {
|
||||
return (appParams?.user_input_form || []).filter((item: any) => item.paragraph || item.select || item['text-input'] || item.number).map((item: any) => {
|
||||
if (item.paragraph) {
|
||||
return {
|
||||
...item.paragraph,
|
||||
type: 'paragraph',
|
||||
}
|
||||
}
|
||||
if (item.number) {
|
||||
return {
|
||||
...item.number,
|
||||
type: 'number',
|
||||
}
|
||||
}
|
||||
if (item.select) {
|
||||
return {
|
||||
...item.select,
|
||||
type: 'select',
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
...item['text-input'],
|
||||
type: 'text-input',
|
||||
@ -226,7 +233,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
|
||||
setShowNewConversationItemInList(true)
|
||||
}
|
||||
}, [setShowConfigPanelBeforeChat, setShowNewConversationItemInList, checkInputsRequired])
|
||||
const currentChatInstanceRef = useRef<{ handleStop: () => void }>({ handleStop: () => {} })
|
||||
const currentChatInstanceRef = useRef<{ handleStop: () => void }>({ handleStop: () => { } })
|
||||
const handleChangeConversation = useCallback((conversationId: string) => {
|
||||
currentChatInstanceRef.current.handleStop()
|
||||
setNewConversationId('')
|
||||
|
||||
@ -31,6 +31,7 @@ type AnswerProps = {
|
||||
allToolIcons?: Record<string, string | Emoji>
|
||||
showPromptLog?: boolean
|
||||
chatAnswerContainerInner?: string
|
||||
hideProcessDetail?: boolean
|
||||
}
|
||||
const Answer: FC<AnswerProps> = ({
|
||||
item,
|
||||
@ -42,6 +43,7 @@ const Answer: FC<AnswerProps> = ({
|
||||
allToolIcons,
|
||||
showPromptLog,
|
||||
chatAnswerContainerInner,
|
||||
hideProcessDetail,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const {
|
||||
@ -129,7 +131,7 @@ const Answer: FC<AnswerProps> = ({
|
||||
}
|
||||
{
|
||||
workflowProcess && (
|
||||
<WorkflowProcess data={workflowProcess} hideInfo />
|
||||
<WorkflowProcess data={workflowProcess} hideInfo hideProcessDetail={hideProcessDetail} />
|
||||
)
|
||||
}
|
||||
{
|
||||
|
||||
@ -18,12 +18,14 @@ type WorkflowProcessProps = {
|
||||
grayBg?: boolean
|
||||
expand?: boolean
|
||||
hideInfo?: boolean
|
||||
hideProcessDetail?: boolean
|
||||
}
|
||||
const WorkflowProcessItem = ({
|
||||
data,
|
||||
grayBg,
|
||||
expand = false,
|
||||
hideInfo = false,
|
||||
hideProcessDetail = false,
|
||||
}: WorkflowProcessProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [collapse, setCollapse] = useState(!expand)
|
||||
@ -94,6 +96,7 @@ const WorkflowProcessItem = ({
|
||||
<NodePanel
|
||||
nodeInfo={node}
|
||||
hideInfo={hideInfo}
|
||||
hideProcessDetail={hideProcessDetail}
|
||||
/>
|
||||
</div>
|
||||
))
|
||||
|
||||
@ -468,7 +468,10 @@ export const useChat = (
|
||||
}))
|
||||
},
|
||||
onNodeStarted: ({ data }) => {
|
||||
responseItem.workflowProcess!.tracing!.push(data as any)
|
||||
responseItem.workflowProcess!.tracing!.push({
|
||||
...data,
|
||||
status: WorkflowRunningStatus.Running,
|
||||
} as any)
|
||||
handleUpdateChatList(produce(chatListRef.current, (draft) => {
|
||||
const currentIndex = draft.findIndex(item => item.id === responseItem.id)
|
||||
draft[currentIndex] = {
|
||||
|
||||
@ -54,6 +54,7 @@ export type ChatProps = {
|
||||
chatNode?: ReactNode
|
||||
onFeedback?: (messageId: string, feedback: Feedback) => void
|
||||
chatAnswerContainerInner?: string
|
||||
hideProcessDetail?: boolean
|
||||
}
|
||||
const Chat: FC<ChatProps> = ({
|
||||
config,
|
||||
@ -78,6 +79,7 @@ const Chat: FC<ChatProps> = ({
|
||||
chatNode,
|
||||
onFeedback,
|
||||
chatAnswerContainerInner,
|
||||
hideProcessDetail,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { currentLogItem, setCurrentLogItem, showPromptLogModal, setShowPromptLogModal, showAgentLogModal, setShowAgentLogModal } = useAppStore(useShallow(state => ({
|
||||
@ -204,6 +206,7 @@ const Chat: FC<ChatProps> = ({
|
||||
allToolIcons={allToolIcons}
|
||||
showPromptLog={showPromptLog}
|
||||
chatAnswerContainerInner={chatAnswerContainerInner}
|
||||
hideProcessDetail={hideProcessDetail}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g id="Icon">
|
||||
<path id="Icon_2" d="M14 6.00016H5C3.34315 6.00016 2 7.34331 2 9.00016C2 10.657 3.34315 12.0002 5 12.0002H8M14 6.00016L11.3333 3.3335M14 6.00016L11.3333 8.66683" stroke="#667085" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 369 B |
@ -0,0 +1,5 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g id="Icon">
|
||||
<path id="Icon_2" d="M2.66699 4.66667H9.33366C11.5428 4.66667 13.3337 6.45753 13.3337 8.66667C13.3337 10.8758 11.5428 12.6667 9.33366 12.6667H2.66699M2.66699 4.66667L5.33366 2M2.66699 4.66667L5.33366 7.33333" stroke="#667085" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 416 B |
@ -0,0 +1,5 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g id="Icon">
|
||||
<path id="Vector" d="M2.4598 3.3093L6.05377 13.551C6.25503 14.1246 7.05599 14.1516 7.29552 13.593L9.08053 9.43022C9.14793 9.27295 9.27326 9.14762 9.43053 9.08022L13.5933 7.29522C14.1519 7.05569 14.1249 6.25472 13.5513 6.05346L3.30961 2.45949C2.78207 2.27437 2.27468 2.78176 2.4598 3.3093Z" stroke="#667085" stroke-width="1.5" stroke-linejoin="round"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 474 B |
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user