Compare commits

...

26 Commits

Author SHA1 Message Date
ba79088ffc Fix SQL parser Error in MyScale vdb. (#7255) 2024-08-14 16:41:18 +08:00
3a27166c2e chore: allow download audio/video through HTTP node (#7224) 2024-08-14 16:25:59 +08:00
429e85f5d6 Fix: support hide env & conversation var in prompt editor (#7256) 2024-08-14 15:14:39 +08:00
b5d472fad7 test(*): Avoid import from api in tests. (#7251) 2024-08-14 14:09:26 +08:00
52383d0161 add support for tongyi-farui (#7248)
Co-authored-by: 雪风 <xuefeng@shifaedu.cn>
2024-08-14 14:09:13 +08:00
48d2febebf fix(api/core/tools/entities/tool_entities.py): Fix type define. (#7250) 2024-08-14 14:08:54 +08:00
ca085034de doc: add missing params (#7242) 2024-08-13 22:31:27 +08:00
f6c12b10ac chore: update package versions to 0.7.0 (#7236) 2024-08-13 22:28:06 +08:00
5b77ef01d4 chore(api/services/app_dsl_service.py): Bump DSL version to 0.1.1 (#7235) 2024-08-13 18:20:41 +08:00
5d85fad522 Revert yarn.lock (#7234) 2024-08-13 18:19:36 +08:00
2fe2e350ce add secondary sort_key when using order_by and paginate at the same time (#7225) 2024-08-13 17:39:51 +08:00
986fd5bfc6 Add gitlab support (#7179)
Co-authored-by: crazywoola <427733928@qq.com>
2024-08-13 17:36:45 +08:00
f104b930cf feat: support elasticsearch vector database (#3558)
Co-authored-by: miendinh <miendinh@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
2024-08-13 17:36:20 +08:00
4423710a13 Add ECharts feature ( #6385 ) (#6961) 2024-08-13 17:35:12 +08:00
9381c08c43 chore: not use step_boundary field (#7231) 2024-08-13 17:22:33 +08:00
0f59d76997 fix: add context_size and max_chunks to Tongyi embedding to resolve issue #7189 (#7227) 2024-08-13 16:35:22 +08:00
b3743a9ae5 chore: refactor searXNG tool (#7220) 2024-08-13 15:34:29 +08:00
13d061911b Error Exception Message Of "Message Not Exists.", Should be "Suggested Questions Is Disabled." (#7219) 2024-08-13 15:17:18 +08:00
935e72d449 Feat: conversation variable & variable assigner node (#7222)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-08-13 14:44:10 +08:00
8b55bd5828 fix: display notion document title correctly (#7215) 2024-08-13 14:05:57 +08:00
a12ddc47e7 feat: add support of speech2text function for OpenAI-API-compatible and Siliconflow (#7197) 2024-08-12 21:38:59 +08:00
57ce8449b0 feat: Support NEXT_TELEMETRY_DISABLED (#7181) 2024-08-12 19:15:41 +08:00
67b9fdaad7 siliconflow support bge-3 && bce-v1 embedding (#7198) 2024-08-12 19:14:43 +08:00
f9cf418f0f Fix/workflow run single step (#7194) 2024-08-12 17:14:17 +08:00
dfa7fe1289 chore: #7177 README_VI (#7182) 2024-08-12 15:57:21 +08:00
d2471cf6f9 Fix jp translation with new dify-doc (#7185) 2024-08-12 15:49:46 +08:00
196 changed files with 7020 additions and 987 deletions

View File

@ -76,7 +76,7 @@ jobs:
- name: Run Workflow
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale)
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch)
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: |
@ -90,5 +90,6 @@ jobs:
pgvecto-rs
pgvector
chroma
elasticsearch
- name: Test Vector Stores
run: poetry run -C api bash dev/pytest/pytest_vdb.sh

View File

@ -6,5 +6,6 @@ yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml
yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs."
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch"

View File

@ -152,7 +152,7 @@ Nhanh chóng chạy Dify trong môi trường của bạn với [hướng dẫn
Sử dụng [tài liệu](https://docs.dify.ai) của chúng tôi để tham khảo thêm và nhận hướng dẫn chi tiết hơn.
- **Dify cho doanh nghiệp / tổ chức</br>**
Chúng tôi cung cấp các tính năng bổ sung tập trung vào doanh nghiệp. [Lên lịch cuộc họp với chúng tôi](https://cal.com/guchenhe/30min) hoặc [gửi email cho chúng tôi](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) để thảo luận về nhu cầu doanh nghiệp. </br>
Chúng tôi cung cấp các tính năng bổ sung tập trung vào doanh nghiệp. [Ghi lại câu hỏi của bạn cho chúng tôi thông qua chatbot này](https://udify.app/chat/22L1zSxg6yW1cWQg) hoặc [gửi email cho chúng tôi](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) để thảo luận về nhu cầu doanh nghiệp. </br>
> Đối với các công ty khởi nghiệp và doanh nghiệp nhỏ sử dụng AWS, hãy xem [Dify Premium trên AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) và triển khai nó vào AWS VPC của riêng bạn chỉ với một cú nhấp chuột. Đây là một AMI giá cả phải chăng với tùy chọn tạo ứng dụng với logo và thương hiệu tùy chỉnh.
@ -221,23 +221,6 @@ Triển khai Dify lên Azure chỉ với một cú nhấp chuột bằng cách s
* [Discord](https://discord.gg/FngNHpbcY7). Tốt nhất cho: chia sẻ ứng dụng của bạn và giao lưu với cộng đồng.
* [Twitter](https://twitter.com/dify_ai). Tốt nhất cho: chia sẻ ứng dụng của bạn và giao lưu với cộng đồng.
Hoặc, lên lịch cuộc họp trực tiếp với một thành viên trong nhóm:
<table>
<tr>
<th>Điểm liên hệ</th>
<th>Mục đích</th>
</tr>
<tr>
<td><a href='https://cal.com/guchenhe/15min' target='_blank'><img class="schedule-button" src='https://github.com/langgenius/dify/assets/13230914/9ebcd111-1205-4d71-83d5-948d70b809f5' alt='Git-Hub-README-Button-3x' style="width: 180px; height: auto; object-fit: contain;"/></a></td>
<td>Yêu cầu kinh doanh & phản hồi sản phẩm</td>
</tr>
<tr>
<td><a href='https://cal.com/pinkbanana' target='_blank'><img class="schedule-button" src='https://github.com/langgenius/dify/assets/13230914/d1edd00a-d7e4-4513-be6c-e57038e143fd' alt='Git-Hub-README-Button-2x' style="width: 180px; height: auto; object-fit: contain;"/></a></td>
<td>Đóng góp, vấn đề & yêu cầu tính năng</td>
</tr>
</table>
## Lịch sử Yêu thích
[![Biểu đồ Lịch sử Yêu thích](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date)

View File

@ -130,6 +130,12 @@ TENCENT_VECTOR_DB_DATABASE=dify
TENCENT_VECTOR_DB_SHARD=1
TENCENT_VECTOR_DB_REPLICAS=2
# ElasticSearch configuration
ELASTICSEARCH_HOST=127.0.0.1
ELASTICSEARCH_PORT=9200
ELASTICSEARCH_USERNAME=elastic
ELASTICSEARCH_PASSWORD=elastic
# PGVECTO_RS configuration
PGVECTO_RS_HOST=localhost
PGVECTO_RS_PORT=5431

View File

@ -344,6 +344,14 @@ def migrate_knowledge_vector_database():
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ELASTICSEARCH:
dataset_id = dataset.id
index_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'elasticsearch',
"vector_store": {"class_prefix": index_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -12,19 +12,14 @@ from configs.packaging import PackagingInfo
class DifyConfig(
# Packaging info
PackagingInfo,
# Deployment configs
DeploymentConfig,
# Feature configs
FeatureConfig,
# Middleware configs
MiddlewareConfig,
# Extra service configs
ExtraServiceConfig,
# Enterprise feature configs
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
@ -36,7 +31,6 @@ class DifyConfig(
env_file='.env',
env_file_encoding='utf-8',
frozen=True,
# ignore extra attributes
extra='ignore',
)
@ -67,3 +61,5 @@ class DifyConfig(
SSRF_PROXY_HTTPS_URL: str | None = None
MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.')
MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.')

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description='Dify version',
default='0.6.16',
default='0.7.0',
)
COMMIT_SHA: str = Field(

View File

@ -17,6 +17,7 @@ from .app import (
audio,
completion,
conversation,
conversation_variables,
generator,
message,
model_config,

View File

@ -0,0 +1,61 @@
from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from fields.conversation_variable_fields import paginated_conversation_variable_fields
from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
class ConversationVariablesApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_fields)
def get(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument('conversation_id', type=str, location='args')
args = parser.parse_args()
stmt = (
select(ConversationVariable)
.where(ConversationVariable.app_id == app_model.id)
.order_by(ConversationVariable.created_at)
)
if args['conversation_id']:
stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id'])
else:
raise ValueError('conversation_id is required')
# NOTE: This is a temporary solution to avoid performance issues.
page = 1
page_size = 100
stmt = stmt.limit(page_size).offset((page - 1) * page_size)
with Session(db.engine) as session:
rows = session.scalars(stmt).all()
return {
'page': page,
'limit': page_size,
'total': len(rows),
'has_more': False,
'data': [
{
'created_at': row.created_at,
'updated_at': row.updated_at,
**row.to_variable().model_dump(),
}
for row in rows
],
}
api.add_resource(ConversationVariablesApi, '/apps/<uuid:app_id>/conversation-variables')

View File

@ -74,6 +74,7 @@ class DraftWorkflowApi(Resource):
parser.add_argument('hash', type=str, required=False, location='json')
# TODO: set this to required=True after frontend is updated
parser.add_argument('environment_variables', type=list, required=False, location='json')
parser.add_argument('conversation_variables', type=list, required=False, location='json')
args = parser.parse_args()
elif 'text/plain' in content_type:
try:
@ -88,7 +89,8 @@ class DraftWorkflowApi(Resource):
'graph': data.get('graph'),
'features': data.get('features'),
'hash': data.get('hash'),
'environment_variables': data.get('environment_variables')
'environment_variables': data.get('environment_variables'),
'conversation_variables': data.get('conversation_variables'),
}
except json.JSONDecodeError:
return {'message': 'Invalid JSON data'}, 400
@ -100,6 +102,8 @@ class DraftWorkflowApi(Resource):
try:
environment_variables_list = args.get('environment_variables') or []
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
conversation_variables_list = args.get('conversation_variables') or []
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
workflow = workflow_service.sync_draft_workflow(
app_model=app_model,
graph=args['graph'],
@ -107,6 +111,7 @@ class DraftWorkflowApi(Resource):
unique_hash=args.get('hash'),
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()

View File

@ -555,7 +555,7 @@ class DatasetRetrievalSettingApi(Resource):
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value,
@ -579,7 +579,7 @@ class DatasetRetrievalSettingMockApi(Resource):
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value,

View File

@ -178,11 +178,20 @@ class DatasetDocumentListApi(Resource):
.subquery()
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \
.order_by(sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)))
.order_by(
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(Document.position),
)
elif sort == 'created_at':
query = query.order_by(sort_logic(Document.created_at))
query = query.order_by(
sort_logic(Document.created_at),
sort_logic(Document.position),
)
else:
query = query.order_by(desc(Document.created_at))
query = query.order_by(
desc(Document.created_at),
desc(Document.position),
)
paginated_documents = query.paginate(
page=page, per_page=limit, max_per_page=100, error_out=False)

View File

@ -131,7 +131,7 @@ class MessageSuggestedApi(Resource):
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
except SuggestedQuestionsAfterAnswerDisabledError:
raise BadRequest("Message Not Exists.")
raise BadRequest("Suggested Questions Is Disabled.")
except Exception:
logging.exception("internal server error.")
raise InternalServerError()

View File

@ -3,8 +3,9 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.file.file_obj import FileExtraConfig
from core.model_runtime.entities.message_entities import PromptMessageRole
from models.model import AppMode
from models import AppMode
class ModelConfigEntity(BaseModel):
@ -200,11 +201,6 @@ class TracingConfigEntity(BaseModel):
tracing_provider: str
class FileExtraConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
class AppAdditionalFeatures(BaseModel):

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.app.app_config.entities import FileExtraConfig
from core.file.file_obj import FileExtraConfig
class FileUploadConfigManager:

View File

@ -113,7 +113,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=invoke_from,
@ -180,7 +179,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
@ -189,12 +187,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream
)
def _generate(self, app_model: App,
def _generate(self, *,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation = None,
conversation: Conversation | None = None,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
is_first_conversation = False

View File

@ -4,6 +4,9 @@ import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -17,11 +20,12 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueSto
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
from models.workflow import ConversationVariable, Workflow
logger = logging.getLogger(__name__)
@ -31,10 +35,13 @@ class AdvancedChatAppRunner(AppRunner):
AdvancedChat Application Runner
"""
def run(self, application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message) -> None:
def run(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
"""
Run application
:param application_generate_entity: application generate entity
@ -48,11 +55,11 @@ class AdvancedChatAppRunner(AppRunner):
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
raise ValueError('Workflow not initialized')
inputs = application_generate_entity.inputs
query = application_generate_entity.query
@ -68,35 +75,66 @@ class AdvancedChatAppRunner(AppRunner):
# moderation
if self.handle_input_moderation(
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id,
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=message,
query=query,
queue_manager=queue_manager,
app_generate_entity=application_generate_entity
app_record=app_record,
message=message,
query=query,
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
):
return
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
session.commit()
# Convert database entities to variables
conversation_variables = [item.to_variable() for item in conversation_variables]
# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
@ -106,43 +144,30 @@ class AdvancedChatAppRunner(AppRunner):
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth
call_depth=application_generate_entity.call_depth,
variable_pool=variable_pool,
)
def single_iteration_run(self, app_id: str, workflow_id: str,
queue_manager: AppQueueManager,
inputs: dict, node_id: str, user_id: str) -> None:
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record: App = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError("App not found")
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
workflow_callbacks = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow,
node_id=node_id,
user_id=user_id,
user_inputs=inputs,
callbacks=workflow_callbacks
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
@ -150,22 +175,25 @@ class AdvancedChatAppRunner(AppRunner):
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id
).first()
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def handle_input_moderation(
self, queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str
self,
queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
) -> bool:
"""
Handle input moderation
@ -192,17 +220,20 @@ class AdvancedChatAppRunner(AppRunner):
queue_manager=queue_manager,
text=str(e),
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
)
return True
return False
def handle_annotation_reply(self, app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
def handle_annotation_reply(
self,
app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity,
) -> bool:
"""
Handle annotation reply
:param app_record: app record
@ -217,29 +248,27 @@ class AdvancedChatAppRunner(AppRunner):
message=message,
query=query,
user_id=app_generate_entity.user_id,
invoke_from=app_generate_entity.invoke_from
invoke_from=app_generate_entity.invoke_from,
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
)
self._stream_output(
queue_manager=queue_manager,
text=annotation_reply.content,
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
)
return True
return False
def _stream_output(self, queue_manager: AppQueueManager,
text: str,
stream: bool,
stopped_by: QueueStopEvent.StopBy) -> None:
def _stream_output(
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
) -> None:
"""
Direct output
:param queue_manager: application queue manager
@ -250,21 +279,10 @@ class AdvancedChatAppRunner(AppRunner):
if stream:
index = 0
for token in text:
queue_manager.publish(
QueueTextChunkEvent(
text=token
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
index += 1
time.sleep(0.01)
else:
queue_manager.publish(
QueueTextChunkEvent(
text=text
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
queue_manager.publish(
QueueStopEvent(stopped_by=stopped_by),
PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)

View File

@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
@ -26,8 +27,7 @@ class WorkflowAppRunner:
Workflow Application Runner
"""
def run(self, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager) -> None:
def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None:
"""
Run application
:param application_generate_entity: application generate entity
@ -47,25 +47,36 @@ class WorkflowAppRunner:
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
raise ValueError('Workflow not initialized')
inputs = application_generate_entity.inputs
files = application_generate_entity.files
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# Create a variable pool.
system_inputs = {
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
)
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
@ -75,44 +86,33 @@ class WorkflowAppRunner:
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth
call_depth=application_generate_entity.call_depth,
variable_pool=variable_pool,
)
def single_iteration_run(self, app_id: str, workflow_id: str,
queue_manager: AppQueueManager,
inputs: dict, node_id: str, user_id: str) -> None:
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record: App = db.session.query(App).filter(App.id == app_id).first()
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError("App not found")
raise ValueError('App not found')
if not app_record.workflow_id:
raise ValueError("Workflow not initialized")
raise ValueError('Workflow not initialized')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
workflow_callbacks = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow,
node_id=node_id,
user_id=user_id,
user_inputs=inputs,
callbacks=workflow_callbacks
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
@ -120,11 +120,13 @@ class WorkflowAppRunner:
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id
).first()
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow

View File

@ -1,6 +1,7 @@
from .segment_group import SegmentGroup
from .segments import (
ArrayAnySegment,
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
@ -50,4 +51,5 @@ __all__ = [
'ArrayNumberVariable',
'ArrayObjectVariable',
'ArrayFileVariable',
'ArraySegment',
]

View File

@ -0,0 +1,2 @@
class VariableError(Exception):
pass

View File

@ -1,8 +1,10 @@
from collections.abc import Mapping
from typing import Any
from configs import dify_config
from core.file.file_obj import FileVar
from .exc import VariableError
from .segments import (
ArrayAnySegment,
FileSegment,
@ -29,39 +31,43 @@ from .variables import (
)
def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable:
if (value_type := m.get('value_type')) is None:
raise ValueError('missing value type')
if not m.get('name'):
raise ValueError('missing name')
if (value := m.get('value')) is None:
raise ValueError('missing value')
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if (value_type := mapping.get('value_type')) is None:
raise VariableError('missing value type')
if not mapping.get('name'):
raise VariableError('missing name')
if (value := mapping.get('value')) is None:
raise VariableError('missing value')
match value_type:
case SegmentType.STRING:
return StringVariable.model_validate(m)
result = StringVariable.model_validate(mapping)
case SegmentType.SECRET:
return SecretVariable.model_validate(m)
result = SecretVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, int):
return IntegerVariable.model_validate(m)
result = IntegerVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, float):
return FloatVariable.model_validate(m)
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise ValueError(f'invalid number value {value}')
raise VariableError(f'invalid number value {value}')
case SegmentType.FILE:
return FileVariable.model_validate(m)
result = FileVariable.model_validate(mapping)
case SegmentType.OBJECT if isinstance(value, dict):
return ObjectVariable.model_validate(
{**m, 'value': {k: build_variable_from_mapping(v) for k, v in value.items()}}
)
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
return ArrayStringVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
result = ArrayStringVariable.model_validate(mapping)
case SegmentType.ARRAY_NUMBER if isinstance(value, list):
return ArrayNumberVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
return ArrayObjectVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
result = ArrayObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_FILE if isinstance(value, list):
return ArrayFileVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
raise ValueError(f'not supported value type {value_type}')
mapping = dict(mapping)
mapping['value'] = [{'value': v} for v in value]
result = ArrayFileVariable.model_validate(mapping)
case _:
raise VariableError(f'not supported value type {value_type}')
if result.size > dify_config.MAX_VARIABLE_SIZE:
raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}')
return result
def build_segment(value: Any, /) -> Segment:
@ -74,12 +80,9 @@ def build_segment(value: Any, /) -> Segment:
if isinstance(value, float):
return FloatSegment(value=value)
if isinstance(value, dict):
# TODO: Limit the depth of the object
return ObjectSegment(value=value)
if isinstance(value, list):
# TODO: Limit the depth of the array
elements = [build_segment(v) for v in value]
return ArrayAnySegment(value=elements)
return ArrayAnySegment(value=value)
if isinstance(value, FileVar):
return FileSegment(value=value)
raise ValueError(f'not supported value {value}')

View File

@ -1,4 +1,5 @@
import json
import sys
from collections.abc import Mapping, Sequence
from typing import Any
@ -37,6 +38,10 @@ class Segment(BaseModel):
def markdown(self) -> str:
return str(self.value)
@property
def size(self) -> int:
return sys.getsizeof(self.value)
def to_object(self) -> Any:
return self.value
@ -105,28 +110,25 @@ class ArraySegment(Segment):
def markdown(self) -> str:
return '\n'.join(['- ' + item.markdown for item in self.value])
def to_object(self):
return [v.to_object() for v in self.value]
class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Segment]
value: Sequence[Any]
class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[StringSegment]
value: Sequence[str]
class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER
value: Sequence[FloatSegment | IntegerSegment]
value: Sequence[float | int]
class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[ObjectSegment]
value: Sequence[Mapping[str, Any]]
class ArrayFileSegment(ArraySegment):

View File

@ -1,14 +1,19 @@
import enum
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel
from core.app.app_config.entities import FileExtraConfig
from core.file.tool_file_parser import ToolFileParser
from core.file.upload_file_parser import UploadFileParser
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from extensions.ext_database import db
from models.model import UploadFile
class FileExtraConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
class FileType(enum.Enum):
@ -114,6 +119,7 @@ class FileVar(BaseModel):
)
def _get_data(self, force_url: bool = False) -> Optional[str]:
from models.model import UploadFile
if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.url

View File

@ -5,8 +5,7 @@ from urllib.parse import parse_qs, urlparse
import requests
from core.app.app_config.entities import FileExtraConfig
from core.file.file_obj import FileBelongsTo, FileTransferMethod, FileType, FileVar
from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser, MessageFile, UploadFile

View File

@ -2,7 +2,6 @@ import base64
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
def obfuscated_token(token: str):
@ -14,6 +13,7 @@ def obfuscated_token(token: str):
def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
raise ValueError(f'Tenant with id {tenant_id} not found')
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)

View File

@ -7,6 +7,7 @@ description:
supported_model_types:
- llm
- text-embedding
- speech2text
configurate_methods:
- customizable-model
model_credential_schema:
@ -61,6 +62,22 @@ model_credential_schema:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
show_on:
- variable: __model_type
value: llm
type: text-input
default: '4096'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
show_on:
- variable: __model_type
value: text-embedding
type: text-input
default: '4096'
placeholder:

View File

@ -0,0 +1,63 @@
from typing import IO, Optional
from urllib.parse import urljoin
import requests
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel):
"""
Model class for OpenAI Compatible Speech to text model.
"""
def _invoke(
self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None
) -> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
headers = {}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get("endpoint_url")
if not endpoint_url.endswith("/"):
endpoint_url += "/"
endpoint_url = urljoin(endpoint_url, "audio/transcriptions")
payload = {"model": model}
files = [("file", file)]
response = requests.post(endpoint_url, headers=headers, data=payload, files=files)
if response.status_code != 200:
raise InvokeBadRequestError(response.text)
response_data = response.json()
return response_data["text"]
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
audio_file_path = self._get_demo_file_path()
with open(audio_file_path, "rb") as audio_file:
self._invoke(model, credentials, audio_file)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

View File

@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
logger = logging.getLogger(__name__)
class SiliconflowProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:

View File

@ -16,6 +16,7 @@ help:
supported_model_types:
- llm
- text-embedding
- speech2text
configurate_methods:
- predefined-model
provider_credential_schema:

View File

@ -0,0 +1,5 @@
model: iic/SenseVoiceSmall
model_type: speech2text
model_properties:
file_upload_limit: 1
supported_file_extensions: mp3,wav

View File

@ -0,0 +1,32 @@
from typing import IO, Optional
from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel
class SiliconflowSpeech2TextModel(OAICompatSpeech2TextModel):
"""
Model class for Siliconflow Speech to text model.
"""
def _invoke(
self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None
) -> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
self._add_custom_parameters(credentials)
return super()._invoke(model, credentials, file)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
return super().validate_credentials(model, credentials)
@classmethod
def _add_custom_parameters(cls, credentials: dict) -> None:
credentials["endpoint_url"] = "https://api.siliconflow.cn/v1"

View File

@ -0,0 +1,5 @@
model: netease-youdao/bce-embedding-base_v1
model_type: text-embedding
model_properties:
context_size: 512
max_chunks: 1

View File

@ -0,0 +1,5 @@
model: BAAI/bge-m3
model_type: text-embedding
model_properties:
context_size: 8192
max_chunks: 1

View File

@ -0,0 +1,81 @@
model: farui-plus
label:
en_US: farui-plus
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 12288
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 2000
min: 1
max: 2000
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: seed
required: false
type: int
default: 1234
label:
zh_Hans: 随机种子
en_US: Random seed
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
- name: enable_search
type: boolean
default: false
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
use_template: response_format
pricing:
input: '0.02'
output: '0.02'
unit: '0.001'
currency: RMB

View File

@ -2,3 +2,8 @@ model: text-embedding-v1
model_type: text-embedding
model_properties:
context_size: 2048
max_chunks: 25
pricing:
input: "0.0007"
unit: "0.001"
currency: RMB

View File

@ -2,3 +2,8 @@ model: text-embedding-v2
model_type: text-embedding
model_properties:
context_size: 2048
max_chunks: 25
pricing:
input: "0.0007"
unit: "0.001"
currency: RMB

View File

@ -2,6 +2,7 @@ import time
from typing import Optional
import dashscope
import numpy as np
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import (
@ -21,11 +22,11 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
"""
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
@ -37,16 +38,44 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
:return: embeddings result
"""
credentials_kwargs = self._to_credential_kwargs(credentials)
embeddings, embedding_used_tokens = self.embed_documents(
credentials_kwargs=credentials_kwargs,
model=model,
texts=texts
)
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
inputs = []
indices = []
used_tokens = 0
for i, text in enumerate(texts):
# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text)
if num_tokens >= context_size:
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
# if num tokens is larger than context length, only use the start
inputs.append(text[0:cutoff])
else:
inputs.append(text)
indices += [i]
batched_embeddings = []
_iter = range(0, len(inputs), max_chunks)
for i in _iter:
embeddings_batch, embedding_used_tokens = self.embed_documents(
credentials_kwargs=credentials_kwargs,
model=model,
texts=inputs[i : i + max_chunks],
)
used_tokens += embedding_used_tokens
batched_embeddings += embeddings_batch
# calc usage
usage = self._calc_response_usage(
model=model, credentials=credentials, tokens=used_tokens
)
return TextEmbeddingResult(
embeddings=embeddings,
usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens),
model=model
embeddings=batched_embeddings, usage=usage, model=model
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
@ -79,12 +108,16 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
credentials_kwargs = self._to_credential_kwargs(credentials)
# call embedding model
self.embed_documents(credentials_kwargs=credentials_kwargs, model=model, texts=["ping"])
self.embed_documents(
credentials_kwargs=credentials_kwargs, model=model, texts=["ping"]
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@staticmethod
def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> tuple[list[list[float]], int]:
def embed_documents(
credentials_kwargs: dict, model: str, texts: list[str]
) -> tuple[list[list[float]], int]:
"""Call out to Tongyi's embedding endpoint.
Args:
@ -102,7 +135,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
api_key=credentials_kwargs["dashscope_api_key"],
model=model,
input=text,
text_type="document"
text_type="document",
)
data = response.output["embeddings"][0]
embeddings.append(data["embedding"])
@ -111,7 +144,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
return [list(map(float, e)) for e in embeddings], embedding_used_tokens
def _calc_response_usage(
self, model: str, credentials: dict, tokens: int
self, model: str, credentials: dict, tokens: int
) -> EmbeddingUsage:
"""
Calculate response usage
@ -125,7 +158,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
tokens=tokens,
)
# transform usage
@ -136,7 +169,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage

View File

@ -0,0 +1,191 @@
import json
from typing import Any
import requests
from elasticsearch import Elasticsearch
from flask import current_app
from pydantic import BaseModel, model_validator
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from models.dataset import Dataset
class ElasticSearchConfig(BaseModel):
host: str
port: str
username: str
password: str
@model_validator(mode='before')
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config HOST is required")
if not values['port']:
raise ValueError("config PORT is required")
if not values['username']:
raise ValueError("config USERNAME is required")
if not values['password']:
raise ValueError("config PASSWORD is required")
return values
class ElasticSearchVector(BaseVector):
def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list):
super().__init__(index_name.lower())
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
try:
client = Elasticsearch(
hosts=f'{config.host}:{config.port}',
basic_auth=(config.username, config.password),
request_timeout=100000,
retry_on_timeout=True,
max_retries=10000,
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
return client
def get_type(self) -> str:
return 'elasticsearch'
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
mapping = {
"properties": {
"text": {
"type": "text"
},
"vector": {
"type": "dense_vector",
"index": True,
"dims": dim,
"similarity": "l2_norm"
},
}
}
self._client.indices.create(index=self._collection_name, mappings=mapping)
added_ids = []
for i, text in enumerate(texts):
self._client.index(index=self._collection_name,
id=uuids[i],
document={
"text": text,
"vector": embeddings[i] if embeddings[i] else None,
"metadata": metadatas[i] if metadatas[i] else {},
})
added_ids.append(uuids[i])
self._client.indices.refresh(index=self._collection_name)
return uuids
def text_exists(self, id: str) -> bool:
return self._client.exists(index=self._collection_name, id=id).__bool__()
def delete_by_ids(self, ids: list[str]) -> None:
for id in ids:
self._client.delete(index=self._collection_name, id=id)
def delete_by_metadata_field(self, key: str, value: str) -> None:
query_str = {
'query': {
'match': {
f'metadata.{key}': f'{value}'
}
}
}
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit['_id'] for hit in results['hits']['hits']]
if ids:
self.delete_by_ids(ids)
def delete(self) -> None:
self._client.indices.delete(index=self._collection_name)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_str = {
"query": {
"script_score": {
"query": {
"match_all": {}
},
"script": {
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
"params": {
"query_vector": query_vector
}
}
}
}
}
results = self._client.search(index=self._collection_name, body=query_str)
docs_and_scores = []
for hit in results['hits']['hits']:
docs_and_scores.append(
(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score']))
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
if score > score_threshold:
doc.metadata['score'] = score
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {
"match": {
"text": query
}
}
results = self._client.search(index=self._collection_name, query=query_str)
docs = []
for hit in results['hits']['hits']:
docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']))
return docs
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
return self.add_texts(texts, embeddings, **kwargs)
class ElasticSearchVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get('ELASTICSEARCH_HOST'),
port=config.get('ELASTICSEARCH_PORT'),
username=config.get('ELASTICSEARCH_USERNAME'),
password=config.get('ELASTICSEARCH_PASSWORD'),
),
attributes=[]
)

View File

@ -93,7 +93,7 @@ class MyScaleVector(BaseVector):
@staticmethod
def escape_str(value: Any) -> str:
return "".join(f"\\{c}" if c in ("\\", "'") else c for c in str(value))
return "".join(" " if c in ("\\", "'") else c for c in str(value))
def text_exists(self, id: str) -> bool:
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
@ -118,7 +118,7 @@ class MyScaleVector(BaseVector):
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._search(f"TextSearch(text, '{query}')", SortOrder.DESC, **kwargs)
return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs)
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)

View File

@ -71,6 +71,9 @@ class Vector:
case VectorType.RELYT:
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
return RelytVectorFactory
case VectorType.ELASTICSEARCH:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
return TiDBVectorFactory

View File

@ -15,3 +15,4 @@ class VectorType(str, Enum):
OPENSEARCH = 'opensearch'
TENCENT = 'tencent'
ORACLE = 'oracle'
ELASTICSEARCH = 'elasticsearch'

View File

@ -46,7 +46,7 @@ class ToolProviderType(Enum):
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
class ApiProviderSchemaType(Enum):
"""
Enum class for api provider schema type.
@ -68,7 +68,7 @@ class ApiProviderSchemaType(Enum):
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
class ApiProviderAuthType(Enum):
"""
Enum class for api provider auth type.
@ -103,8 +103,8 @@ class ToolInvokeMessage(BaseModel):
"""
plain text, image url or link url
"""
message: Union[str, bytes, dict] = None
meta: dict[str, Any] = None
message: str | bytes | dict | None = None
meta: dict[str, Any] | None = None
save_as: str = ''
class ToolInvokeMessageBinary(BaseModel):
@ -154,8 +154,8 @@ class ToolParameter(BaseModel):
options: Optional[list[ToolParameterOption]] = None
@classmethod
def get_simple_instance(cls,
name: str, llm_description: str, type: ToolParameterType,
def get_simple_instance(cls,
name: str, llm_description: str, type: ToolParameterType,
required: bool, options: Optional[list[str]] = None) -> 'ToolParameter':
"""
get a simple tool parameter
@ -222,7 +222,7 @@ class ToolProviderCredentials(BaseModel):
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
@staticmethod
def default(value: str) -> str:
return ""
@ -290,7 +290,7 @@ class ToolRuntimeVariablePool(BaseModel):
'tenant_id': self.tenant_id,
'pool': [variable.model_dump() for variable in self.pool],
}
def set_text(self, tool_name: str, name: str, value: str) -> None:
"""
set a text variable
@ -301,7 +301,7 @@ class ToolRuntimeVariablePool(BaseModel):
variable = cast(ToolRuntimeTextVariable, variable)
variable.value = value
return
variable = ToolRuntimeTextVariable(
type=ToolRuntimeVariableType.TEXT,
name=name,
@ -334,7 +334,7 @@ class ToolRuntimeVariablePool(BaseModel):
variable = cast(ToolRuntimeImageVariable, variable)
variable.value = value
return
variable = ToolRuntimeImageVariable(
type=ToolRuntimeVariableType.IMAGE,
name=name,
@ -388,21 +388,21 @@ class ToolInvokeMeta(BaseModel):
Get an empty instance of ToolInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> 'ToolInvokeMeta':
"""
Get an instance of ToolInvokeMeta with error
"""
return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self) -> dict:
return {
'time_cost': self.time_cost,
'error': self.error,
'tool_config': self.tool_config,
}
class ToolLabel(BaseModel):
"""
Tool label
@ -416,4 +416,4 @@ class ToolInvokeFrom(Enum):
Enum class for tool invoke
"""
WORKFLOW = "workflow"
AGENT = "agent"
AGENT = "agent"

View File

@ -0,0 +1,2 @@
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
<svg width="24" height="25" viewBox="0 0 24 25" xmlns="http://www.w3.org/2000/svg" fill="none"><path fill="#FC6D26" d="M14.975 8.904L14.19 6.55l-1.552-4.67a.268.268 0 00-.255-.18.268.268 0 00-.254.18l-1.552 4.667H5.422L3.87 1.879a.267.267 0 00-.254-.179.267.267 0 00-.254.18l-1.55 4.667-.784 2.357a.515.515 0 00.193.583l6.78 4.812 6.778-4.812a.516.516 0 00.196-.583z"/><path fill="#E24329" d="M8 14.296l2.578-7.75H5.423L8 14.296z"/><path fill="#FC6D26" d="M8 14.296l-2.579-7.75H1.813L8 14.296z"/><path fill="#FCA326" d="M1.81 6.549l-.784 2.354a.515.515 0 00.193.583L8 14.3 1.81 6.55z"/><path fill="#E24329" d="M1.812 6.549h3.612L3.87 1.882a.268.268 0 00-.254-.18.268.268 0 00-.255.18L1.812 6.549z"/><path fill="#FC6D26" d="M8 14.296l2.578-7.75h3.614L8 14.296z"/><path fill="#FCA326" d="M14.19 6.549l.783 2.354a.514.514 0 01-.193.583L8 14.296l6.188-7.747h.001z"/><path fill="#E24329" d="M14.19 6.549H10.58l1.551-4.667a.267.267 0 01.255-.18c.115 0 .217.073.254.18l1.552 4.667z"/></svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -0,0 +1,34 @@
from typing import Any
import requests
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class GitlabProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
if 'access_tokens' not in credentials or not credentials.get('access_tokens'):
raise ToolProviderCredentialValidationError("Gitlab Access Tokens is required.")
if 'site_url' not in credentials or not credentials.get('site_url'):
site_url = 'https://gitlab.com'
else:
site_url = credentials.get('site_url')
try:
headers = {
"Content-Type": "application/vnd.text+json",
"Authorization": f"Bearer {credentials.get('access_tokens')}",
}
response = requests.get(
url= f"{site_url}/api/v4/user",
headers=headers)
if response.status_code != 200:
raise ToolProviderCredentialValidationError((response.json()).get('message'))
except Exception as e:
raise ToolProviderCredentialValidationError("Gitlab Access Tokens and Api Version is invalid. {}".format(e))
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,38 @@
identity:
author: Leo.Wang
name: gitlab
label:
en_US: Gitlab
zh_Hans: Gitlab
description:
en_US: Gitlab plugin for commit
zh_Hans: 用于获取Gitlab commit的插件
icon: gitlab.svg
credentials_for_provider:
access_tokens:
type: secret-input
required: true
label:
en_US: Gitlab access token
zh_Hans: Gitlab access token
placeholder:
en_US: Please input your Gitlab access token
zh_Hans: 请输入你的 Gitlab access token
help:
en_US: Get your Gitlab access token from Gitlab
zh_Hans: 从 Gitlab 获取您的 access token
url: https://docs.gitlab.com/16.9/ee/api/oauth2.html
site_url:
type: text-input
required: false
default: 'https://gitlab.com'
label:
en_US: Gitlab site url
zh_Hans: Gitlab site url
placeholder:
en_US: Please input your Gitlab site url
zh_Hans: 请输入你的 Gitlab site url
help:
en_US: Find your Gitlab url
zh_Hans: 找到你的Gitlab url
url: https://gitlab.com/help

View File

@ -0,0 +1,101 @@
import json
from datetime import datetime, timedelta
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GitlabCommitsTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
project = tool_parameters.get('project', '')
employee = tool_parameters.get('employee', '')
start_time = tool_parameters.get('start_time', '')
end_time = tool_parameters.get('end_time', '')
if not project:
return self.create_text_message('Project is required')
if not start_time:
start_time = (datetime.utcnow() - timedelta(days=1)).isoformat()
if not end_time:
end_time = datetime.utcnow().isoformat()
access_token = self.runtime.credentials.get('access_tokens')
site_url = self.runtime.credentials.get('site_url')
if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
return self.create_text_message("Gitlab API Access Tokens is required.")
if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'):
site_url = 'https://gitlab.com'
# Get commit content
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time)
return self.create_text_message(json.dumps(result, ensure_ascii=False))
def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '') -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = []
try:
# Get all of projects
url = f"{domain}/api/v4/projects"
response = requests.get(url, headers=headers)
response.raise_for_status()
projects = response.json()
filtered_projects = [p for p in projects if project == "*" or p['name'] == project]
for project in filtered_projects:
project_id = project['id']
project_name = project['name']
print(f"Project: {project_name}")
# Get all of proejct commits
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
params = {
'since': start_time,
'until': end_time
}
if employee:
params['author'] = employee
commits_response = requests.get(commits_url, headers=headers, params=params)
commits_response.raise_for_status()
commits = commits_response.json()
for commit in commits:
commit_sha = commit['id']
print(f"\tCommit SHA: {commit_sha}")
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
diff_response = requests.get(diff_url, headers=headers)
diff_response.raise_for_status()
diffs = diff_response.json()
for diff in diffs:
# Caculate code lines of changed
added_lines = diff['diff'].count('\n+')
removed_lines = diff['diff'].count('\n-')
total_changes = added_lines + removed_lines
if total_changes > 1:
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')])
results.append({
"project": project_name,
"commit_sha": commit_sha,
"diff": final_code
})
print(f"Commit code:{final_code}")
except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}")
return results

View File

@ -0,0 +1,56 @@
identity:
name: gitlab_commits
author: Leo.Wang
label:
en_US: Gitlab Commits
zh_Hans: Gitlab代码提交内容
description:
human:
en_US: A tool for query gitlab commits. Input should be a exists username.
zh_Hans: 一个用于查询gitlab代码提交记录的的工具输入的内容应该是一个已存在的用户名或者项目名。
llm: A tool for query gitlab commits. Input should be a exists username or project.
parameters:
- name: employee
type: string
required: false
label:
en_US: employee
zh_Hans: 员工用户名
human_description:
en_US: employee
zh_Hans: 员工用户名
llm_description: employee for gitlab
form: llm
- name: project
type: string
required: true
label:
en_US: project
zh_Hans: 项目名
human_description:
en_US: project
zh_Hans: 项目名
llm_description: project for gitlab
form: llm
- name: start_time
type: string
required: false
label:
en_US: start_time
zh_Hans: 开始时间
human_description:
en_US: start_time
zh_Hans: 开始时间
llm_description: start_time for gitlab
form: llm
- name: end_time
type: string
required: false
label:
en_US: end_time
zh_Hans: 结束时间
human_description:
en_US: end_time
zh_Hans: 结束时间
llm_description: end_time for gitlab
form: llm

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,54 @@
[uwsgi]
# Who will run the code
uid = searxng
gid = searxng
# Number of workers (usually CPU count)
# default value: %k (= number of CPU core, see Dockerfile)
workers = %k
# Number of threads per worker
# default value: 4 (see Dockerfile)
threads = 4
# The right granted on the created socket
chmod-socket = 666
# Plugin to use and interpreter config
single-interpreter = true
master = true
plugin = python3
lazy-apps = true
enable-threads = 4
# Module to import
module = searx.webapp
# Virtualenv and python path
pythonpath = /usr/local/searxng/
chdir = /usr/local/searxng/searx/
# automatically set processes name to something meaningful
auto-procname = true
# Disable request logging for privacy
disable-logging = true
log-5xx = true
# Set the max size of a request (request-body excluded)
buffer-size = 8192
# No keep alive
# See https://github.com/searx/searx-docker/issues/24
add-header = Connection: close
# Follow SIGTERM convention
# See https://github.com/searxng/searxng/issues/3427
die-on-term
# uwsgi serves the static files
static-map = /static=/usr/local/searxng/searx/static
# expires set to one day
static-expires = /* 86400
static-gzip-all = True
offload-threads = 4

View File

@ -17,8 +17,7 @@ class SearXNGProvider(BuiltinToolProviderController):
tool_parameters={
"query": "SearXNG",
"limit": 1,
"search_type": "page",
"result_type": "link"
"search_type": "general"
},
)
except Exception as e:

View File

@ -6,7 +6,7 @@ identity:
zh_Hans: SearXNG
description:
en_US: A free internet metasearch engine.
zh_Hans: 开源互联网元搜索引擎
zh_Hans: 开源免费的互联网元搜索引擎
icon: icon.svg
tags:
- search
@ -18,9 +18,6 @@ credentials_for_provider:
label:
en_US: SearXNG base URL
zh_Hans: SearXNG base URL
help:
en_US: Please input your SearXNG base URL
zh_Hans: 请输入您的 SearXNG base URL
placeholder:
en_US: Please input your SearXNG base URL
zh_Hans: 请输入您的 SearXNG base URL

View File

@ -1,4 +1,3 @@
import json
from typing import Any
import requests
@ -7,90 +6,11 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class SearXNGSearchResults(dict):
"""Wrapper for search results."""
def __init__(self, data: str):
super().__init__(json.loads(data))
self.__dict__ = self
@property
def results(self) -> Any:
return self.get("results", [])
class SearXNGSearchTool(BuiltinTool):
"""
Tool for performing a search using SearXNG engine.
"""
SEARCH_TYPE: dict[str, str] = {
"page": "general",
"news": "news",
"image": "images",
# "video": "videos",
# "file": "files"
}
LINK_FILED: dict[str, str] = {
"page": "url",
"news": "url",
"image": "img_src",
# "video": "iframe_src",
# "file": "magnetlink"
}
TEXT_FILED: dict[str, str] = {
"page": "content",
"news": "content",
"image": "img_src",
# "video": "iframe_src",
# "file": "magnetlink"
}
def _invoke_query(self, user_id: str, host: str, query: str, search_type: str, result_type: str, topK: int = 5) -> list[dict]:
"""Run query and return the results."""
search_type = search_type.lower()
if search_type not in self.SEARCH_TYPE.keys():
search_type= "page"
response = requests.get(host, params={
"q": query,
"format": "json",
"categories": self.SEARCH_TYPE[search_type]
})
if response.status_code != 200:
raise Exception(f'Error {response.status_code}: {response.text}')
search_results = SearXNGSearchResults(response.text).results[:topK]
if result_type == 'link':
results = []
if search_type == "page" or search_type == "news":
for r in search_results:
results.append(self.create_text_message(
text=f'{r["title"]}: {r.get(self.LINK_FILED[search_type], "")}'
))
elif search_type == "image":
for r in search_results:
results.append(self.create_image_message(
image=r.get(self.LINK_FILED[search_type], "")
))
else:
for r in search_results:
results.append(self.create_link_message(
link=r.get(self.LINK_FILED[search_type], "")
))
return results
else:
text = ''
for i, r in enumerate(search_results):
text += f'{i+1}: {r["title"]} - {r.get(self.TEXT_FILED[search_type], "")}\n'
return self.create_text_message(text=self.summary(user_id=user_id, content=text))
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
Invoke the SearXNG search tool.
@ -103,23 +23,21 @@ class SearXNGSearchTool(BuiltinTool):
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation.
"""
host = self.runtime.credentials.get('searxng_base_url', None)
host = self.runtime.credentials.get('searxng_base_url')
if not host:
raise Exception('SearXNG api is required')
query = tool_parameters.get('query')
if not query:
return self.create_text_message('Please input query')
num_results = min(tool_parameters.get('num_results', 5), 20)
search_type = tool_parameters.get('search_type', 'page') or 'page'
result_type = tool_parameters.get('result_type', 'text') or 'text'
return self._invoke_query(
user_id=user_id,
host=host,
query=query,
search_type=search_type,
result_type=result_type,
topK=num_results
)
response = requests.get(host, params={
"q": tool_parameters.get('query'),
"format": "json",
"categories": tool_parameters.get('search_type', 'general')
})
if response.status_code != 200:
raise Exception(f'Error {response.status_code}: {response.text}')
res = response.json().get("results", [])
if not res:
return self.create_text_message(f"No results found, get response: {response.content}")
return [self.create_json_message(item) for item in res]

View File

@ -1,13 +1,13 @@
identity:
name: searxng_search
author: Tice
author: Junytang
label:
en_US: SearXNG Search
zh_Hans: SearXNG 搜索
description:
human:
en_US: Perform searches on SearXNG and get results.
zh_Hans: SearXNG 上进行搜索并获取结果。
en_US: SearXNG is a free internet metasearch engine which aggregates results from more than 70 search services.
zh_Hans: SearXNG 是一个免费的互联网元搜索引擎它从70多个不同的搜索服务中聚合搜索结果。
llm: Perform searches on SearXNG and get results.
parameters:
- name: query
@ -16,9 +16,6 @@ parameters:
label:
en_US: Query string
zh_Hans: 查询语句
human_description:
en_US: The search query.
zh_Hans: 搜索查询语句。
llm_description: Key words for searching
form: llm
- name: search_type
@ -27,63 +24,46 @@ parameters:
label:
en_US: search type
zh_Hans: 搜索类型
pt_BR: search type
human_description:
en_US: search type for page, news or image.
zh_Hans: 选择搜索的类型:网页,新闻,图片。
pt_BR: search type for page, news or image.
default: Page
default: general
options:
- value: Page
- value: general
label:
en_US: Page
zh_Hans: 网页
pt_BR: Page
- value: News
en_US: General
zh_Hans: 综合
- value: images
label:
en_US: Images
zh_Hans: 图片
- value: videos
label:
en_US: Videos
zh_Hans: 视频
- value: news
label:
en_US: News
zh_Hans: 新闻
pt_BR: News
- value: Image
- value: map
label:
en_US: Image
zh_Hans:
pt_BR: Image
form: form
- name: num_results
type: number
required: true
label:
en_US: Number of query results
zh_Hans: 返回查询数量
human_description:
en_US: The number of query results.
zh_Hans: 返回查询结果的数量。
form: form
default: 5
min: 1
max: 20
- name: result_type
type: select
required: true
label:
en_US: result type
zh_Hans: 结果类型
pt_BR: result type
human_description:
en_US: return a list of links or texts.
zh_Hans: 返回一个连接列表还是纯文本内容。
pt_BR: return a list of links or texts.
default: text
options:
- value: link
en_US: Map
zh_Hans:
- value: music
label:
en_US: Link
zh_Hans: 链接
pt_BR: Link
- value: text
en_US: Music
zh_Hans: 音乐
- value: it
label:
en_US: Text
zh_Hans: 文本
pt_BR: Text
en_US: It
zh_Hans: 信息技术
- value: science
label:
en_US: Science
zh_Hans: 科学
- value: files
label:
en_US: Files
zh_Hans: 文件
- value: social_media
label:
en_US: Social Media
zh_Hans: 社交媒体
form: form

View File

@ -23,10 +23,12 @@ class NodeType(Enum):
HTTP_REQUEST = 'http-request'
TOOL = 'tool'
VARIABLE_AGGREGATOR = 'variable-aggregator'
# TODO: merge this into VARIABLE_AGGREGATOR
VARIABLE_ASSIGNER = 'variable-assigner'
LOOP = 'loop'
ITERATION = 'iteration'
PARAMETER_EXTRACTOR = 'parameter-extractor'
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
@classmethod
def value_of(cls, value: str) -> 'NodeType':

View File

@ -13,6 +13,7 @@ VariableValue = Union[str, int, float, dict, list, FileVar]
SYSTEM_VARIABLE_NODE_ID = 'sys'
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
CONVERSATION_VARIABLE_NODE_ID = 'conversation'
class VariablePool:
@ -21,6 +22,7 @@ class VariablePool:
system_variables: Mapping[SystemVariable, Any],
user_inputs: Mapping[str, Any],
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable] | None = None,
) -> None:
# system variables
# for example:
@ -44,9 +46,13 @@ class VariablePool:
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Add environment variables to the variable pool
for var in environment_variables or []:
for var in environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
# Add conversation variables to the variable pool
for var in conversation_variables or []:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Adds a variable to the variable pool.

View File

@ -8,6 +8,7 @@ from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from models import WorkflowNodeExecutionStatus
class UserFrom(Enum):
@ -91,14 +92,19 @@ class BaseNode(ABC):
:param variable_pool: variable pool
:return:
"""
result = self._run(
variable_pool=variable_pool
)
try:
result = self._run(
variable_pool=variable_pool
)
self.node_run_result = result
return result
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
self.node_run_result = result
return result
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None:
"""
Publish text chunk
:param text: chunk text

View File

@ -133,9 +133,6 @@ class HttpRequestNode(BaseNode):
"""
files = []
mimetype, file_binary = response.extract_file()
# if not image, return directly
if 'image' not in mimetype:
return files
if mimetype:
# extract filename from url

View File

@ -0,0 +1,109 @@
from collections.abc import Sequence
from enum import Enum
from typing import Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.segments import SegmentType, Variable, factory
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from extensions.ext_database import db
from models import ConversationVariable, WorkflowNodeExecutionStatus
class VariableAssignerNodeError(Exception):
pass
class WriteMode(str, Enum):
OVER_WRITE = 'over-write'
APPEND = 'append'
CLEAR = 'clear'
class VariableAssignerData(BaseNodeData):
title: str = 'Variable Assigner'
desc: Optional[str] = 'Assign a value to a variable'
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]
class VariableAssignerNode(BaseNode):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
data = cast(VariableAssignerData, self.node_data)
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = variable_pool.get(data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError('assigned variable not found')
match data.write_mode:
case WriteMode.OVER_WRITE:
income_value = variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_variable = original_variable.model_copy(update={'value': income_value.value})
case WriteMode.APPEND:
income_value = variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={'value': updated_value})
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
case _:
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
# Over write the variable.
variable_pool.add(data.assigned_variable_selector, updated_variable)
# Update conversation variable.
# TODO: Find a better way to use the database.
conversation_id = variable_pool.get(['sys', 'conversation_id'])
if not conversation_id:
raise VariableAssignerNodeError('conversation_id not found')
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
'value': income_value.to_object(),
},
)
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableAssignerNodeError('conversation variable not found in the database')
row.data = variable.model_dump_json()
session.commit()
def get_zero_value(t: SegmentType):
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return factory.build_segment([])
case SegmentType.OBJECT:
return factory.build_segment({})
case SegmentType.STRING:
return factory.build_segment('')
case SegmentType.NUMBER:
return factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f'unsupported variable type: {t}')

View File

@ -4,12 +4,11 @@ from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from configs import dify_config
from core.app.app_config.entities import FileExtraConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
from core.workflow.errors import WorkflowNodeRunFailedError
@ -30,6 +29,7 @@ from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from core.workflow.nodes.variable_assigner import VariableAssignerNode
from extensions.ext_database import db
from models.workflow import (
Workflow,
@ -51,7 +51,8 @@ node_classes: Mapping[NodeType, type[BaseNode]] = {
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode,
NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
}
logger = logging.getLogger(__name__)
@ -94,10 +95,9 @@ class WorkflowEngineManager:
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
user_inputs: Mapping[str, Any],
system_inputs: Mapping[SystemVariable, Any],
callbacks: Sequence[WorkflowCallback],
call_depth: int = 0
call_depth: int = 0,
variable_pool: VariablePool,
) -> None:
"""
:param workflow: Workflow instance
@ -122,12 +122,6 @@ class WorkflowEngineManager:
if not isinstance(graph.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
)
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
if call_depth > workflow_call_max_depth:
@ -403,6 +397,7 @@ class WorkflowEngineManager:
system_variables={},
user_inputs={},
environment_variables=workflow.environment_variables,
conversation_variables=workflow.conversation_variables,
)
if node_cls is None:
@ -468,6 +463,7 @@ class WorkflowEngineManager:
system_variables={},
user_inputs={},
environment_variables=workflow.environment_variables,
conversation_variables=workflow.conversation_variables,
)
# variable selector to variable mapping

View File

@ -0,0 +1,21 @@
from flask_restful import fields
from libs.helper import TimestampField
conversation_variable_fields = {
'id': fields.String,
'name': fields.String,
'value_type': fields.String(attribute='value_type.value'),
'value': fields.String,
'description': fields.String,
'created_at': TimestampField,
'updated_at': TimestampField,
}
paginated_conversation_variable_fields = {
'page': fields.Integer,
'limit': fields.Integer,
'total': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_variable_fields), attribute='data'),
}

View File

@ -32,11 +32,12 @@ class EnvironmentVariableField(fields.Raw):
return value
environment_variable_fields = {
conversation_variable_fields = {
'id': fields.String,
'name': fields.String,
'value': fields.Raw,
'value_type': fields.String(attribute='value_type.value'),
'value': fields.Raw,
'description': fields.String,
}
workflow_fields = {
@ -50,4 +51,5 @@ workflow_fields = {
'updated_at': TimestampField,
'tool_published': fields.Boolean,
'environment_variables': fields.List(EnvironmentVariableField()),
'conversation_variables': fields.List(fields.Nested(conversation_variable_fields)),
}

View File

@ -2,10 +2,10 @@
from abc import abstractmethod
import requests
from api.models.source import DataSourceBearerBinding
from flask_login import current_user
from extensions.ext_database import db
from models.source import DataSourceBearerBinding
class BearerDataSource:

View File

@ -154,11 +154,11 @@ class NotionOAuth(OAuthDataSource):
for page_result in page_results:
page_id = page_result['id']
page_name = 'Untitled'
for key in ['Name', 'title', 'Title', 'Page']:
if key in page_result['properties']:
if len(page_result['properties'][key].get('title', [])) > 0:
page_name = page_result['properties'][key]['title'][0]['plain_text']
break
for key in page_result['properties']:
if 'title' in page_result['properties'][key] and page_result['properties'][key]['title']:
title_list = page_result['properties'][key]['title']
if len(title_list) > 0 and 'plain_text' in title_list[0]:
page_name = title_list[0]['plain_text']
page_icon = page_result['icon']
if page_icon:
icon_type = page_icon['type']

View File

@ -0,0 +1,51 @@
"""support conversation variables
Revision ID: 63a83fcf12ba
Revises: 1787fbae959a
Create Date: 2024-08-13 06:33:07.950379
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = '63a83fcf12ba'
down_revision = '1787fbae959a'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('workflow__conversation_variables',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('conversation_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('data', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', 'conversation_id', name=op.f('workflow__conversation_variables_pkey'))
)
with op.batch_alter_table('workflow__conversation_variables', schema=None) as batch_op:
batch_op.create_index(batch_op.f('workflow__conversation_variables_app_id_idx'), ['app_id'], unique=False)
batch_op.create_index(batch_op.f('workflow__conversation_variables_created_at_idx'), ['created_at'], unique=False)
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.add_column(sa.Column('conversation_variables', sa.Text(), server_default='{}', nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.drop_column('conversation_variables')
with op.batch_alter_table('workflow__conversation_variables', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('workflow__conversation_variables_created_at_idx'))
batch_op.drop_index(batch_op.f('workflow__conversation_variables_app_id_idx'))
op.drop_table('workflow__conversation_variables')
# ### end Alembic commands ###

View File

@ -1,15 +1,19 @@
from enum import Enum
from sqlalchemy import CHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
from .model import AppMode
from .types import StringUUID
from .workflow import ConversationVariable, WorkflowNodeExecutionStatus
__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus']
class CreatedByRole(Enum):
"""
Enum class for createdByRole
"""
ACCOUNT = "account"
END_USER = "end_user"
ACCOUNT = 'account'
END_USER = 'end_user'
@classmethod
def value_of(cls, value: str) -> 'CreatedByRole':
@ -23,49 +27,3 @@ class CreatedByRole(Enum):
if role.value == value:
return role
raise ValueError(f'invalid createdByRole value {value}')
class CreatedFrom(Enum):
"""
Enum class for createdFrom
"""
SERVICE_API = "service-api"
WEB_APP = "web-app"
EXPLORE = "explore"
@classmethod
def value_of(cls, value: str) -> 'CreatedFrom':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for role in cls:
if role.value == value:
return role
raise ValueError(f'invalid createdFrom value {value}')
class StringUUID(TypeDecorator):
impl = CHAR
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
elif dialect.name == 'postgresql':
return str(value)
else:
return value.hex
def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql':
return dialect.type_descriptor(UUID())
else:
return dialect.type_descriptor(CHAR(36))
def process_result_value(self, value, dialect):
if value is None:
return value
return str(value)

View File

@ -4,7 +4,8 @@ import json
from flask_login import UserMixin
from extensions.ext_database import db
from models import StringUUID
from .types import StringUUID
class AccountStatus(str, enum.Enum):

View File

@ -1,7 +1,8 @@
import enum
from extensions.ext_database import db
from models import StringUUID
from .types import StringUUID
class APIBasedExtensionPoint(enum.Enum):

View File

@ -16,9 +16,10 @@ from configs import dify_config
from core.rag.retrieval.retrival_methods import RetrievalMethod
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import StringUUID
from models.account import Account
from models.model import App, Tag, TagBinding, UploadFile
from .account import Account
from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID
class Dataset(db.Model):

View File

@ -14,8 +14,8 @@ from core.file.upload_file_parser import UploadFileParser
from extensions.ext_database import db
from libs.helper import generate_string
from . import StringUUID
from .account import Account, Tenant
from .types import StringUUID
class DifySetup(db.Model):

View File

@ -1,7 +1,8 @@
from enum import Enum
from extensions.ext_database import db
from models import StringUUID
from .types import StringUUID
class ProviderType(Enum):

View File

@ -3,7 +3,8 @@ import json
from sqlalchemy.dialects.postgresql import JSONB
from extensions.ext_database import db
from models import StringUUID
from .types import StringUUID
class DataSourceOauthBinding(db.Model):

View File

@ -2,7 +2,8 @@ import json
from enum import Enum
from extensions.ext_database import db
from models import StringUUID
from .types import StringUUID
class ToolProviderName(Enum):

View File

@ -6,8 +6,9 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from extensions.ext_database import db
from models import StringUUID
from models.model import Account, App, Tenant
from .model import Account, App, Tenant
from .types import StringUUID
class BuiltinToolProvider(db.Model):

26
api/models/types.py Normal file
View File

@ -0,0 +1,26 @@
from sqlalchemy import CHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
class StringUUID(TypeDecorator):
impl = CHAR
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
elif dialect.name == 'postgresql':
return str(value)
else:
return value.hex
def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql':
return dialect.type_descriptor(UUID())
else:
return dialect.type_descriptor(CHAR(36))
def process_result_value(self, value, dialect):
if value is None:
return value
return str(value)

View File

@ -1,7 +1,8 @@
from extensions.ext_database import db
from models import StringUUID
from models.model import Message
from .model import Message
from .types import StringUUID
class SavedMessage(db.Model):

View File

@ -3,18 +3,18 @@ from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional, Union
from sqlalchemy import func
from sqlalchemy.orm import Mapped
import contexts
from constants import HIDDEN_VALUE
from core.app.segments import (
SecretVariable,
Variable,
factory,
)
from core.app.segments import SecretVariable, Variable, factory
from core.helper import encrypter
from extensions.ext_database import db
from libs import helper
from models import StringUUID
from models.account import Account
from .account import Account
from .types import StringUUID
class CreatedByRole(Enum):
@ -122,6 +122,7 @@ class Workflow(db.Model):
updated_by = db.Column(StringUUID)
updated_at = db.Column(db.DateTime)
_environment_variables = db.Column('environment_variables', db.Text, nullable=False, server_default='{}')
_conversation_variables = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}')
@property
def created_by_account(self):
@ -249,9 +250,27 @@ class Workflow(db.Model):
'graph': self.graph_dict,
'features': self.features_dict,
'environment_variables': [var.model_dump(mode='json') for var in environment_variables],
'conversation_variables': [var.model_dump(mode='json') for var in self.conversation_variables],
}
return result
@property
def conversation_variables(self) -> Sequence[Variable]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = '{}'
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()]
return results
@conversation_variables.setter
def conversation_variables(self, value: Sequence[Variable]) -> None:
self._conversation_variables = json.dumps(
{var.name: var.model_dump() for var in value},
ensure_ascii=False,
)
class WorkflowRunTriggeredFrom(Enum):
"""
@ -702,3 +721,34 @@ class WorkflowAppLog(db.Model):
created_by_role = CreatedByRole.value_of(self.created_by_role)
return db.session.get(EndUser, self.created_by) \
if created_by_role == CreatedByRole.END_USER else None
class ConversationVariable(db.Model):
__tablename__ = 'workflow__conversation_variables'
id: Mapped[str] = db.Column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True)
app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True)
data = db.Column(db.Text, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp())
def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None:
self.id = id
self.app_id = app_id
self.conversation_id = conversation_id
self.data = data
@classmethod
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> 'ConversationVariable':
obj = cls(
id=variable.id,
app_id=app_id,
conversation_id=conversation_id,
data=variable.model_dump_json(),
)
return obj
def to_variable(self) -> Variable:
mapping = json.loads(self.data)
return factory.build_variable_from_mapping(mapping)

40
api/poetry.lock generated
View File

@ -2100,6 +2100,44 @@ primp = ">=0.5.5"
dev = ["mypy (>=1.11.0)", "pytest (>=8.3.1)", "pytest-asyncio (>=0.23.8)", "ruff (>=0.5.5)"]
lxml = ["lxml (>=5.2.2)"]
[[package]]
name = "elastic-transport"
version = "8.15.0"
description = "Transport classes and utilities shared among Python Elastic client libraries"
optional = false
python-versions = ">=3.8"
files = [
{file = "elastic_transport-8.15.0-py3-none-any.whl", hash = "sha256:d7080d1dada2b4eee69e7574f9c17a76b42f2895eff428e562f94b0360e158c0"},
{file = "elastic_transport-8.15.0.tar.gz", hash = "sha256:85d62558f9baafb0868c801233a59b235e61d7b4804c28c2fadaa866b6766233"},
]
[package.dependencies]
certifi = "*"
urllib3 = ">=1.26.2,<3"
[package.extras]
develop = ["aiohttp", "furo", "httpx", "opentelemetry-api", "opentelemetry-sdk", "orjson", "pytest", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "pytest-mock", "requests", "respx", "sphinx (>2)", "sphinx-autodoc-typehints", "trustme"]
[[package]]
name = "elasticsearch"
version = "8.14.0"
description = "Python client for Elasticsearch"
optional = false
python-versions = ">=3.7"
files = [
{file = "elasticsearch-8.14.0-py3-none-any.whl", hash = "sha256:cef8ef70a81af027f3da74a4f7d9296b390c636903088439087b8262a468c130"},
{file = "elasticsearch-8.14.0.tar.gz", hash = "sha256:aa2490029dd96f4015b333c1827aa21fd6c0a4d223b00dfb0fe933b8d09a511b"},
]
[package.dependencies]
elastic-transport = ">=8.13,<9"
[package.extras]
async = ["aiohttp (>=3,<4)"]
orjson = ["orjson (>=3)"]
requests = ["requests (>=2.4.0,!=2.32.2,<3.0.0)"]
vectorstore-mmr = ["numpy (>=1)", "simsimd (>=3)"]
[[package]]
name = "emoji"
version = "2.12.1"
@ -9546,4 +9584,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "2b822039247a445f72e04e967aef84f841781e2789b70071acad022f36ba26a5"
content-hash = "05dfa6b9bce9ed8ac21caf58eff1596f146080ab2ab6987924b189be673c22cf"

View File

@ -181,6 +181,7 @@ zhipuai = "1.0.7"
rank-bm25 = "~0.2.2"
openpyxl = "^3.1.5"
kaleido = "0.2.1"
elasticsearch = "8.14.0"
############################################################
# Tool dependencies required by tool implementations

View File

@ -13,9 +13,9 @@ from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
current_dsl_version = "0.1.0"
current_dsl_version = "0.1.1"
dsl_to_dify_version_mapping: dict[str, str] = {
"0.1.0": "0.6.0", # dsl version -> from dify version
"0.1.1": "0.6.0", # dsl version -> from dify version
}
@ -238,6 +238,8 @@ class AppDslService:
# init draft workflow
environment_variables_list = workflow_data.get('environment_variables') or []
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
conversation_variables_list = workflow_data.get('conversation_variables') or []
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
workflow_service = WorkflowService()
draft_workflow = workflow_service.sync_draft_workflow(
app_model=app,
@ -246,6 +248,7 @@ class AppDslService:
unique_hash=None,
account=account,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
workflow_service.publish_workflow(
app_model=app,

View File

@ -6,7 +6,6 @@ from core.app.app_config.entities import (
DatasetRetrieveConfigEntity,
EasyUIBasedAppConfig,
ExternalDataVariableEntity,
FileExtraConfig,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
@ -14,6 +13,7 @@ from core.app.app_config.entities import (
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.file.file_obj import FileExtraConfig
from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.utils.encoders import jsonable_encoder

View File

@ -72,6 +72,7 @@ class WorkflowService:
unique_hash: Optional[str],
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
) -> Workflow:
"""
Sync draft workflow
@ -99,7 +100,8 @@ class WorkflowService:
graph=json.dumps(graph),
features=json.dumps(features),
created_by=account.id,
environment_variables=environment_variables
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
db.session.add(workflow)
# update draft workflow if found
@ -109,6 +111,7 @@ class WorkflowService:
workflow.updated_by = account.id
workflow.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
# commit db session changes
db.session.commit()
@ -145,7 +148,8 @@ class WorkflowService:
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
environment_variables=draft_workflow.environment_variables
environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables,
)
# commit db session changes
@ -336,8 +340,8 @@ class WorkflowService:
)
if not workflow_nodes:
return elapsed_time
for node in workflow_nodes:
elapsed_time += node.elapsed_time
return elapsed_time
return elapsed_time

View File

@ -1,8 +1,10 @@
import logging
import time
from collections.abc import Callable
import click
from celery import shared_task
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError
from extensions.ext_database import db
@ -28,7 +30,7 @@ from models.model import (
)
from models.tools import WorkflowToolProvider
from models.web import PinnedConversation, SavedMessage
from models.workflow import Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun
@shared_task(queue='app_deletion', bind=True, max_retries=3)
@ -54,6 +56,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
_delete_app_tag_bindings(tenant_id, app_id)
_delete_end_users(tenant_id, app_id)
_delete_trace_app_configs(tenant_id, app_id)
_delete_conversation_variables(app_id=app_id)
end_at = time.perf_counter()
logging.info(click.style(f'App and related data deleted: {app_id} latency: {end_at - start_at}', fg='green'))
@ -225,6 +228,13 @@ def _delete_app_conversations(tenant_id: str, app_id: str):
"conversation"
)
def _delete_conversation_variables(*, app_id: str):
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
with db.engine.connect() as conn:
conn.execute(stmt)
conn.commit()
logging.info(click.style(f"Deleted conversation variables for app {app_id}", fg='green'))
def _delete_app_messages(tenant_id: str, app_id: str):
def del_message(message_id: str):
@ -299,7 +309,7 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str):
)
def _delete_records(query_sql: str, params: dict, delete_func: callable, name: str) -> None:
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
while True:
with db.engine.begin() as conn:
rs = conn.execute(db.text(query_sql), params)

View File

@ -1,5 +1,4 @@
from api.core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter
from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter
class MockTEIClass:
@ -12,7 +11,7 @@ class MockTEIClass:
model_type = 'embedding'
return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
@staticmethod
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
# Use space as token separator, and split the text into tokens

View File

@ -1,12 +1,12 @@
import os
import pytest
from api.core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import (
HuggingfaceTeiTextEmbeddingModel,
TeiHelper,
)
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass

View File

@ -0,0 +1,59 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import (
OAICompatSpeech2TextModel,
)
def test_validate_credentials():
model = OAICompatSpeech2TextModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="whisper-1",
credentials={
"api_key": "invalid_key",
"endpoint_url": "https://api.openai.com/v1/"
},
)
model.validate_credentials(
model="whisper-1",
credentials={
"api_key": os.environ.get("OPENAI_API_KEY"),
"endpoint_url": "https://api.openai.com/v1/"
},
)
def test_invoke_model():
model = OAICompatSpeech2TextModel()
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, "audio.mp3")
# Open the file and get the file object
with open(audio_file_path, "rb") as audio_file:
file = audio_file
result = model.invoke(
model="whisper-1",
credentials={
"api_key": os.environ.get("OPENAI_API_KEY"),
"endpoint_url": "https://api.openai.com/v1/"
},
file=file,
user="abc-123",
)
assert isinstance(result, str)
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'

View File

@ -0,0 +1,53 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.siliconflow.speech2text.speech2text import SiliconflowSpeech2TextModel
def test_validate_credentials():
model = SiliconflowSpeech2TextModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="iic/SenseVoiceSmall",
credentials={
"api_key": "invalid_key"
},
)
model.validate_credentials(
model="iic/SenseVoiceSmall",
credentials={
"api_key": os.environ.get("API_KEY")
},
)
def test_invoke_model():
model = SiliconflowSpeech2TextModel()
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, "audio.mp3")
# Open the file and get the file object
with open(audio_file_path, "rb") as audio_file:
file = audio_file
result = model.invoke(
model="iic/SenseVoiceSmall",
credentials={
"api_key": os.environ.get("API_KEY")
},
file=file
)
assert isinstance(result, str)
assert result == '1,2,3,4,5,6,7,8,9,10.'

View File

@ -0,0 +1,25 @@
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
setup_mock_redis,
)
class ElasticSearchVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self.vector = ElasticSearchVector(
index_name=self.collection_name.lower(),
config=ElasticSearchConfig(
host='http://localhost',
port='9200',
username='elastic',
password='elastic'
),
attributes=self.attributes
)
def test_elasticsearch_vector(setup_mock_redis):
ElasticSearchVectorTest().run_all_tests()

View File

@ -7,15 +7,16 @@ from core.app.segments import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileSegment,
FileVariable,
FloatVariable,
IntegerVariable,
NoneSegment,
ObjectSegment,
SecretVariable,
StringVariable,
factory,
)
from core.app.segments.exc import VariableError
def test_string_variable():
@ -44,7 +45,7 @@ def test_secret_variable():
def test_invalid_value_type():
test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'}
with pytest.raises(ValueError):
with pytest.raises(VariableError):
factory.build_variable_from_mapping(test_data)
@ -77,26 +78,14 @@ def test_object_variable():
'name': 'test_object',
'description': 'Description of the variable.',
'value': {
'key1': {
'id': str(uuid4()),
'value_type': 'string',
'name': 'text',
'value': 'text',
'description': 'Description of the variable.',
},
'key2': {
'id': str(uuid4()),
'value_type': 'number',
'name': 'number',
'value': 1,
'description': 'Description of the variable.',
},
'key1': 'text',
'key2': 2,
},
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ObjectSegment)
assert isinstance(variable.value['key1'], StringVariable)
assert isinstance(variable.value['key2'], IntegerVariable)
assert isinstance(variable.value['key1'], str)
assert isinstance(variable.value['key2'], int)
def test_array_string_variable():
@ -106,26 +95,14 @@ def test_array_string_variable():
'name': 'test_array',
'description': 'Description of the variable.',
'value': [
{
'id': str(uuid4()),
'value_type': 'string',
'name': 'text',
'value': 'text',
'description': 'Description of the variable.',
},
{
'id': str(uuid4()),
'value_type': 'string',
'name': 'text',
'value': 'text',
'description': 'Description of the variable.',
},
'text',
'text',
],
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ArrayStringVariable)
assert isinstance(variable.value[0], StringVariable)
assert isinstance(variable.value[1], StringVariable)
assert isinstance(variable.value[0], str)
assert isinstance(variable.value[1], str)
def test_array_number_variable():
@ -135,26 +112,14 @@ def test_array_number_variable():
'name': 'test_array',
'description': 'Description of the variable.',
'value': [
{
'id': str(uuid4()),
'value_type': 'number',
'name': 'number',
'value': 1,
'description': 'Description of the variable.',
},
{
'id': str(uuid4()),
'value_type': 'number',
'name': 'number',
'value': 2.0,
'description': 'Description of the variable.',
},
1,
2.0,
],
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ArrayNumberVariable)
assert isinstance(variable.value[0], IntegerVariable)
assert isinstance(variable.value[1], FloatVariable)
assert isinstance(variable.value[0], int)
assert isinstance(variable.value[1], float)
def test_array_object_variable():
@ -165,59 +130,23 @@ def test_array_object_variable():
'description': 'Description of the variable.',
'value': [
{
'id': str(uuid4()),
'value_type': 'object',
'name': 'object',
'description': 'Description of the variable.',
'value': {
'key1': {
'id': str(uuid4()),
'value_type': 'string',
'name': 'text',
'value': 'text',
'description': 'Description of the variable.',
},
'key2': {
'id': str(uuid4()),
'value_type': 'number',
'name': 'number',
'value': 1,
'description': 'Description of the variable.',
},
},
'key1': 'text',
'key2': 1,
},
{
'id': str(uuid4()),
'value_type': 'object',
'name': 'object',
'description': 'Description of the variable.',
'value': {
'key1': {
'id': str(uuid4()),
'value_type': 'string',
'name': 'text',
'value': 'text',
'description': 'Description of the variable.',
},
'key2': {
'id': str(uuid4()),
'value_type': 'number',
'name': 'number',
'value': 1,
'description': 'Description of the variable.',
},
},
'key1': 'text',
'key2': 1,
},
],
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ArrayObjectVariable)
assert isinstance(variable.value[0], ObjectSegment)
assert isinstance(variable.value[1], ObjectSegment)
assert isinstance(variable.value[0].value['key1'], StringVariable)
assert isinstance(variable.value[0].value['key2'], IntegerVariable)
assert isinstance(variable.value[1].value['key1'], StringVariable)
assert isinstance(variable.value[1].value['key2'], IntegerVariable)
assert isinstance(variable.value[0], dict)
assert isinstance(variable.value[1], dict)
assert isinstance(variable.value[0]['key1'], str)
assert isinstance(variable.value[0]['key2'], int)
assert isinstance(variable.value[1]['key1'], str)
assert isinstance(variable.value[1]['key2'], int)
def test_file_variable():
@ -257,51 +186,53 @@ def test_array_file_variable():
'value': [
{
'id': str(uuid4()),
'name': 'file',
'value_type': 'file',
'value': {
'id': str(uuid4()),
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
{
'id': str(uuid4()),
'name': 'file',
'value_type': 'file',
'value': {
'id': str(uuid4()),
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
],
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ArrayFileVariable)
assert isinstance(variable.value[0], FileVariable)
assert isinstance(variable.value[1], FileVariable)
assert isinstance(variable.value[0], FileSegment)
assert isinstance(variable.value[1], FileSegment)
def test_variable_cannot_large_than_5_kb():
with pytest.raises(VariableError):
factory.build_variable_from_mapping(
{
'id': str(uuid4()),
'value_type': 'string',
'name': 'test_text',
'value': 'a' * 1024 * 6,
}
)

View File

@ -2,8 +2,8 @@ from unittest.mock import MagicMock
import pytest
from core.app.app_config.entities import FileExtraConfig, ModelConfigEntity
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.app.app_config.entities import ModelConfigEntity
from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, UserPromptMessage
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform

View File

@ -0,0 +1,150 @@
from unittest import mock
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.segments import ArrayStringVariable, StringVariable
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
DEFAULT_NODE_ID = 'node_id'
def test_overwrite_string_variable():
conversation_variable = StringVariable(
id=str(uuid4()),
name='test_conversation_variable',
value='the first value',
)
input_variable = StringVariable(
id=str(uuid4()),
name='test_string_variable',
value='the second value',
)
node = VariableAssignerNode(
tenant_id='tenant_id',
app_id='app_id',
workflow_id='workflow_id',
user_id='user_id',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'node_id',
'data': {
'assigned_variable_selector': ['conversation', conversation_variable.name],
'write_mode': WriteMode.OVER_WRITE.value,
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
},
},
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
variable_pool.add(
[DEFAULT_NODE_ID, input_variable.name],
input_variable,
)
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
node.run(variable_pool)
mock_run.assert_called_once()
got = variable_pool.get(['conversation', conversation_variable.name])
assert got is not None
assert got.value == 'the second value'
assert got.to_object() == 'the second value'
def test_append_variable_to_array():
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name='test_conversation_variable',
value=['the first value'],
)
input_variable = StringVariable(
id=str(uuid4()),
name='test_string_variable',
value='the second value',
)
node = VariableAssignerNode(
tenant_id='tenant_id',
app_id='app_id',
workflow_id='workflow_id',
user_id='user_id',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'node_id',
'data': {
'assigned_variable_selector': ['conversation', conversation_variable.name],
'write_mode': WriteMode.APPEND.value,
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
},
},
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
variable_pool.add(
[DEFAULT_NODE_ID, input_variable.name],
input_variable,
)
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
node.run(variable_pool)
mock_run.assert_called_once()
got = variable_pool.get(['conversation', conversation_variable.name])
assert got is not None
assert got.to_object() == ['the first value', 'the second value']
def test_clear_array():
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name='test_conversation_variable',
value=['the first value'],
)
node = VariableAssignerNode(
tenant_id='tenant_id',
app_id='app_id',
workflow_id='workflow_id',
user_id='user_id',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'node_id',
'data': {
'assigned_variable_selector': ['conversation', conversation_variable.name],
'write_mode': WriteMode.CLEAR.value,
'input_variable_selector': [],
},
},
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
node.run(variable_pool)
got = variable_pool.get(['conversation', conversation_variable.name])
assert got is not None
assert got.to_object() == []

View File

@ -0,0 +1,25 @@
from uuid import uuid4
from core.app.segments import SegmentType, factory
from models import ConversationVariable
def test_from_variable_and_to_variable():
variable = factory.build_variable_from_mapping(
{
'id': str(uuid4()),
'name': 'name',
'value_type': SegmentType.OBJECT,
'value': {
'key': {
'key': 'value',
}
},
}
)
conversation_variable = ConversationVariable.from_variable(
app_id='app_id', conversation_id='conversation_id', variable=variable
)
assert conversation_variable.to_variable() == variable

View File

@ -7,4 +7,5 @@ pytest api/tests/integration_tests/vdb/chroma \
api/tests/integration_tests/vdb/pgvector \
api/tests/integration_tests/vdb/qdrant \
api/tests/integration_tests/vdb/weaviate \
api/tests/integration_tests/vdb/elasticsearch \
api/tests/integration_tests/vdb/test_vector_store.py

View File

@ -2,7 +2,7 @@ version: '3'
services:
# API service
api:
image: langgenius/dify-api:0.6.16
image: langgenius/dify-api:0.7.0
restart: always
environment:
# Startup mode, 'api' starts the API server.
@ -169,6 +169,11 @@ services:
CHROMA_DATABASE: default_database
CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider
CHROMA_AUTH_CREDENTIALS: xxxxxx
# ElasticSearch Config
ELASTICSEARCH_HOST: 127.0.0.1
ELASTICSEARCH_PORT: 9200
ELASTICSEARCH_USERNAME: elastic
ELASTICSEARCH_PASSWORD: elastic
# Mail configuration, support: resend, smtp
MAIL_TYPE: ''
# default send from email address, if not specified
@ -224,7 +229,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.6.16
image: langgenius/dify-api:0.7.0
restart: always
environment:
CONSOLE_WEB_URL: ''
@ -371,6 +376,11 @@ services:
CHROMA_DATABASE: default_database
CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider
CHROMA_AUTH_CREDENTIALS: xxxxxx
# ElasticSearch Config
ELASTICSEARCH_HOST: 127.0.0.1
ELASTICSEARCH_PORT: 9200
ELASTICSEARCH_USERNAME: elastic
ELASTICSEARCH_PASSWORD: elastic
# Notion import configuration, support public and internal
NOTION_INTEGRATION_TYPE: public
NOTION_CLIENT_SECRET: you-client-secret
@ -390,7 +400,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.6.16
image: langgenius/dify-web:0.7.0
restart: always
environment:
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is

View File

@ -125,6 +125,10 @@ x-shared-env: &shared-api-worker-env
CHROMA_DATABASE: ${CHROMA_DATABASE:-default_database}
CHROMA_AUTH_PROVIDER: ${CHROMA_AUTH_PROVIDER:-chromadb.auth.token_authn.TokenAuthClientProvider}
CHROMA_AUTH_CREDENTIALS: ${CHROMA_AUTH_CREDENTIALS:-}
ELASTICSEARCH_HOST: ${ELASTICSEARCH_HOST:-127.0.0.1}
ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200}
ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic}
ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic}
# AnalyticDB configuration
ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-}
ANALYTICDB_KEY_SECRET: ${ANALYTICDB_KEY_SECRET:-}
@ -187,7 +191,7 @@ x-shared-env: &shared-api-worker-env
services:
# API service
api:
image: langgenius/dify-api:0.6.16
image: langgenius/dify-api:0.7.0
restart: always
environment:
# Use the shared environment variables.
@ -207,7 +211,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.6.16
image: langgenius/dify-api:0.7.0
restart: always
environment:
# Use the shared environment variables.
@ -226,12 +230,13 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.6.16
image: langgenius/dify-web:0.7.0
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
APP_API_URL: ${APP_API_URL:-}
SENTRY_DSN: ${WEB_SENTRY_DSN:-}
NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0}
# The postgres database.
db:
@ -582,7 +587,7 @@ services:
# MyScale vector database
myscale:
container_name: myscale
image: myscale/myscaledb:1.6
image: myscale/myscaledb:1.6.4
profiles:
- myscale
restart: always
@ -594,6 +599,27 @@ services:
ports:
- "${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123}"
elasticsearch:
image: docker.elastic.co/elasticsearch/elasticsearch:8.14.3
container_name: elasticsearch
profiles:
- elasticsearch
restart: always
environment:
- "ELASTIC_PASSWORD=${ELASTICSEARCH_USERNAME:-elastic}"
- "cluster.name=dify-es-cluster"
- "node.name=dify-es0"
- "discovery.type=single-node"
- "xpack.security.http.ssl.enabled=false"
- "xpack.license.self_generated.type=trial"
ports:
- "${ELASTICSEARCH_PORT:-9200}:${ELASTICSEARCH_PORT:-9200}"
healthcheck:
test: ["CMD", "curl", "-s", "http://localhost:9200/_cluster/health?pretty"]
interval: 30s
timeout: 10s
retries: 50
# unstructured .
# (if used, you need to set ETL_TYPE to Unstructured in the api & worker service.)
unstructured:

View File

@ -13,3 +13,6 @@ NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api
# SENTRY
NEXT_PUBLIC_SENTRY_DSN=
# Disable Next.js Telemetry (https://nextjs.org/telemetry)
NEXT_TELEMETRY_DISABLED=1

View File

@ -39,6 +39,7 @@ ENV DEPLOY_ENV=PRODUCTION
ENV CONSOLE_API_URL=http://127.0.0.1:5001
ENV APP_API_URL=http://127.0.0.1:5001
ENV PORT=3000
ENV NEXT_TELEMETRY_DISABLED=1
# set timezone
ENV TZ=UTC

Some files were not shown because too many files have changed in this diff Show More