Compare commits

..

47 Commits

Author SHA1 Message Date
Yi
61b54c4b0e fix typo in README.md 2024-08-22 14:45:35 +08:00
fef4e09dfc docs: update certbot/README.md (#7528) 2024-08-22 13:36:15 +08:00
60ef7ba855 fix: add missed modifications of <AppIcon /> (#7512) 2024-08-22 13:32:59 +08:00
6f968bafb2 feat: update the "tag delete" confirm modal (#7522) 2024-08-22 11:33:20 +08:00
9f6aab11d4 fix: tag input state lost issue (#7500) 2024-08-22 10:26:09 +08:00
0006c6f0fd fix(storage): 🐛 Create S3 bucket if it doesn't exist (#7514)
Co-authored-by: 莫岳恒 <moyueheng@datagrand.com>
2024-08-22 09:45:42 +08:00
2c427e04be Feat/7134 use dataset api create a dataset with permission (#7508) 2024-08-21 20:25:45 +08:00
f53454f81d add finish_reason to the LLM node output (#7498) 2024-08-21 17:29:30 +08:00
784b11ce19 Chore/remove python dependencies selector (#7494) 2024-08-21 16:57:14 +08:00
715eb8fa32 fix rerank mode is none (#7496) 2024-08-21 16:42:28 +08:00
a02118d5bc Fix/incorrect code template (#7490) 2024-08-21 15:31:13 +08:00
85fc0fdb51 chore: support CODE_MAX_PRECISION (#7484) 2024-08-21 15:11:56 +08:00
f7af8c7cc7 feat: gpt-4o-mini-2024-07-18 support json schema (#7489) 2024-08-21 15:11:29 +08:00
0c99a3d0c5 fix the issue of the refine_switches at param being invalid in the Novita.AI tool (#7485) 2024-08-21 15:09:05 +08:00
66dfb5c89a fix: json schema not saved correctly (#7487) 2024-08-21 14:58:14 +08:00
6435b4eb44 Separate CODE_MAX_DEPTH and set it as an environment variable (#7474) 2024-08-21 12:48:25 +08:00
4e7b6aec3a feat: support pinning, including, and excluding for model providers and tools (#7419)
Co-authored-by: GareArc <chen4851@purude.edu>
2024-08-21 11:16:43 +08:00
6c25d7bed3 chore: improve the copywrite of the assigner node append mode description (#7467) 2024-08-21 10:34:25 +08:00
028fd52c9b fix: image icon not showing correctly on left panel in workflow web app page (#7466) 2024-08-21 10:29:16 +08:00
9a715f6b68 fix(tool): tool node error (#7459)
Co-authored-by: hobo.l <hobo.l@binance.com>
2024-08-21 09:04:54 +08:00
8c32f8c77d chore: #7348 i18n (#7451) 2024-08-21 09:03:51 +08:00
b7778de224 fix: document error message can not be cleared (#7453) 2024-08-20 19:30:57 +08:00
c70d69322b feat: support dialogue count in chatflow (#7440) 2024-08-20 18:28:39 +08:00
e35e251863 feat: Sort conversations by updated_at desc (#7348)
Co-authored-by: wangpj <wangpj@hundsunc.om>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-08-20 17:55:44 +08:00
eae53e11e6 refactor(api/models/workflow.py): Add __init__ to Workflow (#7443) 2024-08-20 17:52:21 +08:00
4f5f27cf2b refactor(api/core/workflow/enums.py): Rename SystemVariable to SystemVariableKey. (#7445) 2024-08-20 17:52:06 +08:00
5e42e90abc fix(api/services/workflow/workflow_converter.py): Add NoneType checkers & format file. (#7446) 2024-08-20 17:51:49 +08:00
a10b207de2 refactor(api/core/app/app_config/entities.py): Move Type to outside and add EXTERNAL_DATA_TOOL. (#7444) 2024-08-20 17:30:14 +08:00
e2d214e030 chore: add and update theme related css variables values (#7442) 2024-08-20 16:33:40 +08:00
4f64a5d36d refactor(api/core/workflow/nodes/variable_assigner): Split into multi files. (#7434) 2024-08-20 15:40:19 +08:00
0d4753785f chore: remove .idea and .vscode from root path (#7437) 2024-08-20 15:37:29 +08:00
2e9084f369 chore(database): Rename table name from workflow__conversation_variables to workflow_conversation_variables. (#7432) 2024-08-20 14:34:03 +08:00
0f90e6df75 add pgvector full text search settting (#7427) 2024-08-20 13:20:19 +08:00
53146ad685 feat: support line break of tooltip content (#7424) 2024-08-20 11:03:55 +08:00
0223fc6fd5 feat: add pgvector full_text_search (#7396) 2024-08-20 11:01:13 +08:00
218380ba43 fix:end of day (#7426) 2024-08-20 10:57:33 +08:00
afd23f7ad8 chore: #7196 i18n (#7416) 2024-08-20 10:21:24 +08:00
6991a243aa chore: correct _tts_invoke_streaming max length (#7423) 2024-08-20 10:20:04 +08:00
1f944c6eeb feat(api): support wenxin bge-large and tao embedding model. (#7393) 2024-08-19 22:25:09 +08:00
31f9977411 Web app support sending message using numpad enter (#7414) 2024-08-19 22:24:21 +08:00
3d27d15f00 chore(*): Bump version 0.7.1 (#7389) 2024-08-19 21:24:56 +08:00
ab6499e5b7 upgrade: sandbox to 0.2.6 (#7410) 2024-08-19 21:24:15 +08:00
4ff4859036 add CrossRef builtin tool: doi query and title query (#7406) 2024-08-19 19:14:20 +08:00
53cf756207 feat: OpenRouter add gpt-4o-2024-08-06 model (#7409) 2024-08-19 19:14:08 +08:00
0087afc2e3 fix(api/core/model_runtime/model_providers/__base/large_language_model.py): Add TEXT type checker (#7407) 2024-08-19 18:45:30 +08:00
bd07e1d2fd fix:start of the period should be YYYY-MM-DD 00:00 (#7371) 2024-08-19 18:12:41 +08:00
8b06105fa1 Feat: shortcut hook (#7385) 2024-08-19 18:11:11 +08:00
193 changed files with 3482 additions and 2410 deletions

1
.gitignore vendored
View File

@ -178,3 +178,4 @@ pyrightconfig.json
api/.vscode
.idea/
.vscode

View File

@ -4,7 +4,7 @@
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Self-hosting</a> ·
<a href="https://docs.dify.ai">Documentation</a> ·
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">Enterprise inquiry</a>
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">Enterprise Inquiry</a>
</p>
<p align="center">
@ -41,41 +41,36 @@
<a href="./README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
</p>
Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features:
Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, Agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features:
</br> </br>
**1. Workflow**:
Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond.
**1. Workflow**:
Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond.
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
**2. Comprehensive model support**:
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
**2. Comprehensive model support**:
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3)
**3. Prompt IDE**:
Intuitive interface for crafting prompts, comparing model performance, and adding additional features such as text-to-speech to a chat-based app.
**3. Prompt IDE**:
Intuitive interface for crafting prompts, comparing model performance, and adding additional features such as text-to-speech to a chat-based app.
**4. RAG Pipeline**:
Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats.
**4. RAG Pipeline**:
Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats.
**5. Agent capabilities**:
You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DALL·E, Stable Diffusion and WolframAlpha.
**5. Agent capabilities**:
You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DALL·E, Stable Diffusion and WolframAlpha.
**6. LLMOps**:
Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations.
**7. Backend-as-a-Service**:
All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic.
**6. LLMOps**:
Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations.
**7. Backend-as-a-Service**:
All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic.
## Feature comparison
<table style="width: 100%;">
<tr>
<th align="center">Feature</th>
@ -145,30 +140,28 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
## Using Dify
- **Cloud </br>**
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
- **Self-hosting Dify Community Edition</br>**
Quickly get Dify running in your environment with this [starter guide](#quick-start).
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
Quickly get Dify running in your environment with this [starter guide](#quick-start).
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
- **Dify for enterprise / organizations</br>**
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs. </br>
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs. </br>
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
## Staying ahead
Star Dify on GitHub and be instantly notified of new releases.
![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4)
## Quick start
> Before installing Dify, make sure your machine meets the following minimum system requirements:
>
>- CPU >= 2 Core
>- RAM >= 4GB
>
> - CPU >= 2 Core
> - RAM >= 4GB
</br>
@ -197,15 +190,16 @@ If you'd like to configure a highly-available setup, there are community-contrib
#### Using Terraform for Deployment
##### Azure Global
Deploy Dify to Azure with a single click using [terraform](https://www.terraform.io/).
- [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform)
## Contributing
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
At the same time, please consider supporting Dify by sharing it on social media and at events and conferences.
> We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c).
**Contributors**
@ -216,16 +210,15 @@ At the same time, please consider supporting Dify by sharing it on social media
## Community & contact
* [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions.
* [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
* [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.
* [Twitter](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community.
- [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions.
- [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
- [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.
- [Twitter](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community.
## Star history
[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date)
## Security disclosure
To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer.

View File

@ -247,8 +247,8 @@ API_TOOL_DEFAULT_READ_TIMEOUT=60
HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300
HTTP_REQUEST_MAX_READ_TIMEOUT=600
HTTP_REQUEST_MAX_WRITE_TIMEOUT=600
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 # 10MB
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 # 1MB
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
# Log file path
LOG_FILE=
@ -267,4 +267,13 @@ APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
CELERY_BEAT_SCHEDULER_TIME=1
# Position configuration
POSITION_TOOL_PINS=
POSITION_TOOL_INCLUDES=
POSITION_TOOL_EXCLUDES=
POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES=

View File

Before

Width:  |  Height:  |  Size: 1.7 KiB

After

Width:  |  Height:  |  Size: 1.7 KiB

View File

View File

@ -5,8 +5,8 @@
"name": "Python: Flask",
"type": "debugpy",
"request": "launch",
"python": "${workspaceFolder}/api/.venv/bin/python",
"cwd": "${workspaceFolder}/api",
"python": "${workspaceFolder}/.venv/bin/python",
"cwd": "${workspaceFolder}",
"envFile": ".env",
"module": "flask",
"justMyCode": true,
@ -18,15 +18,15 @@
"args": [
"run",
"--host=0.0.0.0",
"--port=5001",
"--port=5001"
]
},
{
"name": "Python: Celery",
"type": "debugpy",
"request": "launch",
"python": "${workspaceFolder}/api/.venv/bin/python",
"cwd": "${workspaceFolder}/api",
"python": "${workspaceFolder}/.venv/bin/python",
"cwd": "${workspaceFolder}",
"module": "celery",
"justMyCode": true,
"envFile": ".env",

View File

@ -37,6 +37,8 @@ class DifyConfig(
CODE_MAX_NUMBER: int = 9223372036854775807
CODE_MIN_NUMBER: int = -9223372036854775808
CODE_MAX_DEPTH: int = 5
CODE_MAX_PRECISION: int = 20
CODE_MAX_STRING_LENGTH: int = 80000
CODE_MAX_STRING_ARRAY_LENGTH: int = 30
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30

View File

@ -406,6 +406,7 @@ class DataSetConfig(BaseSettings):
default=False,
)
class WorkspaceConfig(BaseSettings):
"""
Workspace configs
@ -442,6 +443,63 @@ class CeleryBeatConfig(BaseSettings):
)
class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description='The heads of model providers',
default='',
)
POSITION_PROVIDER_INCLUDES: str = Field(
description='The included model providers',
default='',
)
POSITION_PROVIDER_EXCLUDES: str = Field(
description='The excluded model providers',
default='',
)
POSITION_TOOL_PINS: str = Field(
description='The heads of tools',
default='',
)
POSITION_TOOL_INCLUDES: str = Field(
description='The included tools',
default='',
)
POSITION_TOOL_EXCLUDES: str = Field(
description='The excluded tools',
default='',
)
@computed_field
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != '']
@computed_field
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != '']
@computed_field
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''}
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -466,6 +524,7 @@ class FeatureConfig(
UpdateConfig,
WorkflowConfig,
WorkspaceConfig,
PositionConfig,
# hosted services config
HostedServiceConfig,

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description='Dify version',
default='0.7.0',
default='0.7.1',
)
COMMIT_SHA: str = Field(

View File

@ -154,6 +154,8 @@ class ChatConversationApi(Resource):
parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args')
parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
required=False, default='-updated_at', location='args')
args = parser.parse_args()
subquery = (
@ -225,7 +227,17 @@ class ChatConversationApi(Resource):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
query = query.order_by(Conversation.created_at.desc())
match args['sort_by']:
case 'created_at':
query = query.order_by(Conversation.created_at.asc())
case '-created_at':
query = query.order_by(Conversation.created_at.desc())
case 'updated_at':
query = query.order_by(Conversation.updated_at.asc())
case '-updated_at':
query = query.order_by(Conversation.updated_at.desc())
case _:
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(
query,

View File

@ -24,7 +24,7 @@ from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields
from libs.login import login_required
from models.dataset import Dataset, Document, DocumentSegment
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
from models.model import ApiToken, UploadFile
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@ -202,7 +202,7 @@ class DatasetApi(Resource):
nullable=True,
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members', 'partial_members'), help='Invalid permission.'
DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.'
)
parser.add_argument('embedding_model', type=str,
location='json', help='Invalid embedding model.')
@ -239,7 +239,7 @@ class DatasetApi(Resource):
tenant_id, dataset_id_str, data.get('partial_member_list')
)
# clear partial member list when permission is only_me or all_team_members
elif data.get('permission') == 'only_me' or data.get('permission') == 'all_team_members':
elif data.get('permission') == DatasetPermissionEnum.ONLY_ME or data.get('permission') == DatasetPermissionEnum.ALL_TEAM:
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@ -573,13 +573,13 @@ class DatasetRetrievalSettingMockApi(Resource):
@account_initialization_required
def get(self, vector_type):
match vector_type:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value,

View File

@ -25,6 +25,8 @@ class ConversationApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
required=False, default='-updated_at', location='args')
args = parser.parse_args()
try:
@ -33,7 +35,8 @@ class ConversationApi(Resource):
user=end_user,
last_id=args['last_id'],
limit=args['limit'],
invoke_from=InvokeFrom.SERVICE_API
invoke_from=InvokeFrom.SERVICE_API,
sort_by=args['sort_by']
)
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

View File

@ -10,7 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
from libs.login import current_user
from models.dataset import Dataset
from models.dataset import Dataset, DatasetPermissionEnum
from services.dataset_service import DatasetService
@ -78,6 +78,8 @@ class DatasetListApi(DatasetApiResource):
parser.add_argument('indexing_technique', type=str, location='json',
choices=Dataset.INDEXING_TECHNIQUE_LIST,
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.', required=False, nullable=False)
args = parser.parse_args()
try:
@ -85,7 +87,8 @@ class DatasetListApi(DatasetApiResource):
tenant_id=tenant_id,
name=args['name'],
indexing_technique=args['indexing_technique'],
account=current_user
account=current_user,
permission=args['permission']
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()

View File

@ -26,6 +26,8 @@ class ConversationListApi(WebApiResource):
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
required=False, default='-updated_at', location='args')
args = parser.parse_args()
pinned = None
@ -40,6 +42,7 @@ class ConversationListApi(WebApiResource):
limit=args['limit'],
invoke_from=InvokeFrom.WEB_APP,
pinned=pinned,
sort_by=args['sort_by']
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

View File

@ -93,7 +93,7 @@ class DatasetConfigManager:
reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get('weights'),
reranking_enabled=dataset_configs.get('reranking_enabled', True),
rerank_mode=dataset_configs["reranking_mode"],
rerank_mode=dataset_configs.get('rerank_mode', 'reranking_model'),
)
)

View File

@ -1,6 +1,6 @@
import re
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.external_data_tool.factory import ExternalDataToolFactory
@ -13,7 +13,7 @@ class BasicVariablesConfigManager:
:param config: model config args
"""
external_data_variables = []
variables = []
variable_entities = []
# old external_data_tools
external_data_tools = config.get('external_data_tools', [])
@ -30,50 +30,41 @@ class BasicVariablesConfigManager:
)
# variables and external_data_tools
for variable in config.get('user_input_form', []):
typ = list(variable.keys())[0]
if typ == 'external_data_tool':
val = variable[typ]
if 'config' not in val:
for variables in config.get('user_input_form', []):
variable_type = list(variables.keys())[0]
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
variable = variables[variable_type]
if 'config' not in variable:
continue
external_data_variables.append(
ExternalDataVariableEntity(
variable=val['variable'],
type=val['type'],
config=val['config']
variable=variable['variable'],
type=variable['type'],
config=variable['config']
)
)
elif typ in [
VariableEntity.Type.TEXT_INPUT.value,
VariableEntity.Type.PARAGRAPH.value,
VariableEntity.Type.NUMBER.value,
elif variable_type in [
VariableEntityType.TEXT_INPUT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.SELECT,
]:
variables.append(
variable = variables[variable_type]
variable_entities.append(
VariableEntity(
type=VariableEntity.Type.value_of(typ),
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
max_length=variable[typ].get('max_length'),
default=variable[typ].get('default'),
)
)
elif typ == VariableEntity.Type.SELECT.value:
variables.append(
VariableEntity(
type=VariableEntity.Type.SELECT,
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
options=variable[typ].get('options'),
default=variable[typ].get('default'),
type=variable_type,
variable=variable.get('variable'),
description=variable.get('description'),
label=variable.get('label'),
required=variable.get('required', False),
max_length=variable.get('max_length'),
options=variable.get('options'),
default=variable.get('default'),
)
)
return variables, external_data_variables
return variable_entities, external_data_variables
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
@ -183,4 +174,4 @@ class BasicVariablesConfigManager:
config=config
)
return config, ["external_data_tools"]
return config, ["external_data_tools"]

View File

@ -82,43 +82,29 @@ class PromptTemplateEntity(BaseModel):
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
class VariableEntityType(str, Enum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external-data-tool"
class VariableEntity(BaseModel):
"""
Variable Entity.
"""
class Type(Enum):
TEXT_INPUT = 'text-input'
SELECT = 'select'
PARAGRAPH = 'paragraph'
NUMBER = 'number'
@classmethod
def value_of(cls, value: str) -> 'VariableEntity.Type':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid variable type value {value}')
variable: str
label: str
description: Optional[str] = None
type: Type
type: VariableEntityType
required: bool = False
max_length: Optional[int] = None
options: Optional[list[str]] = None
default: Optional[str] = None
hint: Optional[str] = None
@property
def name(self) -> str:
return self.variable
class ExternalDataVariableEntity(BaseModel):
"""
@ -252,4 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig):
"""
Workflow UI Based App Config Entity.
"""
workflow_id: str
workflow_id: str

View File

@ -29,7 +29,7 @@ from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
@ -46,7 +46,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[dict, None, None]]:
):
"""
Generate App response.
@ -73,8 +73,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# parse files
files = args['files'] if args.get('files') else []
@ -133,8 +134,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
node_id: str,
user: Account,
args: dict,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
stream: bool = True):
"""
Generate App response.
@ -157,8 +157,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
@ -200,8 +201,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation | None = None,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
stream: bool = True):
is_first_conversation = False
if not conversation:
is_first_conversation = True
@ -270,11 +270,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation_id,
SystemVariable.USER_ID: user_id,
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
@ -362,7 +362,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
if os.environ.get("DEBUG", "false").lower() == 'true':
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:

View File

@ -49,7 +49,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from events.message_event import message_was_created
@ -74,7 +74,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow: Workflow
_user: Union[Account, EndUser]
# Deprecated
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(
@ -108,10 +108,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message = message
# Deprecated
self._workflow_system_variables = {
SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id,
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_id,
}
self._task_state = AdvancedChatTaskState(

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.app.app_config.entities import AppConfig, VariableEntity
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
class BaseAppGenerator:
@ -9,29 +9,29 @@ class BaseAppGenerator:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
return filtered_inputs
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.name)
user_input_value = inputs.get(var.variable)
if var.required and not user_input_value:
raise ValueError(f'{var.name} is required in input form')
raise ValueError(f'{var.variable} is required in input form')
if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None?
return var.default or ''
if (
var.type
in (
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.SELECT,
VariableEntity.Type.PARAGRAPH,
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
)
and user_input_value
and not isinstance(user_input_value, str)
):
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number
try:
if '.' in user_input_value:
@ -39,14 +39,14 @@ class BaseAppGenerator:
else:
return int(user_input_value)
except ValueError:
raise ValueError(f"{var.name} in input form must be a valid number")
if var.type == VariableEntity.Type.SELECT:
raise ValueError(f"{var.variable} in input form must be a valid number")
if var.type == VariableEntityType.SELECT:
options = var.options or []
if user_input_value not in options:
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
return user_input_value

View File

@ -1,6 +1,7 @@
import json
import logging
from collections.abc import Generator
from datetime import datetime, timezone
from typing import Optional, Union
from sqlalchemy import and_
@ -36,17 +37,17 @@ logger = logging.getLogger(__name__)
class MessageBasedAppGenerator(BaseAppGenerator):
def _handle_response(
self, application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
self, application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
@ -193,6 +194,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.add(conversation)
db.session.commit()
db.session.refresh(conversation)
else:
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
message = Message(
app_id=app_config.app_id,

View File

@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
@ -67,8 +67,8 @@ class WorkflowAppRunner:
# Create a variable pool.
system_inputs = {
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id,
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,

View File

@ -43,7 +43,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
@ -67,7 +67,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
@ -92,8 +92,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._workflow = workflow
self._workflow_system_variables = {
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.USER_ID: user_id
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_id
}
self._task_state = WorkflowTaskState(

View File

@ -2,7 +2,7 @@ from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from models.account import Account
from models.model import EndUser
from models.workflow import Workflow
@ -13,4 +13,4 @@ class WorkflowCycleStateManager:
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]

View File

@ -1,15 +1,13 @@
import logging
import time
from enum import Enum
from threading import Lock
from typing import Literal, Optional
from typing import Optional
from httpx import get, post
from httpx import Timeout, post
from pydantic import BaseModel
from yarl import URL
from configs import dify_config
from core.helper.code_executor.entities import CodeDependency
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
@ -21,7 +19,7 @@ logger = logging.getLogger(__name__)
CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT
CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY
CODE_EXECUTION_TIMEOUT = (10, 60)
CODE_EXECUTION_TIMEOUT = Timeout(connect=10, write=10, read=60, pool=None)
class CodeExecutionException(Exception):
pass
@ -66,8 +64,7 @@ class CodeExecutor:
def execute_code(cls,
language: CodeLanguage,
preload: str,
code: str,
dependencies: Optional[list[CodeDependency]] = None) -> str:
code: str) -> str:
"""
Execute code
:param language: code language
@ -87,9 +84,6 @@ class CodeExecutor:
'enable_network': True
}
if dependencies:
data['dependencies'] = [dependency.model_dump() for dependency in dependencies]
try:
response = post(str(url), json=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT)
if response.status_code == 503:
@ -116,10 +110,10 @@ class CodeExecutor:
if response.data.error:
raise CodeExecutionException(response.data.error)
return response.data.stdout
return response.data.stdout or ''
@classmethod
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict:
"""
Execute code
:param language: code language
@ -131,67 +125,12 @@ class CodeExecutor:
if not template_transformer:
raise CodeExecutionException(f'Unsupported language {language}')
runner, preload, dependencies = template_transformer.transform_caller(code, inputs, dependencies)
runner, preload = template_transformer.transform_caller(code, inputs)
try:
response = cls.execute_code(language, preload, runner, dependencies)
response = cls.execute_code(language, preload, runner)
except CodeExecutionException as e:
raise e
return template_transformer.transform_response(response)
@classmethod
def list_dependencies(cls, language: str) -> list[CodeDependency]:
if language not in cls.supported_dependencies_languages:
return []
with cls.dependencies_cache_lock:
if language in cls.dependencies_cache:
# check expiration
dependencies = cls.dependencies_cache[language]
if dependencies['expiration'] > time.time():
return dependencies['data']
# remove expired cache
del cls.dependencies_cache[language]
dependencies = cls._get_dependencies(language)
with cls.dependencies_cache_lock:
cls.dependencies_cache[language] = {
'data': dependencies,
'expiration': time.time() + 60
}
return dependencies
@classmethod
def _get_dependencies(cls, language: Literal['python3']) -> list[CodeDependency]:
"""
List dependencies
"""
url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'dependencies'
headers = {
'X-Api-Key': CODE_EXECUTION_API_KEY
}
running_language = cls.code_language_to_running_language.get(language)
if isinstance(running_language, Enum):
running_language = running_language.value
data = {
'language': running_language,
}
try:
response = get(str(url), params=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT)
if response.status_code != 200:
raise Exception(f'Failed to list dependencies, got status code {response.status_code}, please check if the sandbox service is running')
response = response.json()
dependencies = response.get('data', {}).get('dependencies', [])
return [
CodeDependency(**dependency) for dependency in dependencies
if dependency.get('name') not in Python3TemplateTransformer.get_standard_packages()
]
except Exception as e:
logger.exception(f'Failed to list dependencies: {e}')
return []

View File

@ -2,8 +2,6 @@ from abc import abstractmethod
from pydantic import BaseModel
from core.helper.code_executor.code_executor import CodeExecutor
class CodeNodeProvider(BaseModel):
@staticmethod
@ -23,10 +21,6 @@ class CodeNodeProvider(BaseModel):
"""
pass
@classmethod
def get_default_available_packages(cls) -> list[dict]:
return [p.model_dump() for p in CodeExecutor.list_dependencies(cls.get_language())]
@classmethod
def get_default_config(cls) -> dict:
return {
@ -50,6 +44,5 @@ class CodeNodeProvider(BaseModel):
"children": None
}
}
},
"available_dependencies": cls.get_default_available_packages(),
}
}

View File

@ -1,6 +0,0 @@
from pydantic import BaseModel
class CodeDependency(BaseModel):
name: str
version: str

View File

@ -3,7 +3,7 @@ from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
class Jinja2Formatter:
@classmethod
def format(cls, template: str, inputs: str) -> str:
def format(cls, template: str, inputs: dict) -> str:
"""
Format template
:param template: template

View File

@ -1,14 +1,9 @@
from textwrap import dedent
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
from core.helper.code_executor.template_transformer import TemplateTransformer
class Jinja2TemplateTransformer(TemplateTransformer):
@classmethod
def get_standard_packages(cls) -> set[str]:
return {'jinja2'} | Python3TemplateTransformer.get_standard_packages()
@classmethod
def transform_response(cls, response: str) -> dict:
"""

View File

@ -13,7 +13,7 @@ class Python3CodeProvider(CodeNodeProvider):
def get_default_code(cls) -> str:
return dedent(
"""
def main(arg1: int, arg2: int) -> dict:
def main(arg1: str, arg2: str) -> dict:
return {
"result": arg1 + arg2,
}

View File

@ -4,30 +4,6 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class Python3TemplateTransformer(TemplateTransformer):
@classmethod
def get_standard_packages(cls) -> set[str]:
return {
'base64',
'binascii',
'collections',
'datetime',
'functools',
'hashlib',
'hmac',
'itertools',
'json',
'math',
'operator',
'os',
'random',
're',
'string',
'sys',
'time',
'traceback',
'uuid',
}
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(f"""

View File

@ -2,9 +2,6 @@ import json
import re
from abc import ABC, abstractmethod
from base64 import b64encode
from typing import Optional
from core.helper.code_executor.entities import CodeDependency
class TemplateTransformer(ABC):
@ -13,12 +10,7 @@ class TemplateTransformer(ABC):
_result_tag: str = '<<RESULT>>'
@classmethod
def get_standard_packages(cls) -> set[str]:
return set()
@classmethod
def transform_caller(cls, code: str, inputs: dict,
dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:
def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
"""
Transform code to python runner
:param code: code
@ -28,14 +20,7 @@ class TemplateTransformer(ABC):
runner_script = cls.assemble_runner_script(code, inputs)
preload_script = cls.get_preload_script()
packages = dependencies or []
standard_packages = cls.get_standard_packages()
for package in standard_packages:
if package not in packages:
packages.append(CodeDependency(name=package, version=''))
packages = list({dep.name: dep for dep in packages if dep.name}.values())
return runner_script, preload_script, packages
return runner_script, preload_script
@classmethod
def extract_result_str_from_response(cls, response: str) -> str:

View File

@ -3,6 +3,7 @@ from collections import OrderedDict
from collections.abc import Callable
from typing import Any
from configs import dify_config
from core.tools.utils.yaml_utils import load_yaml_file
@ -19,6 +20,87 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
return {name: index for index, name in enumerate(positions)}
def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping for tools from name to index from a YAML file.
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
position_map = get_position_map(folder_path, file_name=file_name)
return pin_position_map(
position_map,
pin_list=dify_config.POSITION_TOOL_PINS_LIST,
)
def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping for providers from name to index from a YAML file.
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
position_map = get_position_map(folder_path, file_name=file_name)
return pin_position_map(
position_map,
pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
)
def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
"""
Pin the items in the pin list to the beginning of the position map.
Overall logic: exclude > include > pin
:param position_map: the position map to be sorted and filtered
:param pin_list: the list of pins to be put at the beginning
:return: the sorted position map
"""
positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x])
# Add pins to position map
position_map = {name: idx for idx, name in enumerate(pin_list)}
# Add remaining positions to position map
start_idx = len(position_map)
for name in positions:
if name not in position_map:
position_map[name] = start_idx
start_idx += 1
return position_map
def is_filtered(
include_set: set[str],
exclude_set: set[str],
data: Any,
name_func: Callable[[Any], str],
) -> bool:
"""
Chcek if the object should be filtered out.
Overall logic: exclude > include > pin
:param include_set: the set of names to be included
:param exclude_set: the set of names to be excluded
:param name_func: the function to get the name of the object
:param data: the data to be filtered
:return: True if the object should be filtered out, False otherwise
"""
if not data:
return False
if not include_set and not exclude_set:
return False
name = name_func(data)
if name in exclude_set: # exclude_set is prioritized
return True
if include_set and name not in include_set: # filter out only if include_set is not empty
return True
return False
def sort_by_position_map(
position_map: dict[str, int],
data: list[Any],

View File

@ -700,6 +700,7 @@ class IndexingRunner:
DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
DatasetDocument.error: None,
}
)

View File

@ -368,6 +368,15 @@ class ModelManager:
return ModelInstance(provider_model_bundle, model)
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
"""
Return first provider and the first model in the provider
:param tenant_id: tenant id
:param model_type: model type
:return: provider name, model name
"""
return self._provider_manager.get_first_provider_first_model(tenant_id, model_type)
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
"""
Get default model instance
@ -502,7 +511,6 @@ class LBModelManager:
config.id
)
res = redis_client.exists(cooldown_cache_key)
res = cast(bool, res)
return res

View File

@ -151,9 +151,9 @@ class AIModel(ABC):
os.path.join(provider_model_type_path, model_schema_yaml)
for model_schema_yaml in os.listdir(provider_model_type_path)
if not model_schema_yaml.startswith('__')
and not model_schema_yaml.startswith('_')
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
and model_schema_yaml.endswith('.yaml')
and not model_schema_yaml.startswith('_')
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
and model_schema_yaml.endswith('.yaml')
]
# get _position.yaml file path

View File

@ -185,7 +185,7 @@ if you are not sure about the structure.
stream=stream,
user=user
)
model_parameters.pop("response_format")
stop = stop or []
stop.extend(["\n```", "```\n"])
@ -249,10 +249,10 @@ if you are not sure about the structure.
prompt_messages=prompt_messages,
input_generator=new_generator()
)
return response
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
input_generator: Generator[LLMResultChunk, None, None]
) -> Generator[LLMResultChunk, None, None]:
"""
@ -310,7 +310,7 @@ if you are not sure about the structure.
)
)
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
input_generator: Generator[LLMResultChunk, None, None]) \
-> Generator[LLMResultChunk, None, None]:
"""
@ -470,7 +470,7 @@ if you are not sure about the structure.
:return: full response or stream response chunk generator result
"""
raise NotImplementedError
@abstractmethod
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
@ -792,6 +792,13 @@ if you are not sure about the structure.
if not isinstance(parameter_value, str):
raise ValueError(f"Model Parameter {parameter_name} should be string.")
# validate options
if parameter_rule.options and parameter_value not in parameter_rule.options:
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
elif parameter_rule.type == ParameterType.TEXT:
if not isinstance(parameter_value, str):
raise ValueError(f"Model Parameter {parameter_name} should be text.")
# validate options
if parameter_rule.options and parameter_value not in parameter_rule.options:
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")

View File

@ -70,7 +70,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
# doc: https://platform.openai.com/docs/guides/text-to-speech
credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs)
# max font is 4096,there is 3500 limit for each request
# max length is 4096 characters, there is 3500 limit for each request
max_length = 3500
if len(content_text) > max_length:
sentences = self._split_text_into_sentences(content_text, max_length=max_length)

View File

@ -6,7 +6,7 @@ from typing import Optional
from pydantic import BaseModel, ConfigDict
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
@ -234,7 +234,7 @@ class ModelProviderFactory:
]
# get _position.yaml file path
position_map = get_position_map(model_providers_path)
position_map = get_provider_position_map(model_providers_path)
# traverse all model_provider_dir_paths
model_providers: list[ModelProviderExtension] = []

View File

@ -37,6 +37,9 @@ parameter_rules:
options:
- text
- json_object
- json_schema
- name: json_schema
use_template: json_schema
pricing:
input: '0.15'
output: '0.60'

View File

@ -428,7 +428,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
finish_reason = 'Unknown'
finish_reason = None # The default value of finish_reason is None
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
chunk = chunk.strip()
@ -437,6 +437,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if chunk.startswith(':'):
continue
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
if decoded_chunk == '[DONE]': # Some provider returns "data: [DONE]"
continue
try:
chunk_json = json.loads(decoded_chunk)

View File

@ -0,0 +1,44 @@
model: gpt-4o-2024-08-06
label:
zh_Hans: gpt-4o-2024-08-06
en_US: gpt-4o-2024-08-06
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '2.50'
output: '10.00'
unit: '0.000001'
currency: USD

View File

@ -118,6 +118,9 @@ class _CommonWenxin:
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1',
'bge-large-en': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en',
'bge-large-zh': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh',
'tao-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k',
}
function_calling_supports = [

View File

@ -0,0 +1,9 @@
model: bge-large-en
model_type: text-embedding
model_properties:
context_size: 512
max_chunks: 16
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,9 @@
model: bge-large-zh
model_type: text-embedding
model_properties:
context_size: 512
max_chunks: 16
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,9 @@
model: tao-8k
model_type: text-embedding
model_properties:
context_size: 8192
max_chunks: 1
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -5,6 +5,7 @@ from typing import Optional
from sqlalchemy.exc import IntegrityError
from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
from core.entities.provider_entities import (
@ -18,12 +19,9 @@ from core.entities.provider_entities import (
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.helper.position_helper import is_filtered
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import (
CredentialFormSchema,
FormType,
ProviderEntity,
)
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
from core.model_runtime.model_providers import model_provider_factory
from extensions import ext_hosting_provider
from extensions.ext_database import db
@ -45,6 +43,7 @@ class ProviderManager:
"""
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
"""
def __init__(self) -> None:
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
@ -117,6 +116,16 @@ class ProviderManager:
# Construct ProviderConfiguration objects for each provider
for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity,
name_func=lambda x: x.provider,
):
continue
provider_name = provider_entity.provider
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
@ -271,6 +280,24 @@ class ProviderManager:
)
)
def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
"""
Get names of first model and its provider
:param tenant_id: workspace id
:param model_type: model type
:return: provider name, model name
"""
provider_configurations = self.get_configurations(tenant_id)
# get available models from provider_configurations
all_models = provider_configurations.get_models(
model_type=model_type,
only_active=False
)
return all_models[0].provider.provider, all_models[0].model
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
-> TenantDefaultModel:
"""

View File

@ -152,8 +152,27 @@ class PGVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# do not support bm25 search
return []
top_k = kwargs.get("top_k", 5)
with self._get_cursor() as cur:
cur.execute(
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), to_tsquery(%s)) AS score
FROM {self.table_name}
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
ORDER BY score DESC
LIMIT {top_k}""",
# f"'{query}'" is required in order to account for whitespace in query
(f"'{query}'", f"'{query}'"),
)
docs = []
for record in cur:
metadata, text, score = record
metadata["score"] = score
docs.append(Document(page_content=text, metadata=metadata))
return docs
def delete(self) -> None:
with self._get_cursor() as cur:

View File

@ -1,6 +1,6 @@
import os.path
from core.helper.position_helper import get_position_map, sort_by_position_map
from core.helper.position_helper import get_tool_position_map, sort_by_position_map
from core.tools.entities.api_entities import UserToolProvider
@ -10,11 +10,11 @@ class BuiltinToolProviderSort:
@classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
if not cls._position:
cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..'))
def name_func(provider: UserToolProvider) -> str:
return provider.name
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
return sorted_providers
return sorted_providers

View File

@ -0,0 +1,49 @@
<?xml version="1.0" encoding="utf-8"?>
<!-- Generator: Adobe Illustrator 19.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
viewBox="0 0 200 130.2" style="enable-background:new 0 0 200 130.2;" xml:space="preserve">
<style type="text/css">
.st0{fill:#3EB1C8;}
.st1{fill:#D8D2C4;}
.st2{fill:#4F5858;}
.st3{fill:#FFC72C;}
.st4{fill:#EF3340;}
</style>
<g>
<polygon class="st0" points="111.8,95.5 111.8,66.8 135.4,59 177.2,73.3 "/>
<polygon class="st1" points="153.6,36.8 111.8,51.2 135.4,59 177.2,44.6 "/>
<polygon class="st2" points="135.4,59 177.2,44.6 177.2,73.3 "/>
<polygon class="st3" points="177.2,0.3 177.2,29 153.6,36.8 111.8,22.5 "/>
<polygon class="st4" points="153.6,36.8 111.8,51.2 111.8,22.5 "/>
<g>
<g>
<g>
<g>
<path class="st2" d="M26.3,104.8c-0.5-3.7-4.1-6.5-8.1-6.5c-7.3,0-10.1,6.2-10.1,12.7c0,6.2,2.8,12.4,10.1,12.4
c5,0,7.8-3.4,8.4-8.3h7.9c-0.8,9.2-7.2,15.2-16.3,15.2C6.8,130.2,0,121.7,0,111c0-11,6.8-19.6,18.2-19.6c8.2,0,15,4.8,16,13.3
H26.3z"/>
<path class="st2" d="M37.4,102.5h7v5h0.1c1.4-3.4,5-5.7,8.6-5.7c0.5,0,1.1,0.1,1.6,0.3v6.9c-0.7-0.2-1.8-0.3-2.6-0.3
c-5.4,0-7.3,3.9-7.3,8.6v12.1h-7.4V102.5z"/>
<path class="st2" d="M68.7,101.8c8.5,0,13.9,5.6,13.9,14.2c0,8.5-5.5,14.1-13.9,14.1c-8.4,0-13.9-5.6-13.9-14.1
C54.9,107.4,60.3,101.8,68.7,101.8z M68.7,124.5c5,0,6.5-4.3,6.5-8.6c0-4.3-1.5-8.6-6.5-8.6c-5,0-6.5,4.3-6.5,8.6
C62.2,120.2,63.8,124.5,68.7,124.5z"/>
<path class="st2" d="M91.2,120.6c0.1,3.2,2.8,4.5,5.7,4.5c2.1,0,4.8-0.8,4.8-3.4c0-2.2-3.1-3-8.4-4.2c-4.3-0.9-8.5-2.4-8.5-7.2
c0-6.9,5.9-8.6,11.7-8.6c5.9,0,11.3,2,11.8,8.6h-7c-0.2-2.9-2.4-3.6-5-3.6c-1.7,0-4.1,0.3-4.1,2.5c0,2.6,4.2,3,8.4,4
c4.3,1,8.5,2.5,8.5,7.5c0,7.1-6.1,9.3-12.3,9.3c-6.2,0-12.3-2.3-12.6-9.5H91.2z"/>
<path class="st2" d="M118.1,120.6c0.1,3.2,2.8,4.5,5.7,4.5c2.1,0,4.8-0.8,4.8-3.4c0-2.2-3.1-3-8.4-4.2
c-4.3-0.9-8.5-2.4-8.5-7.2c0-6.9,5.9-8.6,11.7-8.6c5.9,0,11.3,2,11.8,8.6h-7c-0.2-2.9-2.4-3.6-5-3.6c-1.7,0-4.1,0.3-4.1,2.5
c0,2.6,4.2,3,8.4,4c4.3,1,8.5,2.5,8.5,7.5c0,7.1-6.1,9.3-12.3,9.3c-6.2,0-12.3-2.3-12.6-9.5H118.1z"/>
<path class="st2" d="M138.4,102.5h7v5h0.1c1.4-3.4,5-5.7,8.6-5.7c0.5,0,1.1,0.1,1.6,0.3v6.9c-0.7-0.2-1.8-0.3-2.6-0.3
c-5.4,0-7.3,3.9-7.3,8.6v12.1h-7.4V102.5z"/>
<path class="st2" d="M163.7,117.7c0.2,4.7,2.5,6.8,6.6,6.8c3,0,5.3-1.8,5.8-3.5h6.5c-2.1,6.3-6.5,9-12.6,9
c-8.5,0-13.7-5.8-13.7-14.1c0-8,5.6-14.2,13.7-14.2c9.1,0,13.6,7.7,13,15.9H163.7z M175.7,113.1c-0.7-3.7-2.3-5.7-5.9-5.7
c-4.7,0-6,3.6-6.1,5.7H175.7z"/>
<path class="st2" d="M187.2,107.5h-4.4v-4.9h4.4v-2.1c0-4.7,3-8.2,9-8.2c1.3,0,2.6,0.2,3.9,0.2V98c-0.9-0.1-1.8-0.2-2.7-0.2
c-2,0-2.8,0.8-2.8,3.1v1.6h5.1v4.9h-5.1v21.9h-7.4V107.5z"/>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 3.0 KiB

View File

@ -0,0 +1,20 @@
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.crossref.tools.query_doi import CrossRefQueryDOITool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class CrossRefProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
CrossRefQueryDOITool().fork_tool_runtime(
runtime={
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
"doi": '10.1007/s00894-022-05373-8',
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,29 @@
identity:
author: Sakura4036
name: crossref
label:
en_US: CrossRef
zh_Hans: CrossRef
description:
en_US: Crossref is a cross-publisher reference linking registration query system using DOI technology created in 2000. Crossref establishes cross-database links between the reference list and citation full text of papers, making it very convenient for readers to access the full text of papers.
zh_Hans: Crossref是于2000年创建的使用DOI技术的跨出版商参考文献链接注册查询系统。Crossref建立了在论文的参考文献列表和引文全文之间的跨数据库链接使得读者能够非常便捷地获取文献全文。
icon: icon.svg
tags:
- search
credentials_for_provider:
mailto:
type: text-input
required: true
label:
en_US: email address
zh_Hans: email地址
pt_BR: email address
placeholder:
en_US: Please input your email address
zh_Hans: 请输入你的email地址
pt_BR: Please input your email address
help:
en_US: According to the requirements of Crossref, an email address is required
zh_Hans: 根据Crossref的要求需要提供一个邮箱地址
pt_BR: According to the requirements of Crossref, an email address is required
url: https://api.crossref.org/swagger-ui/index.html

View File

@ -0,0 +1,25 @@
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolParameterValidationError
from core.tools.tool.builtin_tool import BuiltinTool
class CrossRefQueryDOITool(BuiltinTool):
"""
Tool for querying the metadata of a publication using its DOI.
"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
doi = tool_parameters.get('doi')
if not doi:
raise ToolParameterValidationError('doi is required.')
# doc: https://github.com/CrossRef/rest-api-doc
url = f"https://api.crossref.org/works/{doi}"
response = requests.get(url)
response.raise_for_status()
response = response.json()
message = response.get('message', {})
return self.create_json_message(message)

View File

@ -0,0 +1,23 @@
identity:
name: crossref_query_doi
author: Sakura4036
label:
en_US: CrossRef Query DOI
zh_Hans: CrossRef DOI 查询
pt_BR: CrossRef Query DOI
description:
human:
en_US: A tool for searching literature information using CrossRef by DOI.
zh_Hans: 一个使用CrossRef通过DOI获取文献信息的工具。
pt_BR: A tool for searching literature information using CrossRef by DOI.
llm: A tool for searching literature information using CrossRef by DOI.
parameters:
- name: doi
type: string
required: true
label:
en_US: DOI
zh_Hans: DOI
pt_BR: DOI
llm_description: DOI for searching in CrossRef
form: llm

View File

@ -0,0 +1,120 @@
import time
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
def convert_time_str_to_seconds(time_str: str) -> int:
"""
Convert a time string to seconds.
example: 1s -> 1, 1m30s -> 90, 1h30m -> 5400, 1h30m30s -> 5430
"""
time_str = time_str.lower().strip().replace(' ', '')
seconds = 0
if 'h' in time_str:
hours, time_str = time_str.split('h')
seconds += int(hours) * 3600
if 'm' in time_str:
minutes, time_str = time_str.split('m')
seconds += int(minutes) * 60
if 's' in time_str:
seconds += int(time_str.replace('s', ''))
return seconds
class CrossRefQueryTitleAPI:
"""
Tool for querying the metadata of a publication using its title.
Crossref API doc: https://github.com/CrossRef/rest-api-doc
"""
query_url_template: str = "https://api.crossref.org/works?query.bibliographic={query}&rows={rows}&offset={offset}&sort={sort}&order={order}&mailto={mailto}"
rate_limit: int = 50
rate_interval: float = 1
max_limit: int = 1000
def __init__(self, mailto: str):
self.mailto = mailto
def _query(self, query: str, rows: int = 5, offset: int = 0, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]:
"""
Query the metadata of a publication using its title.
:param query: the title of the publication
:param rows: the number of results to return
:param sort: the sort field
:param order: the sort order
:param fuzzy_query: whether to return all items that match the query
"""
url = self.query_url_template.format(query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto)
response = requests.get(url)
response.raise_for_status()
rate_limit = int(response.headers['x-ratelimit-limit'])
# convert time string to seconds
rate_interval = convert_time_str_to_seconds(response.headers['x-ratelimit-interval'])
self.rate_limit = rate_limit
self.rate_interval = rate_interval
response = response.json()
if response['status'] != 'ok':
return []
message = response['message']
if fuzzy_query:
# fuzzy query return all items
return message['items']
else:
for paper in message['items']:
title = paper['title'][0]
if title.lower() != query.lower():
continue
return [paper]
return []
def query(self, query: str, rows: int = 5, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]:
"""
Query the metadata of a publication using its title.
:param query: the title of the publication
:param rows: the number of results to return
:param sort: the sort field
:param order: the sort order
:param fuzzy_query: whether to return all items that match the query
"""
rows = min(rows, self.max_limit)
if rows > self.rate_limit:
# query multiple times
query_times = rows // self.rate_limit + 1
results = []
for i in range(query_times):
result = self._query(query, rows=self.rate_limit, offset=i * self.rate_limit, sort=sort, order=order, fuzzy_query=fuzzy_query)
if fuzzy_query:
results.extend(result)
else:
# fuzzy_query=False, only one result
if result:
return result
time.sleep(self.rate_interval)
return results
else:
# query once
return self._query(query, rows, sort=sort, order=order, fuzzy_query=fuzzy_query)
class CrossRefQueryTitleTool(BuiltinTool):
"""
Tool for querying the metadata of a publication using its title.
"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
query = tool_parameters.get('query')
fuzzy_query = tool_parameters.get('fuzzy_query', False)
rows = tool_parameters.get('rows', 3)
sort = tool_parameters.get('sort', 'relevance')
order = tool_parameters.get('order', 'desc')
mailto = self.runtime.credentials['mailto']
result = CrossRefQueryTitleAPI(mailto).query(query, rows, sort, order, fuzzy_query)
return [self.create_json_message(r) for r in result]

View File

@ -0,0 +1,105 @@
identity:
name: crossref_query_title
author: Sakura4036
label:
en_US: CrossRef Title Query
zh_Hans: CrossRef 标题查询
pt_BR: CrossRef Title Query
description:
human:
en_US: A tool for querying literature information using CrossRef by title.
zh_Hans: 一个使用CrossRef通过标题搜索文献信息的工具。
pt_BR: A tool for querying literature information using CrossRef by title.
llm: A tool for querying literature information using CrossRef by title.
parameters:
- name: query
type: string
required: true
label:
en_US: 标题
zh_Hans: 查询语句
pt_BR: 标题
human_description:
en_US: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years
zh_Hans: 用于搜索文献信息有助于查找引用。包括标题作者ISSN和出版年份
pt_BR: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years
llm_description: key words for querying in Web of Science
form: llm
- name: fuzzy_query
type: boolean
default: false
label:
en_US: Whether to fuzzy search
zh_Hans: 是否模糊搜索
pt_BR: Whether to fuzzy search
human_description:
en_US: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none
zh_Hans: 用于选择搜索类型模糊搜索返回更多结果精确搜索返回1条结果或无
pt_BR: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none
form: form
- name: limit
type: number
required: false
label:
en_US: max query number
zh_Hans: 最大搜索数
pt_BR: max query number
human_description:
en_US: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches)
zh_Hans: 最大搜索数(模糊搜索返回的最大结果数或精确搜索最大匹配数)
pt_BR: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches)
form: llm
default: 50
- name: sort
type: select
required: true
options:
- value: relevance
label:
en_US: relevance
zh_Hans: 相关性
pt_BR: relevance
- value: published
label:
en_US: publication date
zh_Hans: 出版日期
pt_BR: publication date
- value: references-count
label:
en_US: references-count
zh_Hans: 引用次数
pt_BR: references-count
default: relevance
label:
en_US: sorting field
zh_Hans: 排序字段
pt_BR: sorting field
human_description:
en_US: Sorting of query results
zh_Hans: 检索结果的排序字段
pt_BR: Sorting of query results
form: form
- name: order
type: select
required: true
options:
- value: desc
label:
en_US: descending
zh_Hans: 降序
pt_BR: descending
- value: asc
label:
en_US: ascending
zh_Hans: 升序
pt_BR: ascending
default: desc
label:
en_US: Order
zh_Hans: 排序
pt_BR: Order
human_description:
en_US: Order of query results
zh_Hans: 检索结果的排序方式
pt_BR: Order of query results
form: form

View File

@ -60,5 +60,11 @@ parameters:
label:
en_US: Tokenizer
human_description:
en_US: cl100k_base - gpt-4,gpt-3.5-turbo,gpt-3.5; o200k_base - gpt-4o,gpt-4o-mini; p50k_base - text-davinci-003,text-davinci-002
en_US: |
· cl100k_base --- gpt-4, gpt-3.5-turbo, gpt-3.5
· o200k_base --- gpt-4o, gpt-4o-mini
· p50k_base --- text-davinci-003, text-davinci-002
· r50k_base --- text-davinci-001, text-curie-001
· p50k_edit --- text-davinci-edit-001, code-davinci-edit-001
· gpt2 --- gpt-2
form: form

View File

@ -0,0 +1,73 @@
from novita_client import (
Txt2ImgV3Embedding,
Txt2ImgV3HiresFix,
Txt2ImgV3LoRA,
Txt2ImgV3Refiner,
V3TaskImage,
)
class NovitaAiToolBase:
def _extract_loras(self, loras_str: str):
if not loras_str:
return []
loras_ori_list = lora_str.strip().split(';')
result_list = []
for lora_str in loras_ori_list:
lora_info = lora_str.strip().split(',')
lora = Txt2ImgV3LoRA(
model_name=lora_info[0].strip(),
strength=float(lora_info[1]),
)
result_list.append(lora)
return result_list
def _extract_embeddings(self, embeddings_str: str):
if not embeddings_str:
return []
embeddings_ori_list = embeddings_str.strip().split(';')
result_list = []
for embedding_str in embeddings_ori_list:
embedding = Txt2ImgV3Embedding(
model_name=embedding_str.strip()
)
result_list.append(embedding)
return result_list
def _extract_hires_fix(self, hires_fix_str: str):
hires_fix_info = hires_fix_str.strip().split(',')
if 'upscaler' in hires_fix_info:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2]),
upscaler=hires_fix_info[3].strip()
)
else:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2])
)
return hires_fix
def _extract_refiner(self, switch_at: str):
refiner = Txt2ImgV3Refiner(
switch_at=float(switch_at)
)
return refiner
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
"""
is hit nsfw
"""
if image.nsfw_detection_result is None:
return False
if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
return True
return False

View File

@ -4,19 +4,15 @@ from typing import Any, Union
from novita_client import (
NovitaClient,
Txt2ImgV3Embedding,
Txt2ImgV3HiresFix,
Txt2ImgV3LoRA,
Txt2ImgV3Refiner,
V3TaskImage,
)
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.novitaai._novita_tool_base import NovitaAiToolBase
from core.tools.tool.builtin_tool import BuiltinTool
class NovitaAiTxt2ImgTool(BuiltinTool):
class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
@ -73,65 +69,19 @@ class NovitaAiTxt2ImgTool(BuiltinTool):
# process loras
if 'loras' in res_parameters:
loras_ori_list = res_parameters.get('loras').strip().split(';')
locals_list = []
for lora_str in loras_ori_list:
lora_info = lora_str.strip().split(',')
lora = Txt2ImgV3LoRA(
model_name=lora_info[0].strip(),
strength=float(lora_info[1]),
)
locals_list.append(lora)
res_parameters['loras'] = locals_list
res_parameters['loras'] = self._extract_loras(res_parameters.get('loras'))
# process embeddings
if 'embeddings' in res_parameters:
embeddings_ori_list = res_parameters.get('embeddings').strip().split(';')
locals_list = []
for embedding_str in embeddings_ori_list:
embedding = Txt2ImgV3Embedding(
model_name=embedding_str.strip()
)
locals_list.append(embedding)
res_parameters['embeddings'] = locals_list
res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings'))
# process hires_fix
if 'hires_fix' in res_parameters:
hires_fix_ori = res_parameters.get('hires_fix')
hires_fix_info = hires_fix_ori.strip().split(',')
if 'upscaler' in hires_fix_info:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2]),
upscaler=hires_fix_info[3].strip()
)
else:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2])
)
res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix'))
res_parameters['hires_fix'] = hires_fix
if 'refiner_switch_at' in res_parameters:
refiner = Txt2ImgV3Refiner(
switch_at=float(res_parameters.get('refiner_switch_at'))
)
del res_parameters['refiner_switch_at']
res_parameters['refiner'] = refiner
# process refiner
if 'refiner_switch_at' in res_parameters:
res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at'))
del res_parameters['refiner_switch_at']
return res_parameters
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
"""
is hit nsfw
"""
if image.nsfw_detection_result is None:
return False
if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
return True
return False

View File

@ -1,6 +1,6 @@
from typing import Optional
from core.app.app_config.entities import VariableEntity
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
@ -18,6 +18,13 @@ from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import Workflow
VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
}
class WorkflowToolProviderController(ToolProviderController):
provider_id: str
@ -28,7 +35,7 @@ class WorkflowToolProviderController(ToolProviderController):
if not app:
raise ValueError('app not found')
controller = WorkflowToolProviderController(**{
'identity': {
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
@ -46,7 +53,7 @@ class WorkflowToolProviderController(ToolProviderController):
'credentials_schema': {},
'provider_id': db_provider.id or '',
})
# init tools
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
@ -56,7 +63,7 @@ class WorkflowToolProviderController(ToolProviderController):
@property
def provider_type(self) -> ToolProviderType:
return ToolProviderType.WORKFLOW
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
"""
get db provider tool
@ -93,23 +100,11 @@ class WorkflowToolProviderController(ToolProviderController):
if variable:
parameter_type = None
options = None
if variable.type in [
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.PARAGRAPH,
]:
parameter_type = ToolParameter.ToolParameterType.STRING
elif variable.type in [
VariableEntity.Type.SELECT
]:
parameter_type = ToolParameter.ToolParameterType.SELECT
elif variable.type in [
VariableEntity.Type.NUMBER
]:
parameter_type = ToolParameter.ToolParameterType.NUMBER
else:
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
raise ValueError(f'unsupported variable type {variable.type}')
if variable.type == VariableEntity.Type.SELECT and variable.options:
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
if variable.type == VariableEntityType.SELECT and variable.options:
options = [
ToolParameterOption(
value=option,
@ -200,7 +195,7 @@ class WorkflowToolProviderController(ToolProviderController):
"""
if self.tools is not None:
return self.tools
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
@ -208,11 +203,11 @@ class WorkflowToolProviderController(ToolProviderController):
if not db_providers:
return []
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
return self.tools
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
"""
get tool by name
@ -226,5 +221,5 @@ class WorkflowToolProviderController(ToolProviderController):
for tool in self.tools:
if tool.identity.name == tool_name:
return tool
return None

View File

@ -10,14 +10,11 @@ from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolInvokeFrom,
ToolParameter,
)
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter
from core.tools.errors import ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
@ -26,10 +23,7 @@ from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ToolConfigurationManager,
ToolParameterConfigurationManager,
)
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from core.workflow.nodes.tool.entities import ToolEntity
from extensions.ext_database import db
@ -38,6 +32,7 @@ from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class ToolManager:
_builtin_provider_lock = Lock()
_builtin_providers = {}
@ -107,7 +102,7 @@ class ToolManager:
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
-> Union[BuiltinTool, ApiTool]:
-> Union[BuiltinTool, ApiTool]:
"""
get the tool runtime
@ -346,7 +341,7 @@ class ToolManager:
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'provider', 'builtin', provider, f'{provider}.py'),
'provider', 'builtin', provider, f'{provider}.py'),
parent_type=BuiltinToolProviderController)
provider: BuiltinToolProviderController = provider_class()
cls._builtin_providers[provider.identity.name] = provider
@ -414,6 +409,15 @@ class ToolManager:
# append builtin providers
for provider in builtin_providers:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider,
name_func=lambda x: x.identity.name
):
continue
user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider,
db_provider=find_db_builtin_provider(provider.identity.name),
@ -473,7 +477,7 @@ class ToolManager:
@classmethod
def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
ApiToolProviderController, dict[str, Any]]:
ApiToolProviderController, dict[str, Any]]:
"""
get the api provider
@ -593,4 +597,5 @@ class ToolManager:
else:
raise ValueError(f"provider type {provider_type} not found")
ToolManager.load_builtin_providers_cache()

View File

@ -6,20 +6,20 @@ from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, FileVar]
SYSTEM_VARIABLE_NODE_ID = 'sys'
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
CONVERSATION_VARIABLE_NODE_ID = 'conversation'
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
class VariablePool:
def __init__(
self,
system_variables: Mapping[SystemVariable, Any],
system_variables: Mapping[SystemVariableKey, Any],
user_inputs: Mapping[str, Any],
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable] | None = None,
@ -68,7 +68,7 @@ class VariablePool:
None
"""
if len(selector) < 2:
raise ValueError('Invalid selector')
raise ValueError("Invalid selector")
if value is None:
return
@ -95,13 +95,13 @@ class VariablePool:
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError('Invalid selector')
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
return value
@deprecated('This method is deprecated, use `get` instead.')
@deprecated("This method is deprecated, use `get` instead.")
def get_any(self, selector: Sequence[str], /) -> Any | None:
"""
Retrieves the value from the variable pool based on the given selector.
@ -116,7 +116,7 @@ class VariablePool:
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError('Invalid selector')
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
return value.to_object() if value else None

View File

@ -1,25 +1,13 @@
from enum import Enum
class SystemVariable(str, Enum):
class SystemVariableKey(str, Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
DIALOGUE_COUNT = 'dialogue_count'
@classmethod
def value_of(cls, value: str):
"""
Get value of given system variable.
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')
QUERY = "query"
FILES = "files"
CONVERSATION_ID = "conversation_id"
USER_ID = "user_id"
DIALOGUE_COUNT = "dialogue_count"

View File

@ -13,8 +13,8 @@ from models.workflow import WorkflowNodeExecutionStatus
MAX_NUMBER = dify_config.CODE_MAX_NUMBER
MIN_NUMBER = dify_config.CODE_MIN_NUMBER
MAX_PRECISION = 20
MAX_DEPTH = 5
MAX_PRECISION = dify_config.CODE_MAX_PRECISION
MAX_DEPTH = dify_config.CODE_MAX_DEPTH
MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH
MAX_STRING_ARRAY_LENGTH = dify_config.CODE_MAX_STRING_ARRAY_LENGTH
MAX_OBJECT_ARRAY_LENGTH = dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH
@ -23,7 +23,7 @@ MAX_NUMBER_ARRAY_LENGTH = dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH
class CodeNode(BaseNode):
_node_data_cls = CodeNodeData
node_type = NodeType.CODE
_node_type = NodeType.CODE
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
@ -48,8 +48,7 @@ class CodeNode(BaseNode):
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data: CodeNodeData = cast(self._node_data_cls, node_data)
node_data = cast(CodeNodeData, self.node_data)
# Get code language
code_language = node_data.code_language
@ -68,7 +67,6 @@ class CodeNode(BaseNode):
language=code_language,
code=code,
inputs=variables,
dependencies=node_data.dependencies
)
# Transform result

View File

@ -3,7 +3,6 @@ from typing import Literal, Optional
from pydantic import BaseModel
from core.helper.code_executor.code_executor import CodeLanguage
from core.helper.code_executor.entities import CodeDependency
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
@ -16,8 +15,12 @@ class CodeNodeData(BaseNodeData):
type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]']
children: Optional[dict[str, 'Output']] = None
class Dependency(BaseModel):
name: str
version: str
variables: list[VariableSelector]
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
code: str
outputs: dict[str, Output]
dependencies: Optional[list[CodeDependency]] = None
dependencies: Optional[list[Dependency]] = None

View File

@ -24,7 +24,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
@ -94,7 +94,7 @@ class LLMNode(BaseNode):
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data,
query=variable_pool.get_any(['sys', SystemVariable.QUERY.value])
query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value])
if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
inputs=inputs,
@ -113,7 +113,7 @@ class LLMNode(BaseNode):
}
# handle invoke result
result_text, usage = self._invoke_llm(
result_text, usage, finish_reason = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
@ -129,7 +129,8 @@ class LLMNode(BaseNode):
outputs = {
'text': result_text,
'usage': jsonable_encoder(usage)
'usage': jsonable_encoder(usage),
'finish_reason': finish_reason
}
return NodeRunResult(
@ -167,14 +168,14 @@ class LLMNode(BaseNode):
)
# handle invoke result
text, usage = self._handle_invoke_result(
text, usage, finish_reason = self._handle_invoke_result(
invoke_result=invoke_result
)
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage
return text, usage, finish_reason
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
"""
@ -186,6 +187,7 @@ class LLMNode(BaseNode):
prompt_messages = []
full_text = ''
usage = None
finish_reason = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
@ -201,10 +203,13 @@ class LLMNode(BaseNode):
if not usage and result.delta.usage:
usage = result.delta.usage
if not finish_reason and result.delta.finish_reason:
finish_reason = result.delta.finish_reason
if not usage:
usage = LLMUsage.empty_usage()
return full_text, usage
return full_text, usage, finish_reason
def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
@ -335,7 +340,7 @@ class LLMNode(BaseNode):
if not node_data.vision.enabled:
return []
files = variable_pool.get_any(['sys', SystemVariable.FILES.value])
files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value])
if not files:
return []
@ -500,7 +505,7 @@ class LLMNode(BaseNode):
return None
# get conversation id
conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value])
conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value])
if conversation_id is None:
return None
@ -672,10 +677,10 @@ class LLMNode(BaseNode):
variable_mapping['#context#'] = node_data.context.variable_selector
if node_data.vision.enabled:
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value]
if node_data.memory:
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value]
if node_data.prompt_config:
enable_jinja = False

View File

@ -63,7 +63,7 @@ class QuestionClassifierNode(LLMNode):
)
# handle invoke result
result_text, usage = self._invoke_llm(
result_text, usage, finish_reason = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
@ -93,6 +93,7 @@ class QuestionClassifierNode(LLMNode):
prompt_messages=prompt_messages
),
'usage': jsonable_encoder(usage),
'finish_reason': finish_reason
}
outputs = {
'class_name': category_name

View File

@ -1,3 +1,7 @@
from collections.abc import Sequence
from pydantic import Field
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData
@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData):
"""
Start Node Data
"""
variables: list[VariableEntity] = []
variables: Sequence[VariableEntity] = Field(default_factory=list)

View File

@ -1,7 +1,7 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
@ -17,16 +17,16 @@ class StartNode(BaseNode):
:param variable_pool: variable pool
:return:
"""
# Get cleaned inputs
cleaned_inputs = dict(variable_pool.user_inputs)
node_inputs = dict(variable_pool.user_inputs)
system_inputs = variable_pool.system_variables
for var in variable_pool.system_variables:
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=cleaned_inputs,
outputs=cleaned_inputs
inputs=node_inputs,
outputs=node_inputs
)
@classmethod

View File

@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence
from os import path
from typing import Any, cast
from core.app.segments import ArrayAnyVariable, parser
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
@ -141,8 +141,8 @@ class ToolNode(BaseNode):
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
assert isinstance(variable, ArrayAnyVariable)
variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):

View File

@ -1,109 +1,8 @@
from collections.abc import Sequence
from enum import Enum
from typing import Optional, cast
from .node import VariableAssignerNode
from .node_data import VariableAssignerData, WriteMode
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.segments import SegmentType, Variable, factory
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from extensions.ext_database import db
from models import ConversationVariable, WorkflowNodeExecutionStatus
class VariableAssignerNodeError(Exception):
pass
class WriteMode(str, Enum):
OVER_WRITE = 'over-write'
APPEND = 'append'
CLEAR = 'clear'
class VariableAssignerData(BaseNodeData):
title: str = 'Variable Assigner'
desc: Optional[str] = 'Assign a value to a variable'
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]
class VariableAssignerNode(BaseNode):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
data = cast(VariableAssignerData, self.node_data)
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = variable_pool.get(data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError('assigned variable not found')
match data.write_mode:
case WriteMode.OVER_WRITE:
income_value = variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_variable = original_variable.model_copy(update={'value': income_value.value})
case WriteMode.APPEND:
income_value = variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={'value': updated_value})
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
case _:
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
# Over write the variable.
variable_pool.add(data.assigned_variable_selector, updated_variable)
# Update conversation variable.
# TODO: Find a better way to use the database.
conversation_id = variable_pool.get(['sys', 'conversation_id'])
if not conversation_id:
raise VariableAssignerNodeError('conversation_id not found')
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
'value': income_value.to_object(),
},
)
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableAssignerNodeError('conversation variable not found in the database')
row.data = variable.model_dump_json()
session.commit()
def get_zero_value(t: SegmentType):
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return factory.build_segment([])
case SegmentType.OBJECT:
return factory.build_segment({})
case SegmentType.STRING:
return factory.build_segment('')
case SegmentType.NUMBER:
return factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f'unsupported variable type: {t}')
__all__ = [
'VariableAssignerNode',
'VariableAssignerData',
'WriteMode',
]

View File

@ -0,0 +1,2 @@
class VariableAssignerNodeError(Exception):
pass

View File

@ -0,0 +1,92 @@
from typing import cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.segments import SegmentType, Variable, factory
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from extensions.ext_database import db
from models import ConversationVariable, WorkflowNodeExecutionStatus
from .exc import VariableAssignerNodeError
from .node_data import VariableAssignerData, WriteMode
class VariableAssignerNode(BaseNode):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
data = cast(VariableAssignerData, self.node_data)
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = variable_pool.get(data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError('assigned variable not found')
match data.write_mode:
case WriteMode.OVER_WRITE:
income_value = variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_variable = original_variable.model_copy(update={'value': income_value.value})
case WriteMode.APPEND:
income_value = variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={'value': updated_value})
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
case _:
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
# Over write the variable.
variable_pool.add(data.assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline.
# Update conversation variable.
conversation_id = variable_pool.get(['sys', 'conversation_id'])
if not conversation_id:
raise VariableAssignerNodeError('conversation_id not found')
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
'value': income_value.to_object(),
},
)
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableAssignerNodeError('conversation variable not found in the database')
row.data = variable.model_dump_json()
session.commit()
def get_zero_value(t: SegmentType):
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return factory.build_segment([])
case SegmentType.OBJECT:
return factory.build_segment({})
case SegmentType.STRING:
return factory.build_segment('')
case SegmentType.NUMBER:
return factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f'unsupported variable type: {t}')

View File

@ -0,0 +1,19 @@
from collections.abc import Sequence
from enum import Enum
from typing import Optional
from core.workflow.entities.base_node_data_entities import BaseNodeData
class WriteMode(str, Enum):
OVER_WRITE = 'over-write'
APPEND = 'append'
CLEAR = 'clear'
class VariableAssignerData(BaseNodeData):
title: str = 'Variable Assigner'
desc: Optional[str] = 'Assign a value to a variable'
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]

View File

@ -28,6 +28,16 @@ class S3Storage(BaseStorage):
region_name=app_config.get("S3_REGION"),
config=Config(s3={"addressing_style": app_config.get("S3_ADDRESS_STYLE")}),
)
# create bucket
try:
self.client.head_bucket(Bucket=self.bucket_name)
except ClientError as e:
# if bucket not exists, create it
if e.response["Error"]["Code"] == "404":
self.client.create_bucket(Bucket=self.bucket_name)
else:
# other error, raise exception
raise
def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)

View File

@ -150,6 +150,7 @@ conversation_with_summary_fields = {
"summary": fields.String(attribute="summary_or_query"),
"read_at": TimestampField,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotated": fields.Boolean,
"model_config": fields.Nested(simple_model_config_fields),
"message_count": fields.Integer,

View File

@ -0,0 +1,28 @@
"""rename workflow__conversation_variables to workflow_conversation_variables
Revision ID: 2dbe42621d96
Revises: a6be81136580
Create Date: 2024-08-20 04:55:38.160010
"""
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = '2dbe42621d96'
down_revision = 'a6be81136580'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.rename_table('workflow__conversation_variables', 'workflow_conversation_variables')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.rename_table('workflow_conversation_variables', 'workflow__conversation_variables')
# ### end Alembic commands ###

View File

@ -1,4 +1,5 @@
import base64
import enum
import hashlib
import hmac
import json
@ -22,6 +23,11 @@ from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID
class DatasetPermissionEnum(str, enum.Enum):
ONLY_ME = 'only_me'
ALL_TEAM = 'all_team_members'
PARTIAL_TEAM = 'partial_members'
class Dataset(db.Model):
__tablename__ = 'datasets'
__table_args__ = (

View File

@ -1,5 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum
from typing import Any, Optional, Union
@ -110,19 +111,32 @@ class Workflow(db.Model):
db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'),
)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
app_id = db.Column(StringUUID, nullable=False)
type = db.Column(db.String(255), nullable=False)
version = db.Column(db.String(255), nullable=False)
graph = db.Column(db.Text)
features = db.Column(db.Text)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(StringUUID)
updated_at = db.Column(db.DateTime)
_environment_variables = db.Column('environment_variables', db.Text, nullable=False, server_default='{}')
_conversation_variables = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}')
id: Mapped[str] = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
app_id: Mapped[str] = db.Column(StringUUID, nullable=False)
type: Mapped[str] = db.Column(db.String(255), nullable=False)
version: Mapped[str] = db.Column(db.String(255), nullable=False)
graph: Mapped[str] = db.Column(db.Text)
features: Mapped[str] = db.Column(db.Text)
created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by: Mapped[str] = db.Column(StringUUID)
updated_at: Mapped[datetime] = db.Column(db.DateTime)
_environment_variables: Mapped[str] = db.Column('environment_variables', db.Text, nullable=False, server_default='{}')
_conversation_variables: Mapped[str] = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}')
def __init__(self, *, tenant_id: str, app_id: str, type: str, version: str, graph: str,
features: str, created_by: str, environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable]):
self.tenant_id = tenant_id
self.app_id = app_id
self.type = type
self.version = version
self.graph = graph
self.features = features
self.created_by = created_by
self.environment_variables = environment_variables or []
self.conversation_variables = conversation_variables or []
@property
def created_by_account(self):
@ -724,7 +738,7 @@ class WorkflowAppLog(db.Model):
class ConversationVariable(db.Model):
__tablename__ = 'workflow__conversation_variables'
__tablename__ = 'workflow_conversation_variables'
id: Mapped[str] = db.Column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True)

1556
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -153,7 +153,7 @@ langfuse = "^2.36.1"
langsmith = "^0.1.77"
mailchimp-transactional = "~1.0.50"
markdown = "~3.5.1"
novita-client = "^0.5.6"
novita-client = "^0.5.7"
numpy = "~1.26.4"
openai = "~1.29.0"
openpyxl = "~3.1.5"

View File

@ -111,6 +111,12 @@ class AppService:
'completion_params': {}
}
else:
provider, model = model_manager.get_default_provider_model_name(
tenant_id=account.current_tenant_id,
model_type=ModelType.LLM
)
default_model_config['model']['provider'] = provider
default_model_config['model']['name'] = model
default_model_dict = default_model_config['model']
default_model_config['model'] = json.dumps(default_model_dict)
@ -190,13 +196,14 @@ class AppService:
"""
Modified App class
"""
def __init__(self, app):
self.__dict__.update(app.__dict__)
@property
def app_model_config(self):
return model_config
app = ModifiedApp(app)
return app

View File

@ -1,6 +1,7 @@
from datetime import datetime, timezone
from typing import Optional, Union
from sqlalchemy import or_
from sqlalchemy import asc, desc, or_
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
@ -18,7 +19,8 @@ class ConversationService:
last_id: Optional[str], limit: int,
invoke_from: InvokeFrom,
include_ids: Optional[list] = None,
exclude_ids: Optional[list] = None) -> InfiniteScrollPagination:
exclude_ids: Optional[list] = None,
sort_by: str = '-updated_at') -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
@ -37,28 +39,28 @@ class ConversationService:
if exclude_ids is not None:
base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
if last_id:
last_conversation = base_query.filter(
Conversation.id == last_id,
).first()
# define sort fields and directions
sort_field, sort_direction = cls._get_sort_params(sort_by)
if last_id:
last_conversation = base_query.filter(Conversation.id == last_id).first()
if not last_conversation:
raise LastConversationNotExistsError()
conversations = base_query.filter(
Conversation.created_at < last_conversation.created_at,
Conversation.id != last_conversation.id
).order_by(Conversation.created_at.desc()).limit(limit).all()
else:
conversations = base_query.order_by(Conversation.created_at.desc()).limit(limit).all()
# build filters based on sorting
filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation)
base_query = base_query.filter(filter_condition)
base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field)))
conversations = base_query.limit(limit).all()
has_more = False
if len(conversations) == limit:
current_page_first_conversation = conversations[-1]
rest_count = base_query.filter(
Conversation.created_at < current_page_first_conversation.created_at,
Conversation.id != current_page_first_conversation.id
).count()
current_page_last_conversation = conversations[-1]
rest_filter_condition = cls._build_filter_condition(sort_field, sort_direction,
current_page_last_conversation, is_next_page=True)
rest_count = base_query.filter(rest_filter_condition).count()
if rest_count > 0:
has_more = True
@ -69,6 +71,21 @@ class ConversationService:
has_more=has_more
)
@classmethod
def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]:
if sort_by.startswith('-'):
return sort_by[1:], desc
return sort_by, asc
@classmethod
def _build_filter_condition(cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation,
is_next_page: bool = False):
field_value = getattr(reference_conversation, sort_field)
if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page):
return getattr(Conversation, sort_field) < field_value
else:
return getattr(Conversation, sort_field) > field_value
@classmethod
def rename(cls, app_model: App, conversation_id: str,
user: Optional[Union[Account, EndUser]], name: str, auto_generate: bool):
@ -78,6 +95,7 @@ class ConversationService:
return cls.auto_generate_name(app_model, conversation)
else:
conversation.name = name
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
return conversation
@ -87,9 +105,9 @@ class ConversationService:
# get conversation first message
message = db.session.query(Message) \
.filter(
Message.app_id == app_model.id,
Message.conversation_id == conversation.id
).order_by(Message.created_at.asc()).first()
Message.app_id == app_model.id,
Message.conversation_id == conversation.id
).order_by(Message.created_at.asc()).first()
if not message:
raise MessageNotExistsError()

View File

@ -27,6 +27,7 @@ from models.dataset import (
Dataset,
DatasetCollectionBinding,
DatasetPermission,
DatasetPermissionEnum,
DatasetProcessRule,
DatasetQuery,
Document,
@ -80,21 +81,21 @@ class DatasetService:
if permitted_dataset_ids:
query = query.filter(
db.or_(
Dataset.permission == 'all_team_members',
db.and_(Dataset.permission == 'only_me', Dataset.created_by == user.id),
db.and_(Dataset.permission == 'partial_members', Dataset.id.in_(permitted_dataset_ids))
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id),
db.and_(Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, Dataset.id.in_(permitted_dataset_ids))
)
)
else:
query = query.filter(
db.or_(
Dataset.permission == 'all_team_members',
db.and_(Dataset.permission == 'only_me', Dataset.created_by == user.id)
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id)
)
)
else:
# if no user, only show datasets that are shared with all team members
query = query.filter(Dataset.permission == 'all_team_members')
query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
if search:
query = query.filter(Dataset.name.ilike(f'%{search}%'))
@ -330,7 +331,7 @@ class DatasetService:
raise NoPermissionError(
'You do not have permission to access this dataset.'
)
if dataset.permission == 'only_me' and dataset.created_by != user.id:
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
logging.debug(
f'User {user.id} does not have permission to access dataset {dataset.id}'
)
@ -351,11 +352,11 @@ class DatasetService:
@staticmethod
def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None):
if dataset.permission == 'only_me':
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
if dataset.created_by != user.id:
raise NoPermissionError('You do not have permission to access this dataset.')
elif dataset.permission == 'partial_members':
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
if not any(
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all()
):

View File

@ -30,6 +30,7 @@ class ModelProviderService:
"""
Model Provider Service
"""
def __init__(self) -> None:
self.provider_manager = ProviderManager()
@ -387,18 +388,21 @@ class ModelProviderService:
tenant_id=tenant_id,
model_type=model_type_enum
)
return DefaultModelResponse(
model=result.model,
model_type=result.model_type,
provider=SimpleProviderEntityResponse(
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types
)
) if result else None
try:
return DefaultModelResponse(
model=result.model,
model_type=result.model_type,
provider=SimpleProviderEntityResponse(
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types
)
) if result else None
except Exception as e:
logger.info(f"get_default_model_of_model_type error: {e}")
return None
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
"""

View File

@ -1,6 +1,8 @@
import json
import logging
from configs import dify_config
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
@ -43,14 +45,14 @@ class BuiltinToolManageService:
result = []
for tool in tools:
result.append(ToolTransformService.tool_to_user_tool(
tool=tool,
credentials=credentials,
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller)
))
return result
@staticmethod
def list_builtin_provider_credentials_schema(
provider_name
@ -78,7 +80,7 @@ class BuiltinToolManageService:
BuiltinToolProvider.provider == provider_name,
).first()
try:
try:
# get provider
provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials:
@ -119,8 +121,8 @@ class BuiltinToolManageService:
# delete cache
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' }
return {'result': 'success'}
@staticmethod
def get_builtin_tool_provider_credentials(
user_id: str, tenant_id: str, provider: str
@ -135,7 +137,7 @@ class BuiltinToolManageService:
if provider is None:
return {}
provider_controller = ToolManager.get_builtin_provider(provider.provider)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
@ -156,7 +158,7 @@ class BuiltinToolManageService:
if provider is None:
raise ValueError(f'you have not added provider {provider_name}')
db.session.delete(provider)
db.session.commit()
@ -165,8 +167,8 @@ class BuiltinToolManageService:
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' }
return {'result': 'success'}
@staticmethod
def get_builtin_tool_provider_icon(
provider: str
@ -179,7 +181,7 @@ class BuiltinToolManageService:
icon_bytes = f.read()
return icon_bytes, mime_type
@staticmethod
def list_builtin_tools(
user_id: str, tenant_id: str
@ -202,6 +204,15 @@ class BuiltinToolManageService:
for provider_controller in provider_controllers:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider_controller,
name_func=lambda x: x.identity.name
):
continue
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
@ -226,4 +237,3 @@ class BuiltinToolManageService:
raise e
return BuiltinToolProviderSort.sort(result)

View File

@ -13,7 +13,8 @@ class WebConversationService:
@classmethod
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
last_id: Optional[str], limit: int, invoke_from: InvokeFrom,
pinned: Optional[bool] = None) -> InfiniteScrollPagination:
pinned: Optional[bool] = None,
sort_by='-updated_at') -> InfiniteScrollPagination:
include_ids = None
exclude_ids = None
if pinned is not None:
@ -36,6 +37,7 @@ class WebConversationService:
invoke_from=invoke_from,
include_ids=include_ids,
exclude_ids=exclude_ids,
sort_by=sort_by
)
@classmethod

View File

@ -32,12 +32,9 @@ class WorkflowConverter:
App Convert to Workflow Mode
"""
def convert_to_workflow(self, app_model: App,
account: Account,
name: str,
icon_type: str,
icon: str,
icon_background: str) -> App:
def convert_to_workflow(
self, app_model: App, account: Account, name: str, icon_type: str, icon: str, icon_background: str
):
"""
Convert app to workflow
@ -56,18 +53,18 @@ class WorkflowConverter:
:return: new App instance
"""
# convert app model config
if not app_model.app_model_config:
raise ValueError("App model config is required")
workflow = self.convert_app_model_config_to_workflow(
app_model=app_model,
app_model_config=app_model.app_model_config,
account_id=account.id
app_model=app_model, app_model_config=app_model.app_model_config, account_id=account.id
)
# create new app
new_app = App()
new_app.tenant_id = app_model.tenant_id
new_app.name = name if name else app_model.name + '(workflow)'
new_app.mode = AppMode.ADVANCED_CHAT.value \
if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
new_app.name = name if name else app_model.name + "(workflow)"
new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
new_app.icon_type = icon_type if icon_type else app_model.icon_type
new_app.icon = icon if icon else app_model.icon
new_app.icon_background = icon_background if icon_background else app_model.icon_background
@ -88,30 +85,21 @@ class WorkflowConverter:
return new_app
def convert_app_model_config_to_workflow(self, app_model: App,
app_model_config: AppModelConfig,
account_id: str) -> Workflow:
def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: AppModelConfig, account_id: str):
"""
Convert app model config to workflow mode
:param app_model: App instance
:param app_model_config: AppModelConfig instance
:param account_id: Account ID
:return:
"""
# get new app mode
new_app_mode = self._get_new_app_mode(app_model)
# convert app model config
app_config = self._convert_to_app_config(
app_model=app_model,
app_model_config=app_model_config
)
app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config)
# init workflow graph
graph = {
"nodes": [],
"edges": []
}
graph = {"nodes": [], "edges": []}
# Convert list:
# - variables -> start
@ -123,11 +111,9 @@ class WorkflowConverter:
# - show_retrieve_source -> knowledge-retrieval
# convert to start node
start_node = self._convert_to_start_node(
variables=app_config.variables
)
start_node = self._convert_to_start_node(variables=app_config.variables)
graph['nodes'].append(start_node)
graph["nodes"].append(start_node)
# convert to http request node
external_data_variable_node_mapping = {}
@ -135,7 +121,7 @@ class WorkflowConverter:
http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node(
app_model=app_model,
variables=app_config.variables,
external_data_variables=app_config.external_data_variables
external_data_variables=app_config.external_data_variables,
)
for http_request_node in http_request_nodes:
@ -144,9 +130,7 @@ class WorkflowConverter:
# convert to knowledge retrieval node
if app_config.dataset:
knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node(
new_app_mode=new_app_mode,
dataset_config=app_config.dataset,
model_config=app_config.model
new_app_mode=new_app_mode, dataset_config=app_config.dataset, model_config=app_config.model
)
if knowledge_retrieval_node:
@ -160,7 +144,7 @@ class WorkflowConverter:
model_config=app_config.model,
prompt_template=app_config.prompt_template,
file_upload=app_config.additional_features.file_upload,
external_data_variable_node_mapping=external_data_variable_node_mapping
external_data_variable_node_mapping=external_data_variable_node_mapping,
)
graph = self._append_node(graph, llm_node)
@ -199,7 +183,7 @@ class WorkflowConverter:
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type=WorkflowType.from_app_mode(new_app_mode).value,
version='draft',
version="draft",
graph=json.dumps(graph),
features=json.dumps(features),
created_by=account_id,
@ -212,24 +196,18 @@ class WorkflowConverter:
return workflow
def _convert_to_app_config(self, app_model: App,
app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
app_mode = AppMode.value_of(app_model.mode)
if app_mode == AppMode.AGENT_CHAT or app_model.is_agent:
app_model.mode = AppMode.AGENT_CHAT.value
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config
app_model=app_model, app_model_config=app_model_config
)
elif app_mode == AppMode.CHAT:
app_config = ChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config
)
app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
elif app_mode == AppMode.COMPLETION:
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config
app_model=app_model, app_model_config=app_model_config
)
else:
raise ValueError("Invalid app mode")
@ -248,14 +226,13 @@ class WorkflowConverter:
"data": {
"title": "START",
"type": NodeType.START.value,
"variables": [jsonable_encoder(v) for v in variables]
}
"variables": [jsonable_encoder(v) for v in variables],
},
}
def _convert_to_http_request_node(self, app_model: App,
variables: list[VariableEntity],
external_data_variables: list[ExternalDataVariableEntity]) \
-> tuple[list[dict], dict[str, str]]:
def _convert_to_http_request_node(
self, app_model: App, variables: list[VariableEntity], external_data_variables: list[ExternalDataVariableEntity]
) -> tuple[list[dict], dict[str, str]]:
"""
Convert API Based Extension to HTTP Request Node
:param app_model: App instance
@ -277,40 +254,33 @@ class WorkflowConverter:
# get params from config
api_based_extension_id = tool_config.get("api_based_extension_id")
if not api_based_extension_id:
continue
# get api_based_extension
api_based_extension = self._get_api_based_extension(
tenant_id=tenant_id,
api_based_extension_id=api_based_extension_id
tenant_id=tenant_id, api_based_extension_id=api_based_extension_id
)
if not api_based_extension:
raise ValueError("[External data tool] API query failed, variable: {}, "
"error: api_based_extension_id is invalid"
.format(tool_variable))
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=tenant_id,
token=api_based_extension.api_key
)
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=api_based_extension.api_key)
inputs = {}
for v in variables:
inputs[v.variable] = '{{#start.' + v.variable + '#}}'
inputs[v.variable] = "{{#start." + v.variable + "#}}"
request_body = {
'point': APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value,
'params': {
'app_id': app_model.id,
'tool_variable': tool_variable,
'inputs': inputs,
'query': '{{#sys.query#}}' if app_model.mode == AppMode.CHAT.value else ''
}
"point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value,
"params": {
"app_id": app_model.id,
"tool_variable": tool_variable,
"inputs": inputs,
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "",
},
}
request_body_json = json.dumps(request_body)
request_body_json = request_body_json.replace(r'\{\{', '{{').replace(r'\}\}', '}}')
request_body_json = request_body_json.replace(r"\{\{", "{{").replace(r"\}\}", "}}")
http_request_node = {
"id": f"http_request_{index}",
@ -320,20 +290,11 @@ class WorkflowConverter:
"type": NodeType.HTTP_REQUEST.value,
"method": "post",
"url": api_based_extension.api_endpoint,
"authorization": {
"type": "api-key",
"config": {
"type": "bearer",
"api_key": api_key
}
},
"authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}},
"headers": "",
"params": "",
"body": {
"type": "json",
"data": request_body_json
}
}
"body": {"type": "json", "data": request_body_json},
},
}
nodes.append(http_request_node)
@ -345,32 +306,24 @@ class WorkflowConverter:
"data": {
"title": f"Parse {api_based_extension.name} Response",
"type": NodeType.CODE.value,
"variables": [{
"variable": "response_json",
"value_selector": [http_request_node['id'], "body"]
}],
"variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}],
"code_language": "python3",
"code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads("
"response_json)\n return {\n \"result\": response_body[\"result\"]\n }",
"outputs": {
"result": {
"type": "string"
}
}
}
'response_json)\n return {\n "result": response_body["result"]\n }',
"outputs": {"result": {"type": "string"}},
},
}
nodes.append(code_node)
external_data_variable_node_mapping[external_data_variable.variable] = code_node['id']
external_data_variable_node_mapping[external_data_variable.variable] = code_node["id"]
index += 1
return nodes, external_data_variable_node_mapping
def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode,
dataset_config: DatasetEntity,
model_config: ModelConfigEntity) \
-> Optional[dict]:
def _convert_to_knowledge_retrieval_node(
self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity
) -> Optional[dict]:
"""
Convert datasets to Knowledge Retrieval Node
:param new_app_mode: new app mode
@ -404,7 +357,7 @@ class WorkflowConverter:
"completion_params": {
**model_config.parameters,
"stop": model_config.stop,
}
},
}
}
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
@ -412,20 +365,23 @@ class WorkflowConverter:
"multiple_retrieval_config": {
"top_k": retrieve_config.top_k,
"score_threshold": retrieve_config.score_threshold,
"reranking_model": retrieve_config.reranking_model
"reranking_model": retrieve_config.reranking_model,
}
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE
else None,
}
},
}
def _convert_to_llm_node(self, original_app_mode: AppMode,
new_app_mode: AppMode,
graph: dict,
model_config: ModelConfigEntity,
prompt_template: PromptTemplateEntity,
file_upload: Optional[FileExtraConfig] = None,
external_data_variable_node_mapping: dict[str, str] = None) -> dict:
def _convert_to_llm_node(
self,
original_app_mode: AppMode,
new_app_mode: AppMode,
graph: dict,
model_config: ModelConfigEntity,
prompt_template: PromptTemplateEntity,
file_upload: Optional[FileExtraConfig] = None,
external_data_variable_node_mapping: dict[str, str] | None = None,
) -> dict:
"""
Convert to LLM Node
:param original_app_mode: original app mode
@ -437,17 +393,18 @@ class WorkflowConverter:
:param external_data_variable_node_mapping: external data variable node mapping
"""
# fetch start and knowledge retrieval node
start_node = next(filter(lambda n: n['data']['type'] == NodeType.START.value, graph['nodes']))
knowledge_retrieval_node = next(filter(
lambda n: n['data']['type'] == NodeType.KNOWLEDGE_RETRIEVAL.value,
graph['nodes']
), None)
start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START.value, graph["nodes"]))
knowledge_retrieval_node = next(
filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None
)
role_prefix = None
# Chat Model
if model_config.mode == LLMMode.CHAT.value:
if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
if not prompt_template.simple_prompt_template:
raise ValueError("Simple prompt template is required")
# get prompt template
prompt_transform = SimplePromptTransform()
prompt_template_config = prompt_transform.get_prompt_template(
@ -456,45 +413,35 @@ class WorkflowConverter:
model=model_config.model,
pre_prompt=prompt_template.simple_prompt_template,
has_context=knowledge_retrieval_node is not None,
query_in_prompt=False
query_in_prompt=False,
)
template = prompt_template_config['prompt_template'].template
template = prompt_template_config["prompt_template"].template
if not template:
prompts = []
else:
template = self._replace_template_variables(
template,
start_node['data']['variables'],
external_data_variable_node_mapping
template, start_node["data"]["variables"], external_data_variable_node_mapping
)
prompts = [
{
"role": 'user',
"text": template
}
]
prompts = [{"role": "user", "text": template}]
else:
advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template
prompts = []
for m in advanced_chat_prompt_template.messages:
if advanced_chat_prompt_template:
if advanced_chat_prompt_template:
for m in advanced_chat_prompt_template.messages:
text = m.text
text = self._replace_template_variables(
text,
start_node['data']['variables'],
external_data_variable_node_mapping
text, start_node["data"]["variables"], external_data_variable_node_mapping
)
prompts.append({
"role": m.role.value,
"text": text
})
prompts.append({"role": m.role.value, "text": text})
# Completion Model
else:
if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
if not prompt_template.simple_prompt_template:
raise ValueError("Simple prompt template is required")
# get prompt template
prompt_transform = SimplePromptTransform()
prompt_template_config = prompt_transform.get_prompt_template(
@ -503,57 +450,50 @@ class WorkflowConverter:
model=model_config.model,
pre_prompt=prompt_template.simple_prompt_template,
has_context=knowledge_retrieval_node is not None,
query_in_prompt=False
query_in_prompt=False,
)
template = prompt_template_config['prompt_template'].template
template = prompt_template_config["prompt_template"].template
template = self._replace_template_variables(
template,
start_node['data']['variables'],
external_data_variable_node_mapping
template=template,
variables=start_node["data"]["variables"],
external_data_variable_node_mapping=external_data_variable_node_mapping,
)
prompts = {
"text": template
}
prompts = {"text": template}
prompt_rules = prompt_template_config['prompt_rules']
prompt_rules = prompt_template_config["prompt_rules"]
role_prefix = {
"user": prompt_rules.get('human_prefix', 'Human'),
"assistant": prompt_rules.get('assistant_prefix', 'Assistant')
"user": prompt_rules.get("human_prefix", "Human"),
"assistant": prompt_rules.get("assistant_prefix", "Assistant"),
}
else:
advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template
if advanced_completion_prompt_template:
text = advanced_completion_prompt_template.prompt
text = self._replace_template_variables(
text,
start_node['data']['variables'],
external_data_variable_node_mapping
template=text,
variables=start_node["data"]["variables"],
external_data_variable_node_mapping=external_data_variable_node_mapping,
)
else:
text = ""
text = text.replace('{{#query#}}', '{{#sys.query#}}')
text = text.replace("{{#query#}}", "{{#sys.query#}}")
prompts = {
"text": text,
}
if advanced_completion_prompt_template.role_prefix:
if advanced_completion_prompt_template and advanced_completion_prompt_template.role_prefix:
role_prefix = {
"user": advanced_completion_prompt_template.role_prefix.user,
"assistant": advanced_completion_prompt_template.role_prefix.assistant
"assistant": advanced_completion_prompt_template.role_prefix.assistant,
}
memory = None
if new_app_mode == AppMode.ADVANCED_CHAT:
memory = {
"role_prefix": role_prefix,
"window": {
"enabled": False
}
}
memory = {"role_prefix": role_prefix, "window": {"enabled": False}}
completion_params = model_config.parameters
completion_params.update({"stop": model_config.stop})
@ -567,28 +507,29 @@ class WorkflowConverter:
"provider": model_config.provider,
"name": model_config.model,
"mode": model_config.mode,
"completion_params": completion_params
"completion_params": completion_params,
},
"prompt_template": prompts,
"memory": memory,
"context": {
"enabled": knowledge_retrieval_node is not None,
"variable_selector": ["knowledge_retrieval", "result"]
if knowledge_retrieval_node is not None else None
if knowledge_retrieval_node is not None
else None,
},
"vision": {
"enabled": file_upload is not None,
"variable_selector": ["sys", "files"] if file_upload is not None else None,
"configs": {
"detail": file_upload.image_config['detail']
} if file_upload is not None else None
}
}
"configs": {"detail": file_upload.image_config["detail"]}
if file_upload is not None and file_upload.image_config is not None
else None,
},
},
}
def _replace_template_variables(self, template: str,
variables: list[dict],
external_data_variable_node_mapping: dict[str, str] = None) -> str:
def _replace_template_variables(
self, template: str, variables: list[dict], external_data_variable_node_mapping: dict[str, str] | None = None
) -> str:
"""
Replace Template Variables
:param template: template
@ -597,12 +538,11 @@ class WorkflowConverter:
:return:
"""
for v in variables:
template = template.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}')
template = template.replace("{{" + v["variable"] + "}}", "{{#start." + v["variable"] + "#}}")
if external_data_variable_node_mapping:
for variable, code_node_id in external_data_variable_node_mapping.items():
template = template.replace('{{' + variable + '}}',
'{{#' + code_node_id + '.result#}}')
template = template.replace("{{" + variable + "}}", "{{#" + code_node_id + ".result#}}")
return template
@ -618,11 +558,8 @@ class WorkflowConverter:
"data": {
"title": "END",
"type": NodeType.END.value,
"outputs": [{
"variable": "result",
"value_selector": ["llm", "text"]
}]
}
"outputs": [{"variable": "result", "value_selector": ["llm", "text"]}],
},
}
def _convert_to_answer_node(self) -> dict:
@ -634,11 +571,7 @@ class WorkflowConverter:
return {
"id": "answer",
"position": None,
"data": {
"title": "ANSWER",
"type": NodeType.ANSWER.value,
"answer": "{{#llm.text#}}"
}
"data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"},
}
def _create_edge(self, source: str, target: str) -> dict:
@ -648,11 +581,7 @@ class WorkflowConverter:
:param target: target node id
:return:
"""
return {
"id": f"{source}-{target}",
"source": source,
"target": target
}
return {"id": f"{source}-{target}", "source": source, "target": target}
def _append_node(self, graph: dict, node: dict) -> dict:
"""
@ -662,9 +591,9 @@ class WorkflowConverter:
:param node: Node to append
:return:
"""
previous_node = graph['nodes'][-1]
graph['nodes'].append(node)
graph['edges'].append(self._create_edge(previous_node['id'], node['id']))
previous_node = graph["nodes"][-1]
graph["nodes"].append(node)
graph["edges"].append(self._create_edge(previous_node["id"], node["id"]))
return graph
def _get_new_app_mode(self, app_model: App) -> AppMode:
@ -678,14 +607,20 @@ class WorkflowConverter:
else:
return AppMode.ADVANCED_CHAT
def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str):
"""
Get API Based Extension
:param tenant_id: tenant id
:param api_based_extension_id: api based extension id
:return:
"""
return db.session.query(APIBasedExtension).filter(
APIBasedExtension.tenant_id == tenant_id,
APIBasedExtension.id == api_based_extension_id
).first()
api_based_extension = (
db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)
if not api_based_extension:
raise ValueError(f"API Based Extension not found, id: {api_based_extension_id}")
return api_based_extension

View File

@ -5,7 +5,7 @@ from core.model_runtime.entities.text_embedding_entities import TextEmbeddingRes
from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import WenxinTextEmbeddingModel
def test_invoke_embedding_model():
def test_invoke_embedding_v1():
sleep(3)
model = WenxinTextEmbeddingModel()
@ -21,4 +21,61 @@ def test_invoke_embedding_model():
assert isinstance(response, TextEmbeddingResult)
assert len(response.embeddings) == 3
assert isinstance(response.embeddings[0], list)
assert isinstance(response.embeddings[0], list)
def test_invoke_embedding_bge_large_en():
sleep(3)
model = WenxinTextEmbeddingModel()
response = model.invoke(
model='bge-large-en',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
texts=['hello', '你好', 'xxxxx'],
user="abc-123"
)
assert isinstance(response, TextEmbeddingResult)
assert len(response.embeddings) == 3
assert isinstance(response.embeddings[0], list)
def test_invoke_embedding_bge_large_zh():
sleep(3)
model = WenxinTextEmbeddingModel()
response = model.invoke(
model='bge-large-zh',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
texts=['hello', '你好', 'xxxxx'],
user="abc-123"
)
assert isinstance(response, TextEmbeddingResult)
assert len(response.embeddings) == 3
assert isinstance(response.embeddings[0], list)
def test_invoke_embedding_tao_8k():
sleep(3)
model = WenxinTextEmbeddingModel()
response = model.invoke(
model='tao-8k',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
texts=['hello', '你好', 'xxxxx'],
user="abc-123"
)
assert isinstance(response, TextEmbeddingResult)
assert len(response.embeddings) == 3
assert isinstance(response.embeddings[0], list)

View File

@ -21,10 +21,6 @@ class PGVectorTest(AbstractVectorTest):
),
)
def search_by_full_text(self):
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def test_pgvector(setup_mock_redis):
PGVectorTest().run_all_tests()

View File

@ -6,14 +6,13 @@ from _pytest.monkeypatch import MonkeyPatch
from jinja2 import Template
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
from core.helper.code_executor.entities import CodeDependency
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
class MockedCodeExecutor:
@classmethod
def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'],
code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
code: str, inputs: dict) -> dict:
# invoke directly
match language:
case CodeLanguage.PYTHON3:
@ -24,6 +23,8 @@ class MockedCodeExecutor:
return {
"result": Template(code).render(inputs)
}
case _:
raise Exception("Language not supported")
@pytest.fixture
def setup_code_executor_mock(request, monkeypatch: MonkeyPatch):

View File

@ -28,14 +28,6 @@ def test_javascript_with_code_template():
inputs={'arg1': 'Hello', 'arg2': 'World'})
assert result == {'result': 'HelloWorld'}
def test_javascript_list_default_available_packages():
packages = JavascriptCodeProvider.get_default_available_packages()
# no default packages available for javascript
assert len(packages) == 0
def test_javascript_get_runner_script():
runner_script = NodeJsTemplateTransformer.get_runner_script()
assert runner_script.count(NodeJsTemplateTransformer._code_placeholder) == 1

View File

@ -29,15 +29,6 @@ def test_python3_with_code_template():
assert result == {'result': 'HelloWorld'}
def test_python3_list_default_available_packages():
packages = Python3CodeProvider.get_default_available_packages()
assert len(packages) > 0
assert {'requests', 'httpx'}.issubset(p['name'] for p in packages)
# check JSON serializable
assert len(str(json.dumps(packages))) > 0
def test_python3_get_runner_script():
runner_script = Python3TemplateTransformer.get_runner_script()
assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1

View File

@ -11,7 +11,7 @@ from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers import ModelProviderFactory
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.llm.llm_node import LLMNode
from extensions.ext_database import db
@ -66,10 +66,10 @@ def test_execute_llm(setup_openai_mock):
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather today?',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather today?',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['abc', 'output'], 'sunny')
@ -181,10 +181,10 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather today?',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather today?',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['abc', 'output'], 'sunny')

View File

@ -13,7 +13,7 @@ from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from extensions.ext_database import db
@ -119,10 +119,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock):
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
result = node.run(pool)
@ -177,10 +177,10 @@ def test_instructions(setup_openai_mock):
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
result = node.run(pool)
@ -243,10 +243,10 @@ def test_chat_parameter_extractor(setup_anthropic_mock):
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
result = node.run(pool)
@ -307,10 +307,10 @@ def test_completion_parameter_extractor(setup_openai_mock):
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
result = node.run(pool)
@ -420,10 +420,10 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
result = node.run(pool)

View File

@ -1,13 +1,13 @@
from core.app.segments import SecretVariable, StringSegment, parser
from core.helper import encrypter
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
def test_segment_group_to_text():
variable_pool = VariablePool(
system_variables={
SystemVariable('user_id'): 'fake-user-id',
SystemVariableKey('user_id'): 'fake-user-id',
},
user_inputs={},
environment_variables=[
@ -42,7 +42,7 @@ def test_convert_constant_to_segment_group():
def test_convert_variable_to_segment_group():
variable_pool = VariablePool(
system_variables={
SystemVariable('user_id'): 'fake-user-id',
SystemVariableKey('user_id'): 'fake-user-id',
},
user_inputs={},
environment_variables=[],

View File

@ -2,7 +2,7 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import UserFrom
from extensions.ext_database import db
@ -29,8 +29,8 @@ def test_execute_answer():
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'weather'], 'sunny')
pool.add(['llm', 'text'], 'You are a helpful AI.')

View File

@ -2,7 +2,7 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from extensions.ext_database import db
@ -119,8 +119,8 @@ def test_execute_if_else_result_true():
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
@ -182,8 +182,8 @@ def test_execute_if_else_result_false():
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['1ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ab', 'def'])

View File

@ -4,7 +4,7 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.segments import ArrayStringVariable, StringVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
@ -42,7 +42,7 @@ def test_overwrite_string_variable():
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@ -52,7 +52,7 @@ def test_overwrite_string_variable():
input_variable,
)
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
node.run(variable_pool)
mock_run.assert_called_once()
@ -93,7 +93,7 @@ def test_append_variable_to_array():
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@ -103,7 +103,7 @@ def test_append_variable_to_array():
input_variable,
)
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
node.run(variable_pool)
mock_run.assert_called_once()
@ -137,7 +137,7 @@ def test_clear_array():
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],

View File

@ -11,7 +11,17 @@ def test_environment_variables():
contexts.tenant_id.set('tenant_id')
# Create a Workflow instance
workflow = Workflow()
workflow = Workflow(
tenant_id='tenant_id',
app_id='app_id',
type='workflow',
version='draft',
graph='{}',
features='{}',
created_by='account_id',
environment_variables=[],
conversation_variables=[],
)
# Create some EnvironmentVariable instances
variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())})
@ -35,7 +45,17 @@ def test_update_environment_variables():
contexts.tenant_id.set('tenant_id')
# Create a Workflow instance
workflow = Workflow()
workflow = Workflow(
tenant_id='tenant_id',
app_id='app_id',
type='workflow',
version='draft',
graph='{}',
features='{}',
created_by='account_id',
environment_variables=[],
conversation_variables=[],
)
# Create some EnvironmentVariable instances
variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())})
@ -70,9 +90,17 @@ def test_to_dict():
contexts.tenant_id.set('tenant_id')
# Create a Workflow instance
workflow = Workflow()
workflow.graph = '{}'
workflow.features = '{}'
workflow = Workflow(
tenant_id='tenant_id',
app_id='app_id',
type='workflow',
version='draft',
graph='{}',
features='{}',
created_by='account_id',
environment_variables=[],
conversation_variables=[],
)
# Create some EnvironmentVariable instances

Some files were not shown because too many files have changed in this diff Show More