mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 12:09:27 +08:00
Compare commits
47 Commits
fix/note-n
...
feat/updat
| Author | SHA1 | Date | |
|---|---|---|---|
| 61b54c4b0e | |||
| fef4e09dfc | |||
| 60ef7ba855 | |||
| 6f968bafb2 | |||
| 9f6aab11d4 | |||
| 0006c6f0fd | |||
| 2c427e04be | |||
| f53454f81d | |||
| 784b11ce19 | |||
| 715eb8fa32 | |||
| a02118d5bc | |||
| 85fc0fdb51 | |||
| f7af8c7cc7 | |||
| 0c99a3d0c5 | |||
| 66dfb5c89a | |||
| 6435b4eb44 | |||
| 4e7b6aec3a | |||
| 6c25d7bed3 | |||
| 028fd52c9b | |||
| 9a715f6b68 | |||
| 8c32f8c77d | |||
| b7778de224 | |||
| c70d69322b | |||
| e35e251863 | |||
| eae53e11e6 | |||
| 4f5f27cf2b | |||
| 5e42e90abc | |||
| a10b207de2 | |||
| e2d214e030 | |||
| 4f64a5d36d | |||
| 0d4753785f | |||
| 2e9084f369 | |||
| 0f90e6df75 | |||
| 53146ad685 | |||
| 0223fc6fd5 | |||
| 218380ba43 | |||
| afd23f7ad8 | |||
| 6991a243aa | |||
| 1f944c6eeb | |||
| 31f9977411 | |||
| 3d27d15f00 | |||
| ab6499e5b7 | |||
| 4ff4859036 | |||
| 53cf756207 | |||
| 0087afc2e3 | |||
| bd07e1d2fd | |||
| 8b06105fa1 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -178,3 +178,4 @@ pyrightconfig.json
|
||||
api/.vscode
|
||||
|
||||
.idea/
|
||||
.vscode
|
||||
73
README.md
73
README.md
@ -4,7 +4,7 @@
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Self-hosting</a> ·
|
||||
<a href="https://docs.dify.ai">Documentation</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">Enterprise inquiry</a>
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">Enterprise Inquiry</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
@ -41,41 +41,36 @@
|
||||
<a href="./README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
|
||||
Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features:
|
||||
Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, Agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features:
|
||||
</br> </br>
|
||||
|
||||
**1. Workflow**:
|
||||
Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond.
|
||||
**1. Workflow**:
|
||||
Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond.
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
|
||||
|
||||
**2. Comprehensive model support**:
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
**2. Comprehensive model support**:
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
**3. Prompt IDE**:
|
||||
Intuitive interface for crafting prompts, comparing model performance, and adding additional features such as text-to-speech to a chat-based app.
|
||||
|
||||
**3. Prompt IDE**:
|
||||
Intuitive interface for crafting prompts, comparing model performance, and adding additional features such as text-to-speech to a chat-based app.
|
||||
**4. RAG Pipeline**:
|
||||
Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats.
|
||||
|
||||
**4. RAG Pipeline**:
|
||||
Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats.
|
||||
**5. Agent capabilities**:
|
||||
You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DALL·E, Stable Diffusion and WolframAlpha.
|
||||
|
||||
**5. Agent capabilities**:
|
||||
You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DALL·E, Stable Diffusion and WolframAlpha.
|
||||
|
||||
**6. LLMOps**:
|
||||
Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations.
|
||||
|
||||
**7. Backend-as-a-Service**:
|
||||
All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic.
|
||||
**6. LLMOps**:
|
||||
Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations.
|
||||
|
||||
**7. Backend-as-a-Service**:
|
||||
All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic.
|
||||
|
||||
## Feature comparison
|
||||
|
||||
<table style="width: 100%;">
|
||||
<tr>
|
||||
<th align="center">Feature</th>
|
||||
@ -145,30 +140,28 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
## Using Dify
|
||||
|
||||
- **Cloud </br>**
|
||||
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
|
||||
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
|
||||
|
||||
- **Self-hosting Dify Community Edition</br>**
|
||||
Quickly get Dify running in your environment with this [starter guide](#quick-start).
|
||||
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
||||
Quickly get Dify running in your environment with this [starter guide](#quick-start).
|
||||
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
||||
|
||||
- **Dify for enterprise / organizations</br>**
|
||||
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs. </br>
|
||||
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs. </br>
|
||||
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
|
||||
|
||||
|
||||
## Staying ahead
|
||||
|
||||
Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
## Quick start
|
||||
|
||||
> Before installing Dify, make sure your machine meets the following minimum system requirements:
|
||||
>
|
||||
>- CPU >= 2 Core
|
||||
>- RAM >= 4GB
|
||||
>
|
||||
> - CPU >= 2 Core
|
||||
> - RAM >= 4GB
|
||||
|
||||
</br>
|
||||
|
||||
@ -197,15 +190,16 @@ If you'd like to configure a highly-available setup, there are community-contrib
|
||||
#### Using Terraform for Deployment
|
||||
|
||||
##### Azure Global
|
||||
|
||||
Deploy Dify to Azure with a single click using [terraform](https://www.terraform.io/).
|
||||
|
||||
- [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform)
|
||||
|
||||
## Contributing
|
||||
|
||||
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
At the same time, please consider supporting Dify by sharing it on social media and at events and conferences.
|
||||
|
||||
|
||||
> We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
**Contributors**
|
||||
@ -216,16 +210,15 @@ At the same time, please consider supporting Dify by sharing it on social media
|
||||
|
||||
## Community & contact
|
||||
|
||||
* [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions.
|
||||
* [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
* [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.
|
||||
* [Twitter](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community.
|
||||
- [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions.
|
||||
- [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
- [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.
|
||||
- [Twitter](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community.
|
||||
|
||||
## Star history
|
||||
|
||||
[](https://star-history.com/#langgenius/dify&Date)
|
||||
|
||||
|
||||
## Security disclosure
|
||||
|
||||
To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer.
|
||||
|
||||
@ -247,8 +247,8 @@ API_TOOL_DEFAULT_READ_TIMEOUT=60
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT=600
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT=600
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 # 10MB
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 # 1MB
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
|
||||
# Log file path
|
||||
LOG_FILE=
|
||||
@ -267,4 +267,13 @@ APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
|
||||
# Celery beat configuration
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
||||
|
||||
# Position configuration
|
||||
POSITION_TOOL_PINS=
|
||||
POSITION_TOOL_INCLUDES=
|
||||
POSITION_TOOL_EXCLUDES=
|
||||
|
||||
POSITION_PROVIDER_PINS=
|
||||
POSITION_PROVIDER_INCLUDES=
|
||||
POSITION_PROVIDER_EXCLUDES=
|
||||
|
||||
0
.idea/icon.png → api/.idea/icon.png
generated
0
.idea/icon.png → api/.idea/icon.png
generated
|
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.7 KiB |
0
.idea/vcs.xml → api/.idea/vcs.xml
generated
0
.idea/vcs.xml → api/.idea/vcs.xml
generated
@ -5,8 +5,8 @@
|
||||
"name": "Python: Flask",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "${workspaceFolder}/api/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
"python": "${workspaceFolder}/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"envFile": ".env",
|
||||
"module": "flask",
|
||||
"justMyCode": true,
|
||||
@ -18,15 +18,15 @@
|
||||
"args": [
|
||||
"run",
|
||||
"--host=0.0.0.0",
|
||||
"--port=5001",
|
||||
"--port=5001"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Python: Celery",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "${workspaceFolder}/api/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
"python": "${workspaceFolder}/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"module": "celery",
|
||||
"justMyCode": true,
|
||||
"envFile": ".env",
|
||||
@ -37,6 +37,8 @@ class DifyConfig(
|
||||
|
||||
CODE_MAX_NUMBER: int = 9223372036854775807
|
||||
CODE_MIN_NUMBER: int = -9223372036854775808
|
||||
CODE_MAX_DEPTH: int = 5
|
||||
CODE_MAX_PRECISION: int = 20
|
||||
CODE_MAX_STRING_LENGTH: int = 80000
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: int = 30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30
|
||||
|
||||
@ -406,6 +406,7 @@ class DataSetConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
Workspace configs
|
||||
@ -442,6 +443,63 @@ class CeleryBeatConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
|
||||
POSITION_PROVIDER_PINS: str = Field(
|
||||
description='The heads of model providers',
|
||||
default='',
|
||||
)
|
||||
|
||||
POSITION_PROVIDER_INCLUDES: str = Field(
|
||||
description='The included model providers',
|
||||
default='',
|
||||
)
|
||||
|
||||
POSITION_PROVIDER_EXCLUDES: str = Field(
|
||||
description='The excluded model providers',
|
||||
default='',
|
||||
)
|
||||
|
||||
POSITION_TOOL_PINS: str = Field(
|
||||
description='The heads of tools',
|
||||
default='',
|
||||
)
|
||||
|
||||
POSITION_TOOL_INCLUDES: str = Field(
|
||||
description='The included tools',
|
||||
default='',
|
||||
)
|
||||
|
||||
POSITION_TOOL_EXCLUDES: str = Field(
|
||||
description='The excluded tools',
|
||||
default='',
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
|
||||
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != '']
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
|
||||
return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != '']
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@ -466,6 +524,7 @@ class FeatureConfig(
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
WorkspaceConfig,
|
||||
PositionConfig,
|
||||
|
||||
# hosted services config
|
||||
HostedServiceConfig,
|
||||
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description='Dify version',
|
||||
default='0.7.0',
|
||||
default='0.7.1',
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -154,6 +154,8 @@ class ChatConversationApi(Resource):
|
||||
parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args')
|
||||
parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
|
||||
required=False, default='-updated_at', location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
subquery = (
|
||||
@ -225,7 +227,17 @@ class ChatConversationApi(Resource):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
match args['sort_by']:
|
||||
case 'created_at':
|
||||
query = query.order_by(Conversation.created_at.asc())
|
||||
case '-created_at':
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
case 'updated_at':
|
||||
query = query.order_by(Conversation.updated_at.asc())
|
||||
case '-updated_at':
|
||||
query = query.order_by(Conversation.updated_at.desc())
|
||||
case _:
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(
|
||||
query,
|
||||
|
||||
@ -24,7 +24,7 @@ from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
|
||||
from models.model import ApiToken, UploadFile
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
@ -202,7 +202,7 @@ class DatasetApi(Resource):
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members', 'partial_members'), help='Invalid permission.'
|
||||
DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.'
|
||||
)
|
||||
parser.add_argument('embedding_model', type=str,
|
||||
location='json', help='Invalid embedding model.')
|
||||
@ -239,7 +239,7 @@ class DatasetApi(Resource):
|
||||
tenant_id, dataset_id_str, data.get('partial_member_list')
|
||||
)
|
||||
# clear partial member list when permission is only_me or all_team_members
|
||||
elif data.get('permission') == 'only_me' or data.get('permission') == 'all_team_members':
|
||||
elif data.get('permission') == DatasetPermissionEnum.ONLY_ME or data.get('permission') == DatasetPermissionEnum.ALL_TEAM:
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
@ -573,13 +573,13 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
match vector_type:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
|
||||
@ -25,6 +25,8 @@ class ConversationApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
|
||||
required=False, default='-updated_at', location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -33,7 +35,8 @@ class ConversationApi(Resource):
|
||||
user=end_user,
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
invoke_from=InvokeFrom.SERVICE_API
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=args['sort_by']
|
||||
)
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
@ -10,7 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
||||
@ -78,6 +78,8 @@ class DatasetListApi(DatasetApiResource):
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.', required=False, nullable=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -85,7 +87,8 @@ class DatasetListApi(DatasetApiResource):
|
||||
tenant_id=tenant_id,
|
||||
name=args['name'],
|
||||
indexing_technique=args['indexing_technique'],
|
||||
account=current_user
|
||||
account=current_user,
|
||||
permission=args['permission']
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
@ -26,6 +26,8 @@ class ConversationListApi(WebApiResource):
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
|
||||
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
|
||||
required=False, default='-updated_at', location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
pinned = None
|
||||
@ -40,6 +42,7 @@ class ConversationListApi(WebApiResource):
|
||||
limit=args['limit'],
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=pinned,
|
||||
sort_by=args['sort_by']
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
@ -93,7 +93,7 @@ class DatasetConfigManager:
|
||||
reranking_model=dataset_configs.get('reranking_model'),
|
||||
weights=dataset_configs.get('weights'),
|
||||
reranking_enabled=dataset_configs.get('reranking_enabled', True),
|
||||
rerank_mode=dataset_configs["reranking_mode"],
|
||||
rerank_mode=dataset_configs.get('rerank_mode', 'reranking_model'),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import re
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ class BasicVariablesConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
external_data_variables = []
|
||||
variables = []
|
||||
variable_entities = []
|
||||
|
||||
# old external_data_tools
|
||||
external_data_tools = config.get('external_data_tools', [])
|
||||
@ -30,50 +30,41 @@ class BasicVariablesConfigManager:
|
||||
)
|
||||
|
||||
# variables and external_data_tools
|
||||
for variable in config.get('user_input_form', []):
|
||||
typ = list(variable.keys())[0]
|
||||
if typ == 'external_data_tool':
|
||||
val = variable[typ]
|
||||
if 'config' not in val:
|
||||
for variables in config.get('user_input_form', []):
|
||||
variable_type = list(variables.keys())[0]
|
||||
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
|
||||
variable = variables[variable_type]
|
||||
if 'config' not in variable:
|
||||
continue
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=val['variable'],
|
||||
type=val['type'],
|
||||
config=val['config']
|
||||
variable=variable['variable'],
|
||||
type=variable['type'],
|
||||
config=variable['config']
|
||||
)
|
||||
)
|
||||
elif typ in [
|
||||
VariableEntity.Type.TEXT_INPUT.value,
|
||||
VariableEntity.Type.PARAGRAPH.value,
|
||||
VariableEntity.Type.NUMBER.value,
|
||||
elif variable_type in [
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
VariableEntityType.NUMBER,
|
||||
VariableEntityType.SELECT,
|
||||
]:
|
||||
variables.append(
|
||||
variable = variables[variable_type]
|
||||
variable_entities.append(
|
||||
VariableEntity(
|
||||
type=VariableEntity.Type.value_of(typ),
|
||||
variable=variable[typ].get('variable'),
|
||||
description=variable[typ].get('description'),
|
||||
label=variable[typ].get('label'),
|
||||
required=variable[typ].get('required', False),
|
||||
max_length=variable[typ].get('max_length'),
|
||||
default=variable[typ].get('default'),
|
||||
)
|
||||
)
|
||||
elif typ == VariableEntity.Type.SELECT.value:
|
||||
variables.append(
|
||||
VariableEntity(
|
||||
type=VariableEntity.Type.SELECT,
|
||||
variable=variable[typ].get('variable'),
|
||||
description=variable[typ].get('description'),
|
||||
label=variable[typ].get('label'),
|
||||
required=variable[typ].get('required', False),
|
||||
options=variable[typ].get('options'),
|
||||
default=variable[typ].get('default'),
|
||||
type=variable_type,
|
||||
variable=variable.get('variable'),
|
||||
description=variable.get('description'),
|
||||
label=variable.get('label'),
|
||||
required=variable.get('required', False),
|
||||
max_length=variable.get('max_length'),
|
||||
options=variable.get('options'),
|
||||
default=variable.get('default'),
|
||||
)
|
||||
)
|
||||
|
||||
return variables, external_data_variables
|
||||
return variable_entities, external_data_variables
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
@ -183,4 +174,4 @@ class BasicVariablesConfigManager:
|
||||
config=config
|
||||
)
|
||||
|
||||
return config, ["external_data_tools"]
|
||||
return config, ["external_data_tools"]
|
||||
|
||||
@ -82,43 +82,29 @@ class PromptTemplateEntity(BaseModel):
|
||||
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
|
||||
|
||||
|
||||
class VariableEntityType(str, Enum):
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
PARAGRAPH = "paragraph"
|
||||
NUMBER = "number"
|
||||
EXTERNAL_DATA_TOOL = "external-data-tool"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
"""
|
||||
Variable Entity.
|
||||
"""
|
||||
class Type(Enum):
|
||||
TEXT_INPUT = 'text-input'
|
||||
SELECT = 'select'
|
||||
PARAGRAPH = 'paragraph'
|
||||
NUMBER = 'number'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'VariableEntity.Type':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid variable type value {value}')
|
||||
|
||||
variable: str
|
||||
label: str
|
||||
description: Optional[str] = None
|
||||
type: Type
|
||||
type: VariableEntityType
|
||||
required: bool = False
|
||||
max_length: Optional[int] = None
|
||||
options: Optional[list[str]] = None
|
||||
default: Optional[str] = None
|
||||
hint: Optional[str] = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.variable
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
@ -252,4 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig):
|
||||
"""
|
||||
Workflow UI Based App Config Entity.
|
||||
"""
|
||||
workflow_id: str
|
||||
workflow_id: str
|
||||
|
||||
@ -29,7 +29,7 @@ from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
@ -46,7 +46,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -73,8 +73,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get('conversation_id'):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
|
||||
conversation_id = args.get('conversation_id')
|
||||
if conversation_id:
|
||||
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
@ -133,8 +134,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
stream: bool = True):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -157,8 +157,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get('conversation_id'):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
|
||||
conversation_id = args.get('conversation_id')
|
||||
if conversation_id:
|
||||
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
@ -200,8 +201,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Conversation | None = None,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
stream: bool = True):
|
||||
is_first_conversation = False
|
||||
if not conversation:
|
||||
is_first_conversation = True
|
||||
@ -270,11 +270,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariable.QUERY: query,
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION_ID: conversation_id,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
SystemVariableKey.QUERY: query,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation_id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
@ -362,7 +362,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
if os.environ.get("DEBUG", "false").lower() == 'true':
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
||||
@ -49,7 +49,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
from events.message_event import message_was_created
|
||||
@ -74,7 +74,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
# Deprecated
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(
|
||||
@ -108,10 +108,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._message = message
|
||||
# Deprecated
|
||||
self._workflow_system_variables = {
|
||||
SystemVariable.QUERY: message.query,
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
|
||||
self._task_state = AdvancedChatTaskState(
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import AppConfig, VariableEntity
|
||||
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
@ -9,29 +9,29 @@ class BaseAppGenerator:
|
||||
user_inputs = user_inputs or {}
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
variables = app_config.variables
|
||||
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
|
||||
return filtered_inputs
|
||||
|
||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
||||
user_input_value = inputs.get(var.name)
|
||||
user_input_value = inputs.get(var.variable)
|
||||
if var.required and not user_input_value:
|
||||
raise ValueError(f'{var.name} is required in input form')
|
||||
raise ValueError(f'{var.variable} is required in input form')
|
||||
if not var.required and not user_input_value:
|
||||
# TODO: should we return None here if the default value is None?
|
||||
return var.default or ''
|
||||
if (
|
||||
var.type
|
||||
in (
|
||||
VariableEntity.Type.TEXT_INPUT,
|
||||
VariableEntity.Type.SELECT,
|
||||
VariableEntity.Type.PARAGRAPH,
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
)
|
||||
and user_input_value
|
||||
and not isinstance(user_input_value, str)
|
||||
):
|
||||
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
|
||||
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
|
||||
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
|
||||
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if '.' in user_input_value:
|
||||
@ -39,14 +39,14 @@ class BaseAppGenerator:
|
||||
else:
|
||||
return int(user_input_value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{var.name} in input form must be a valid number")
|
||||
if var.type == VariableEntity.Type.SELECT:
|
||||
raise ValueError(f"{var.variable} in input form must be a valid number")
|
||||
if var.type == VariableEntityType.SELECT:
|
||||
options = var.options or []
|
||||
if user_input_value not in options:
|
||||
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
|
||||
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
|
||||
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
|
||||
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
|
||||
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
|
||||
|
||||
return user_input_value
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import and_
|
||||
@ -36,17 +37,17 @@ logger = logging.getLogger(__name__)
|
||||
class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
def _handle_response(
|
||||
self, application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
self, application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
@ -193,6 +194,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
else:
|
||||
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
message = Message(
|
||||
app_id=app_config.app_id,
|
||||
|
||||
@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
@ -67,8 +67,8 @@ class WorkflowAppRunner:
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
|
||||
@ -43,7 +43,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
@ -67,7 +67,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
@ -92,8 +92,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
self._workflow = workflow
|
||||
self._workflow_system_variables = {
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.USER_ID: user_id
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_id
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState(
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
@ -13,4 +13,4 @@ class WorkflowCycleStateManager:
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
|
||||
@ -1,15 +1,13 @@
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from threading import Lock
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from httpx import get, post
|
||||
from httpx import Timeout, post
|
||||
from pydantic import BaseModel
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
|
||||
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
|
||||
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
|
||||
@ -21,7 +19,7 @@ logger = logging.getLogger(__name__)
|
||||
CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT
|
||||
CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY
|
||||
|
||||
CODE_EXECUTION_TIMEOUT = (10, 60)
|
||||
CODE_EXECUTION_TIMEOUT = Timeout(connect=10, write=10, read=60, pool=None)
|
||||
|
||||
class CodeExecutionException(Exception):
|
||||
pass
|
||||
@ -66,8 +64,7 @@ class CodeExecutor:
|
||||
def execute_code(cls,
|
||||
language: CodeLanguage,
|
||||
preload: str,
|
||||
code: str,
|
||||
dependencies: Optional[list[CodeDependency]] = None) -> str:
|
||||
code: str) -> str:
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
@ -87,9 +84,6 @@ class CodeExecutor:
|
||||
'enable_network': True
|
||||
}
|
||||
|
||||
if dependencies:
|
||||
data['dependencies'] = [dependency.model_dump() for dependency in dependencies]
|
||||
|
||||
try:
|
||||
response = post(str(url), json=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT)
|
||||
if response.status_code == 503:
|
||||
@ -116,10 +110,10 @@ class CodeExecutor:
|
||||
if response.data.error:
|
||||
raise CodeExecutionException(response.data.error)
|
||||
|
||||
return response.data.stdout
|
||||
return response.data.stdout or ''
|
||||
|
||||
@classmethod
|
||||
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
|
||||
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict:
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
@ -131,67 +125,12 @@ class CodeExecutor:
|
||||
if not template_transformer:
|
||||
raise CodeExecutionException(f'Unsupported language {language}')
|
||||
|
||||
runner, preload, dependencies = template_transformer.transform_caller(code, inputs, dependencies)
|
||||
runner, preload = template_transformer.transform_caller(code, inputs)
|
||||
|
||||
try:
|
||||
response = cls.execute_code(language, preload, runner, dependencies)
|
||||
response = cls.execute_code(language, preload, runner)
|
||||
except CodeExecutionException as e:
|
||||
raise e
|
||||
|
||||
return template_transformer.transform_response(response)
|
||||
|
||||
@classmethod
|
||||
def list_dependencies(cls, language: str) -> list[CodeDependency]:
|
||||
if language not in cls.supported_dependencies_languages:
|
||||
return []
|
||||
|
||||
with cls.dependencies_cache_lock:
|
||||
if language in cls.dependencies_cache:
|
||||
# check expiration
|
||||
dependencies = cls.dependencies_cache[language]
|
||||
if dependencies['expiration'] > time.time():
|
||||
return dependencies['data']
|
||||
# remove expired cache
|
||||
del cls.dependencies_cache[language]
|
||||
|
||||
dependencies = cls._get_dependencies(language)
|
||||
with cls.dependencies_cache_lock:
|
||||
cls.dependencies_cache[language] = {
|
||||
'data': dependencies,
|
||||
'expiration': time.time() + 60
|
||||
}
|
||||
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def _get_dependencies(cls, language: Literal['python3']) -> list[CodeDependency]:
|
||||
"""
|
||||
List dependencies
|
||||
"""
|
||||
url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'dependencies'
|
||||
|
||||
headers = {
|
||||
'X-Api-Key': CODE_EXECUTION_API_KEY
|
||||
}
|
||||
|
||||
running_language = cls.code_language_to_running_language.get(language)
|
||||
if isinstance(running_language, Enum):
|
||||
running_language = running_language.value
|
||||
|
||||
data = {
|
||||
'language': running_language,
|
||||
}
|
||||
|
||||
try:
|
||||
response = get(str(url), params=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to list dependencies, got status code {response.status_code}, please check if the sandbox service is running')
|
||||
response = response.json()
|
||||
dependencies = response.get('data', {}).get('dependencies', [])
|
||||
return [
|
||||
CodeDependency(**dependency) for dependency in dependencies
|
||||
if dependency.get('name') not in Python3TemplateTransformer.get_standard_packages()
|
||||
]
|
||||
except Exception as e:
|
||||
logger.exception(f'Failed to list dependencies: {e}')
|
||||
return []
|
||||
|
||||
@ -2,8 +2,6 @@ from abc import abstractmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
|
||||
|
||||
class CodeNodeProvider(BaseModel):
|
||||
@staticmethod
|
||||
@ -23,10 +21,6 @@ class CodeNodeProvider(BaseModel):
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_default_available_packages(cls) -> list[dict]:
|
||||
return [p.model_dump() for p in CodeExecutor.list_dependencies(cls.get_language())]
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls) -> dict:
|
||||
return {
|
||||
@ -50,6 +44,5 @@ class CodeNodeProvider(BaseModel):
|
||||
"children": None
|
||||
}
|
||||
}
|
||||
},
|
||||
"available_dependencies": cls.get_default_available_packages(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeDependency(BaseModel):
|
||||
name: str
|
||||
version: str
|
||||
@ -3,7 +3,7 @@ from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
|
||||
class Jinja2Formatter:
|
||||
@classmethod
|
||||
def format(cls, template: str, inputs: str) -> str:
|
||||
def format(cls, template: str, inputs: dict) -> str:
|
||||
"""
|
||||
Format template
|
||||
:param template: template
|
||||
|
||||
@ -1,14 +1,9 @@
|
||||
from textwrap import dedent
|
||||
|
||||
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
|
||||
class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def get_standard_packages(cls) -> set[str]:
|
||||
return {'jinja2'} | Python3TemplateTransformer.get_standard_packages()
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str) -> dict:
|
||||
"""
|
||||
|
||||
@ -13,7 +13,7 @@ class Python3CodeProvider(CodeNodeProvider):
|
||||
def get_default_code(cls) -> str:
|
||||
return dedent(
|
||||
"""
|
||||
def main(arg1: int, arg2: int) -> dict:
|
||||
def main(arg1: str, arg2: str) -> dict:
|
||||
return {
|
||||
"result": arg1 + arg2,
|
||||
}
|
||||
|
||||
@ -4,30 +4,6 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
|
||||
class Python3TemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def get_standard_packages(cls) -> set[str]:
|
||||
return {
|
||||
'base64',
|
||||
'binascii',
|
||||
'collections',
|
||||
'datetime',
|
||||
'functools',
|
||||
'hashlib',
|
||||
'hmac',
|
||||
'itertools',
|
||||
'json',
|
||||
'math',
|
||||
'operator',
|
||||
'os',
|
||||
'random',
|
||||
're',
|
||||
'string',
|
||||
'sys',
|
||||
'time',
|
||||
'traceback',
|
||||
'uuid',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f"""
|
||||
|
||||
@ -2,9 +2,6 @@ import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from base64 import b64encode
|
||||
from typing import Optional
|
||||
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
|
||||
|
||||
class TemplateTransformer(ABC):
|
||||
@ -13,12 +10,7 @@ class TemplateTransformer(ABC):
|
||||
_result_tag: str = '<<RESULT>>'
|
||||
|
||||
@classmethod
|
||||
def get_standard_packages(cls) -> set[str]:
|
||||
return set()
|
||||
|
||||
@classmethod
|
||||
def transform_caller(cls, code: str, inputs: dict,
|
||||
dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:
|
||||
def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
|
||||
"""
|
||||
Transform code to python runner
|
||||
:param code: code
|
||||
@ -28,14 +20,7 @@ class TemplateTransformer(ABC):
|
||||
runner_script = cls.assemble_runner_script(code, inputs)
|
||||
preload_script = cls.get_preload_script()
|
||||
|
||||
packages = dependencies or []
|
||||
standard_packages = cls.get_standard_packages()
|
||||
for package in standard_packages:
|
||||
if package not in packages:
|
||||
packages.append(CodeDependency(name=package, version=''))
|
||||
packages = list({dep.name: dep for dep in packages if dep.name}.values())
|
||||
|
||||
return runner_script, preload_script, packages
|
||||
return runner_script, preload_script
|
||||
|
||||
@classmethod
|
||||
def extract_result_str_from_response(cls, response: str) -> str:
|
||||
|
||||
@ -3,6 +3,7 @@ from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
|
||||
@ -19,6 +20,87 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
|
||||
return {name: index for index, name in enumerate(positions)}
|
||||
|
||||
|
||||
def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
|
||||
"""
|
||||
Get the mapping for tools from name to index from a YAML file.
|
||||
:param folder_path:
|
||||
:param file_name: the YAML file name, default to '_position.yaml'
|
||||
:return: a dict with name as key and index as value
|
||||
"""
|
||||
position_map = get_position_map(folder_path, file_name=file_name)
|
||||
|
||||
return pin_position_map(
|
||||
position_map,
|
||||
pin_list=dify_config.POSITION_TOOL_PINS_LIST,
|
||||
)
|
||||
|
||||
|
||||
def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
|
||||
"""
|
||||
Get the mapping for providers from name to index from a YAML file.
|
||||
:param folder_path:
|
||||
:param file_name: the YAML file name, default to '_position.yaml'
|
||||
:return: a dict with name as key and index as value
|
||||
"""
|
||||
position_map = get_position_map(folder_path, file_name=file_name)
|
||||
return pin_position_map(
|
||||
position_map,
|
||||
pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
|
||||
)
|
||||
|
||||
|
||||
def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
|
||||
"""
|
||||
Pin the items in the pin list to the beginning of the position map.
|
||||
Overall logic: exclude > include > pin
|
||||
:param position_map: the position map to be sorted and filtered
|
||||
:param pin_list: the list of pins to be put at the beginning
|
||||
:return: the sorted position map
|
||||
"""
|
||||
positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x])
|
||||
|
||||
# Add pins to position map
|
||||
position_map = {name: idx for idx, name in enumerate(pin_list)}
|
||||
|
||||
# Add remaining positions to position map
|
||||
start_idx = len(position_map)
|
||||
for name in positions:
|
||||
if name not in position_map:
|
||||
position_map[name] = start_idx
|
||||
start_idx += 1
|
||||
|
||||
return position_map
|
||||
|
||||
|
||||
def is_filtered(
|
||||
include_set: set[str],
|
||||
exclude_set: set[str],
|
||||
data: Any,
|
||||
name_func: Callable[[Any], str],
|
||||
) -> bool:
|
||||
"""
|
||||
Chcek if the object should be filtered out.
|
||||
Overall logic: exclude > include > pin
|
||||
:param include_set: the set of names to be included
|
||||
:param exclude_set: the set of names to be excluded
|
||||
:param name_func: the function to get the name of the object
|
||||
:param data: the data to be filtered
|
||||
:return: True if the object should be filtered out, False otherwise
|
||||
"""
|
||||
if not data:
|
||||
return False
|
||||
if not include_set and not exclude_set:
|
||||
return False
|
||||
|
||||
name = name_func(data)
|
||||
|
||||
if name in exclude_set: # exclude_set is prioritized
|
||||
return True
|
||||
if include_set and name not in include_set: # filter out only if include_set is not empty
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def sort_by_position_map(
|
||||
position_map: dict[str, int],
|
||||
data: list[Any],
|
||||
|
||||
@ -700,6 +700,7 @@ class IndexingRunner:
|
||||
DatasetDocument.tokens: tokens,
|
||||
DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
|
||||
DatasetDocument.error: None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -368,6 +368,15 @@ class ModelManager:
|
||||
|
||||
return ModelInstance(provider_model_bundle, model)
|
||||
|
||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
|
||||
"""
|
||||
Return first provider and the first model in the provider
|
||||
:param tenant_id: tenant id
|
||||
:param model_type: model type
|
||||
:return: provider name, model name
|
||||
"""
|
||||
return self._provider_manager.get_first_provider_first_model(tenant_id, model_type)
|
||||
|
||||
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
|
||||
"""
|
||||
Get default model instance
|
||||
@ -502,7 +511,6 @@ class LBModelManager:
|
||||
config.id
|
||||
)
|
||||
|
||||
|
||||
res = redis_client.exists(cooldown_cache_key)
|
||||
res = cast(bool, res)
|
||||
return res
|
||||
|
||||
@ -151,9 +151,9 @@ class AIModel(ABC):
|
||||
os.path.join(provider_model_type_path, model_schema_yaml)
|
||||
for model_schema_yaml in os.listdir(provider_model_type_path)
|
||||
if not model_schema_yaml.startswith('__')
|
||||
and not model_schema_yaml.startswith('_')
|
||||
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
|
||||
and model_schema_yaml.endswith('.yaml')
|
||||
and not model_schema_yaml.startswith('_')
|
||||
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
|
||||
and model_schema_yaml.endswith('.yaml')
|
||||
]
|
||||
|
||||
# get _position.yaml file path
|
||||
|
||||
@ -185,7 +185,7 @@ if you are not sure about the structure.
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
|
||||
model_parameters.pop("response_format")
|
||||
stop = stop or []
|
||||
stop.extend(["\n```", "```\n"])
|
||||
@ -249,10 +249,10 @@ if you are not sure about the structure.
|
||||
prompt_messages=prompt_messages,
|
||||
input_generator=new_generator()
|
||||
)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
|
||||
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
|
||||
input_generator: Generator[LLMResultChunk, None, None]
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
@ -310,7 +310,7 @@ if you are not sure about the structure.
|
||||
)
|
||||
)
|
||||
|
||||
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
|
||||
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
|
||||
input_generator: Generator[LLMResultChunk, None, None]) \
|
||||
-> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
@ -470,7 +470,7 @@ if you are not sure about the structure.
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
@ -792,6 +792,13 @@ if you are not sure about the structure.
|
||||
if not isinstance(parameter_value, str):
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be string.")
|
||||
|
||||
# validate options
|
||||
if parameter_rule.options and parameter_value not in parameter_rule.options:
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
|
||||
elif parameter_rule.type == ParameterType.TEXT:
|
||||
if not isinstance(parameter_value, str):
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be text.")
|
||||
|
||||
# validate options
|
||||
if parameter_rule.options and parameter_value not in parameter_rule.options:
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
|
||||
|
||||
@ -70,7 +70,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
# doc: https://platform.openai.com/docs/guides/text-to-speech
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
# max font is 4096,there is 3500 limit for each request
|
||||
# max length is 4096 characters, there is 3500 limit for each request
|
||||
max_length = 3500
|
||||
if len(content_text) > max_length:
|
||||
sentences = self._split_text_into_sentences(content_text, max_length=max_length)
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Optional
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map
|
||||
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
@ -234,7 +234,7 @@ class ModelProviderFactory:
|
||||
]
|
||||
|
||||
# get _position.yaml file path
|
||||
position_map = get_position_map(model_providers_path)
|
||||
position_map = get_provider_position_map(model_providers_path)
|
||||
|
||||
# traverse all model_provider_dir_paths
|
||||
model_providers: list[ModelProviderExtension] = []
|
||||
|
||||
@ -37,6 +37,9 @@ parameter_rules:
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- json_schema
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.15'
|
||||
output: '0.60'
|
||||
|
||||
@ -428,7 +428,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
finish_reason = 'Unknown'
|
||||
finish_reason = None # The default value of finish_reason is None
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
chunk = chunk.strip()
|
||||
@ -437,6 +437,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
if chunk.startswith(':'):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
||||
if decoded_chunk == '[DONE]': # Some provider returns "data: [DONE]"
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
|
||||
@ -0,0 +1,44 @@
|
||||
model: gpt-4o-2024-08-06
|
||||
label:
|
||||
zh_Hans: gpt-4o-2024-08-06
|
||||
en_US: gpt-4o-2024-08-06
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '2.50'
|
||||
output: '10.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -118,6 +118,9 @@ class _CommonWenxin:
|
||||
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
||||
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
|
||||
'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1',
|
||||
'bge-large-en': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en',
|
||||
'bge-large-zh': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh',
|
||||
'tao-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k',
|
||||
}
|
||||
|
||||
function_calling_supports = [
|
||||
|
||||
@ -0,0 +1,9 @@
|
||||
model: bge-large-en
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 512
|
||||
max_chunks: 16
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,9 @@
|
||||
model: bge-large-zh
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 512
|
||||
max_chunks: 16
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,9 @@
|
||||
model: tao-8k
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 8192
|
||||
max_chunks: 1
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@ -5,6 +5,7 @@ from typing import Optional
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
||||
from core.entities.provider_entities import (
|
||||
@ -18,12 +19,9 @@ from core.entities.provider_entities import (
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from extensions import ext_hosting_provider
|
||||
from extensions.ext_database import db
|
||||
@ -45,6 +43,7 @@ class ProviderManager:
|
||||
"""
|
||||
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.decoding_rsa_key = None
|
||||
self.decoding_cipher_rsa = None
|
||||
@ -117,6 +116,16 @@ class ProviderManager:
|
||||
|
||||
# Construct ProviderConfiguration objects for each provider
|
||||
for provider_entity in provider_entities:
|
||||
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
|
||||
data=provider_entity,
|
||||
name_func=lambda x: x.provider,
|
||||
):
|
||||
continue
|
||||
|
||||
provider_name = provider_entity.provider
|
||||
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
|
||||
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
|
||||
@ -271,6 +280,24 @@ class ProviderManager:
|
||||
)
|
||||
)
|
||||
|
||||
def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
|
||||
"""
|
||||
Get names of first model and its provider
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param model_type: model type
|
||||
:return: provider name, model name
|
||||
"""
|
||||
provider_configurations = self.get_configurations(tenant_id)
|
||||
|
||||
# get available models from provider_configurations
|
||||
all_models = provider_configurations.get_models(
|
||||
model_type=model_type,
|
||||
only_active=False
|
||||
)
|
||||
|
||||
return all_models[0].provider.provider, all_models[0].model
|
||||
|
||||
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
|
||||
-> TenantDefaultModel:
|
||||
"""
|
||||
|
||||
@ -152,8 +152,27 @@ class PGVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# do not support bm25 search
|
||||
return []
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), to_tsquery(%s)) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
# f"'{query}'" is required in order to account for whitespace in query
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
)
|
||||
|
||||
docs = []
|
||||
|
||||
for record in cur:
|
||||
metadata, text, score = record
|
||||
metadata["score"] = score
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os.path
|
||||
|
||||
from core.helper.position_helper import get_position_map, sort_by_position_map
|
||||
from core.helper.position_helper import get_tool_position_map, sort_by_position_map
|
||||
from core.tools.entities.api_entities import UserToolProvider
|
||||
|
||||
|
||||
@ -10,11 +10,11 @@ class BuiltinToolProviderSort:
|
||||
@classmethod
|
||||
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
||||
if not cls._position:
|
||||
cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))
|
||||
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
def name_func(provider: UserToolProvider) -> str:
|
||||
return provider.name
|
||||
|
||||
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
|
||||
|
||||
return sorted_providers
|
||||
return sorted_providers
|
||||
|
||||
49
api/core/tools/provider/builtin/crossref/_assets/icon.svg
Normal file
49
api/core/tools/provider/builtin/crossref/_assets/icon.svg
Normal file
@ -0,0 +1,49 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!-- Generator: Adobe Illustrator 19.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
|
||||
viewBox="0 0 200 130.2" style="enable-background:new 0 0 200 130.2;" xml:space="preserve">
|
||||
<style type="text/css">
|
||||
.st0{fill:#3EB1C8;}
|
||||
.st1{fill:#D8D2C4;}
|
||||
.st2{fill:#4F5858;}
|
||||
.st3{fill:#FFC72C;}
|
||||
.st4{fill:#EF3340;}
|
||||
</style>
|
||||
<g>
|
||||
<polygon class="st0" points="111.8,95.5 111.8,66.8 135.4,59 177.2,73.3 "/>
|
||||
<polygon class="st1" points="153.6,36.8 111.8,51.2 135.4,59 177.2,44.6 "/>
|
||||
<polygon class="st2" points="135.4,59 177.2,44.6 177.2,73.3 "/>
|
||||
<polygon class="st3" points="177.2,0.3 177.2,29 153.6,36.8 111.8,22.5 "/>
|
||||
<polygon class="st4" points="153.6,36.8 111.8,51.2 111.8,22.5 "/>
|
||||
<g>
|
||||
<g>
|
||||
<g>
|
||||
<g>
|
||||
<path class="st2" d="M26.3,104.8c-0.5-3.7-4.1-6.5-8.1-6.5c-7.3,0-10.1,6.2-10.1,12.7c0,6.2,2.8,12.4,10.1,12.4
|
||||
c5,0,7.8-3.4,8.4-8.3h7.9c-0.8,9.2-7.2,15.2-16.3,15.2C6.8,130.2,0,121.7,0,111c0-11,6.8-19.6,18.2-19.6c8.2,0,15,4.8,16,13.3
|
||||
H26.3z"/>
|
||||
<path class="st2" d="M37.4,102.5h7v5h0.1c1.4-3.4,5-5.7,8.6-5.7c0.5,0,1.1,0.1,1.6,0.3v6.9c-0.7-0.2-1.8-0.3-2.6-0.3
|
||||
c-5.4,0-7.3,3.9-7.3,8.6v12.1h-7.4V102.5z"/>
|
||||
<path class="st2" d="M68.7,101.8c8.5,0,13.9,5.6,13.9,14.2c0,8.5-5.5,14.1-13.9,14.1c-8.4,0-13.9-5.6-13.9-14.1
|
||||
C54.9,107.4,60.3,101.8,68.7,101.8z M68.7,124.5c5,0,6.5-4.3,6.5-8.6c0-4.3-1.5-8.6-6.5-8.6c-5,0-6.5,4.3-6.5,8.6
|
||||
C62.2,120.2,63.8,124.5,68.7,124.5z"/>
|
||||
<path class="st2" d="M91.2,120.6c0.1,3.2,2.8,4.5,5.7,4.5c2.1,0,4.8-0.8,4.8-3.4c0-2.2-3.1-3-8.4-4.2c-4.3-0.9-8.5-2.4-8.5-7.2
|
||||
c0-6.9,5.9-8.6,11.7-8.6c5.9,0,11.3,2,11.8,8.6h-7c-0.2-2.9-2.4-3.6-5-3.6c-1.7,0-4.1,0.3-4.1,2.5c0,2.6,4.2,3,8.4,4
|
||||
c4.3,1,8.5,2.5,8.5,7.5c0,7.1-6.1,9.3-12.3,9.3c-6.2,0-12.3-2.3-12.6-9.5H91.2z"/>
|
||||
<path class="st2" d="M118.1,120.6c0.1,3.2,2.8,4.5,5.7,4.5c2.1,0,4.8-0.8,4.8-3.4c0-2.2-3.1-3-8.4-4.2
|
||||
c-4.3-0.9-8.5-2.4-8.5-7.2c0-6.9,5.9-8.6,11.7-8.6c5.9,0,11.3,2,11.8,8.6h-7c-0.2-2.9-2.4-3.6-5-3.6c-1.7,0-4.1,0.3-4.1,2.5
|
||||
c0,2.6,4.2,3,8.4,4c4.3,1,8.5,2.5,8.5,7.5c0,7.1-6.1,9.3-12.3,9.3c-6.2,0-12.3-2.3-12.6-9.5H118.1z"/>
|
||||
<path class="st2" d="M138.4,102.5h7v5h0.1c1.4-3.4,5-5.7,8.6-5.7c0.5,0,1.1,0.1,1.6,0.3v6.9c-0.7-0.2-1.8-0.3-2.6-0.3
|
||||
c-5.4,0-7.3,3.9-7.3,8.6v12.1h-7.4V102.5z"/>
|
||||
<path class="st2" d="M163.7,117.7c0.2,4.7,2.5,6.8,6.6,6.8c3,0,5.3-1.8,5.8-3.5h6.5c-2.1,6.3-6.5,9-12.6,9
|
||||
c-8.5,0-13.7-5.8-13.7-14.1c0-8,5.6-14.2,13.7-14.2c9.1,0,13.6,7.7,13,15.9H163.7z M175.7,113.1c-0.7-3.7-2.3-5.7-5.9-5.7
|
||||
c-4.7,0-6,3.6-6.1,5.7H175.7z"/>
|
||||
<path class="st2" d="M187.2,107.5h-4.4v-4.9h4.4v-2.1c0-4.7,3-8.2,9-8.2c1.3,0,2.6,0.2,3.9,0.2V98c-0.9-0.1-1.8-0.2-2.7-0.2
|
||||
c-2,0-2.8,0.8-2.8,3.1v1.6h5.1v4.9h-5.1v21.9h-7.4V107.5z"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 3.0 KiB |
20
api/core/tools/provider/builtin/crossref/crossref.py
Normal file
20
api/core/tools/provider/builtin/crossref/crossref.py
Normal file
@ -0,0 +1,20 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.crossref.tools.query_doi import CrossRefQueryDOITool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class CrossRefProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
CrossRefQueryDOITool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"doi": '10.1007/s00894-022-05373-8',
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
29
api/core/tools/provider/builtin/crossref/crossref.yaml
Normal file
29
api/core/tools/provider/builtin/crossref/crossref.yaml
Normal file
@ -0,0 +1,29 @@
|
||||
identity:
|
||||
author: Sakura4036
|
||||
name: crossref
|
||||
label:
|
||||
en_US: CrossRef
|
||||
zh_Hans: CrossRef
|
||||
description:
|
||||
en_US: Crossref is a cross-publisher reference linking registration query system using DOI technology created in 2000. Crossref establishes cross-database links between the reference list and citation full text of papers, making it very convenient for readers to access the full text of papers.
|
||||
zh_Hans: Crossref是于2000年创建的使用DOI技术的跨出版商参考文献链接注册查询系统。Crossref建立了在论文的参考文献列表和引文全文之间的跨数据库链接,使得读者能够非常便捷地获取文献全文。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
credentials_for_provider:
|
||||
mailto:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: email address
|
||||
zh_Hans: email地址
|
||||
pt_BR: email address
|
||||
placeholder:
|
||||
en_US: Please input your email address
|
||||
zh_Hans: 请输入你的email地址
|
||||
pt_BR: Please input your email address
|
||||
help:
|
||||
en_US: According to the requirements of Crossref, an email address is required
|
||||
zh_Hans: 根据Crossref的要求,需要提供一个邮箱地址
|
||||
pt_BR: According to the requirements of Crossref, an email address is required
|
||||
url: https://api.crossref.org/swagger-ui/index.html
|
||||
25
api/core/tools/provider/builtin/crossref/tools/query_doi.py
Normal file
25
api/core/tools/provider/builtin/crossref/tools/query_doi.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolParameterValidationError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class CrossRefQueryDOITool(BuiltinTool):
|
||||
"""
|
||||
Tool for querying the metadata of a publication using its DOI.
|
||||
"""
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
doi = tool_parameters.get('doi')
|
||||
if not doi:
|
||||
raise ToolParameterValidationError('doi is required.')
|
||||
# doc: https://github.com/CrossRef/rest-api-doc
|
||||
url = f"https://api.crossref.org/works/{doi}"
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
response = response.json()
|
||||
message = response.get('message', {})
|
||||
|
||||
return self.create_json_message(message)
|
||||
@ -0,0 +1,23 @@
|
||||
identity:
|
||||
name: crossref_query_doi
|
||||
author: Sakura4036
|
||||
label:
|
||||
en_US: CrossRef Query DOI
|
||||
zh_Hans: CrossRef DOI 查询
|
||||
pt_BR: CrossRef Query DOI
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for searching literature information using CrossRef by DOI.
|
||||
zh_Hans: 一个使用CrossRef通过DOI获取文献信息的工具。
|
||||
pt_BR: A tool for searching literature information using CrossRef by DOI.
|
||||
llm: A tool for searching literature information using CrossRef by DOI.
|
||||
parameters:
|
||||
- name: doi
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: DOI
|
||||
zh_Hans: DOI
|
||||
pt_BR: DOI
|
||||
llm_description: DOI for searching in CrossRef
|
||||
form: llm
|
||||
120
api/core/tools/provider/builtin/crossref/tools/query_title.py
Normal file
120
api/core/tools/provider/builtin/crossref/tools/query_title.py
Normal file
@ -0,0 +1,120 @@
|
||||
import time
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
def convert_time_str_to_seconds(time_str: str) -> int:
|
||||
"""
|
||||
Convert a time string to seconds.
|
||||
example: 1s -> 1, 1m30s -> 90, 1h30m -> 5400, 1h30m30s -> 5430
|
||||
"""
|
||||
time_str = time_str.lower().strip().replace(' ', '')
|
||||
seconds = 0
|
||||
if 'h' in time_str:
|
||||
hours, time_str = time_str.split('h')
|
||||
seconds += int(hours) * 3600
|
||||
if 'm' in time_str:
|
||||
minutes, time_str = time_str.split('m')
|
||||
seconds += int(minutes) * 60
|
||||
if 's' in time_str:
|
||||
seconds += int(time_str.replace('s', ''))
|
||||
return seconds
|
||||
|
||||
|
||||
class CrossRefQueryTitleAPI:
|
||||
"""
|
||||
Tool for querying the metadata of a publication using its title.
|
||||
Crossref API doc: https://github.com/CrossRef/rest-api-doc
|
||||
"""
|
||||
query_url_template: str = "https://api.crossref.org/works?query.bibliographic={query}&rows={rows}&offset={offset}&sort={sort}&order={order}&mailto={mailto}"
|
||||
rate_limit: int = 50
|
||||
rate_interval: float = 1
|
||||
max_limit: int = 1000
|
||||
|
||||
def __init__(self, mailto: str):
|
||||
self.mailto = mailto
|
||||
|
||||
def _query(self, query: str, rows: int = 5, offset: int = 0, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]:
|
||||
"""
|
||||
Query the metadata of a publication using its title.
|
||||
:param query: the title of the publication
|
||||
:param rows: the number of results to return
|
||||
:param sort: the sort field
|
||||
:param order: the sort order
|
||||
:param fuzzy_query: whether to return all items that match the query
|
||||
"""
|
||||
url = self.query_url_template.format(query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto)
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
rate_limit = int(response.headers['x-ratelimit-limit'])
|
||||
# convert time string to seconds
|
||||
rate_interval = convert_time_str_to_seconds(response.headers['x-ratelimit-interval'])
|
||||
|
||||
self.rate_limit = rate_limit
|
||||
self.rate_interval = rate_interval
|
||||
|
||||
response = response.json()
|
||||
if response['status'] != 'ok':
|
||||
return []
|
||||
|
||||
message = response['message']
|
||||
if fuzzy_query:
|
||||
# fuzzy query return all items
|
||||
return message['items']
|
||||
else:
|
||||
for paper in message['items']:
|
||||
title = paper['title'][0]
|
||||
if title.lower() != query.lower():
|
||||
continue
|
||||
return [paper]
|
||||
return []
|
||||
|
||||
def query(self, query: str, rows: int = 5, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]:
|
||||
"""
|
||||
Query the metadata of a publication using its title.
|
||||
:param query: the title of the publication
|
||||
:param rows: the number of results to return
|
||||
:param sort: the sort field
|
||||
:param order: the sort order
|
||||
:param fuzzy_query: whether to return all items that match the query
|
||||
"""
|
||||
rows = min(rows, self.max_limit)
|
||||
if rows > self.rate_limit:
|
||||
# query multiple times
|
||||
query_times = rows // self.rate_limit + 1
|
||||
results = []
|
||||
|
||||
for i in range(query_times):
|
||||
result = self._query(query, rows=self.rate_limit, offset=i * self.rate_limit, sort=sort, order=order, fuzzy_query=fuzzy_query)
|
||||
if fuzzy_query:
|
||||
results.extend(result)
|
||||
else:
|
||||
# fuzzy_query=False, only one result
|
||||
if result:
|
||||
return result
|
||||
time.sleep(self.rate_interval)
|
||||
return results
|
||||
else:
|
||||
# query once
|
||||
return self._query(query, rows, sort=sort, order=order, fuzzy_query=fuzzy_query)
|
||||
|
||||
|
||||
class CrossRefQueryTitleTool(BuiltinTool):
|
||||
"""
|
||||
Tool for querying the metadata of a publication using its title.
|
||||
"""
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
query = tool_parameters.get('query')
|
||||
fuzzy_query = tool_parameters.get('fuzzy_query', False)
|
||||
rows = tool_parameters.get('rows', 3)
|
||||
sort = tool_parameters.get('sort', 'relevance')
|
||||
order = tool_parameters.get('order', 'desc')
|
||||
mailto = self.runtime.credentials['mailto']
|
||||
|
||||
result = CrossRefQueryTitleAPI(mailto).query(query, rows, sort, order, fuzzy_query)
|
||||
|
||||
return [self.create_json_message(r) for r in result]
|
||||
105
api/core/tools/provider/builtin/crossref/tools/query_title.yaml
Normal file
105
api/core/tools/provider/builtin/crossref/tools/query_title.yaml
Normal file
@ -0,0 +1,105 @@
|
||||
identity:
|
||||
name: crossref_query_title
|
||||
author: Sakura4036
|
||||
label:
|
||||
en_US: CrossRef Title Query
|
||||
zh_Hans: CrossRef 标题查询
|
||||
pt_BR: CrossRef Title Query
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for querying literature information using CrossRef by title.
|
||||
zh_Hans: 一个使用CrossRef通过标题搜索文献信息的工具。
|
||||
pt_BR: A tool for querying literature information using CrossRef by title.
|
||||
llm: A tool for querying literature information using CrossRef by title.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: 标题
|
||||
zh_Hans: 查询语句
|
||||
pt_BR: 标题
|
||||
human_description:
|
||||
en_US: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years
|
||||
zh_Hans: 用于搜索文献信息,有助于查找引用。包括标题,作者,ISSN和出版年份
|
||||
pt_BR: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years
|
||||
llm_description: key words for querying in Web of Science
|
||||
form: llm
|
||||
- name: fuzzy_query
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: Whether to fuzzy search
|
||||
zh_Hans: 是否模糊搜索
|
||||
pt_BR: Whether to fuzzy search
|
||||
human_description:
|
||||
en_US: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none
|
||||
zh_Hans: 用于选择搜索类型,模糊搜索返回更多结果,精确搜索返回1条结果或无
|
||||
pt_BR: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none
|
||||
form: form
|
||||
- name: limit
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: max query number
|
||||
zh_Hans: 最大搜索数
|
||||
pt_BR: max query number
|
||||
human_description:
|
||||
en_US: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches)
|
||||
zh_Hans: 最大搜索数(模糊搜索返回的最大结果数或精确搜索最大匹配数)
|
||||
pt_BR: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches)
|
||||
form: llm
|
||||
default: 50
|
||||
- name: sort
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: relevance
|
||||
label:
|
||||
en_US: relevance
|
||||
zh_Hans: 相关性
|
||||
pt_BR: relevance
|
||||
- value: published
|
||||
label:
|
||||
en_US: publication date
|
||||
zh_Hans: 出版日期
|
||||
pt_BR: publication date
|
||||
- value: references-count
|
||||
label:
|
||||
en_US: references-count
|
||||
zh_Hans: 引用次数
|
||||
pt_BR: references-count
|
||||
default: relevance
|
||||
label:
|
||||
en_US: sorting field
|
||||
zh_Hans: 排序字段
|
||||
pt_BR: sorting field
|
||||
human_description:
|
||||
en_US: Sorting of query results
|
||||
zh_Hans: 检索结果的排序字段
|
||||
pt_BR: Sorting of query results
|
||||
form: form
|
||||
- name: order
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: desc
|
||||
label:
|
||||
en_US: descending
|
||||
zh_Hans: 降序
|
||||
pt_BR: descending
|
||||
- value: asc
|
||||
label:
|
||||
en_US: ascending
|
||||
zh_Hans: 升序
|
||||
pt_BR: ascending
|
||||
default: desc
|
||||
label:
|
||||
en_US: Order
|
||||
zh_Hans: 排序
|
||||
pt_BR: Order
|
||||
human_description:
|
||||
en_US: Order of query results
|
||||
zh_Hans: 检索结果的排序方式
|
||||
pt_BR: Order of query results
|
||||
form: form
|
||||
@ -60,5 +60,11 @@ parameters:
|
||||
label:
|
||||
en_US: Tokenizer
|
||||
human_description:
|
||||
en_US: cl100k_base - gpt-4,gpt-3.5-turbo,gpt-3.5; o200k_base - gpt-4o,gpt-4o-mini; p50k_base - text-davinci-003,text-davinci-002
|
||||
en_US: |
|
||||
· cl100k_base --- gpt-4, gpt-3.5-turbo, gpt-3.5
|
||||
· o200k_base --- gpt-4o, gpt-4o-mini
|
||||
· p50k_base --- text-davinci-003, text-davinci-002
|
||||
· r50k_base --- text-davinci-001, text-curie-001
|
||||
· p50k_edit --- text-davinci-edit-001, code-davinci-edit-001
|
||||
· gpt2 --- gpt-2
|
||||
form: form
|
||||
|
||||
@ -0,0 +1,73 @@
|
||||
from novita_client import (
|
||||
Txt2ImgV3Embedding,
|
||||
Txt2ImgV3HiresFix,
|
||||
Txt2ImgV3LoRA,
|
||||
Txt2ImgV3Refiner,
|
||||
V3TaskImage,
|
||||
)
|
||||
|
||||
|
||||
class NovitaAiToolBase:
|
||||
def _extract_loras(self, loras_str: str):
|
||||
if not loras_str:
|
||||
return []
|
||||
|
||||
loras_ori_list = lora_str.strip().split(';')
|
||||
result_list = []
|
||||
for lora_str in loras_ori_list:
|
||||
lora_info = lora_str.strip().split(',')
|
||||
lora = Txt2ImgV3LoRA(
|
||||
model_name=lora_info[0].strip(),
|
||||
strength=float(lora_info[1]),
|
||||
)
|
||||
result_list.append(lora)
|
||||
|
||||
return result_list
|
||||
|
||||
def _extract_embeddings(self, embeddings_str: str):
|
||||
if not embeddings_str:
|
||||
return []
|
||||
|
||||
embeddings_ori_list = embeddings_str.strip().split(';')
|
||||
result_list = []
|
||||
for embedding_str in embeddings_ori_list:
|
||||
embedding = Txt2ImgV3Embedding(
|
||||
model_name=embedding_str.strip()
|
||||
)
|
||||
result_list.append(embedding)
|
||||
|
||||
return result_list
|
||||
|
||||
def _extract_hires_fix(self, hires_fix_str: str):
|
||||
hires_fix_info = hires_fix_str.strip().split(',')
|
||||
if 'upscaler' in hires_fix_info:
|
||||
hires_fix = Txt2ImgV3HiresFix(
|
||||
target_width=int(hires_fix_info[0]),
|
||||
target_height=int(hires_fix_info[1]),
|
||||
strength=float(hires_fix_info[2]),
|
||||
upscaler=hires_fix_info[3].strip()
|
||||
)
|
||||
else:
|
||||
hires_fix = Txt2ImgV3HiresFix(
|
||||
target_width=int(hires_fix_info[0]),
|
||||
target_height=int(hires_fix_info[1]),
|
||||
strength=float(hires_fix_info[2])
|
||||
)
|
||||
|
||||
return hires_fix
|
||||
|
||||
def _extract_refiner(self, switch_at: str):
|
||||
refiner = Txt2ImgV3Refiner(
|
||||
switch_at=float(switch_at)
|
||||
)
|
||||
return refiner
|
||||
|
||||
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
|
||||
"""
|
||||
is hit nsfw
|
||||
"""
|
||||
if image.nsfw_detection_result is None:
|
||||
return False
|
||||
if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
|
||||
return True
|
||||
return False
|
||||
@ -4,19 +4,15 @@ from typing import Any, Union
|
||||
|
||||
from novita_client import (
|
||||
NovitaClient,
|
||||
Txt2ImgV3Embedding,
|
||||
Txt2ImgV3HiresFix,
|
||||
Txt2ImgV3LoRA,
|
||||
Txt2ImgV3Refiner,
|
||||
V3TaskImage,
|
||||
)
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.novitaai._novita_tool_base import NovitaAiToolBase
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class NovitaAiTxt2ImgTool(BuiltinTool):
|
||||
class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
@ -73,65 +69,19 @@ class NovitaAiTxt2ImgTool(BuiltinTool):
|
||||
|
||||
# process loras
|
||||
if 'loras' in res_parameters:
|
||||
loras_ori_list = res_parameters.get('loras').strip().split(';')
|
||||
locals_list = []
|
||||
for lora_str in loras_ori_list:
|
||||
lora_info = lora_str.strip().split(',')
|
||||
lora = Txt2ImgV3LoRA(
|
||||
model_name=lora_info[0].strip(),
|
||||
strength=float(lora_info[1]),
|
||||
)
|
||||
locals_list.append(lora)
|
||||
|
||||
res_parameters['loras'] = locals_list
|
||||
res_parameters['loras'] = self._extract_loras(res_parameters.get('loras'))
|
||||
|
||||
# process embeddings
|
||||
if 'embeddings' in res_parameters:
|
||||
embeddings_ori_list = res_parameters.get('embeddings').strip().split(';')
|
||||
locals_list = []
|
||||
for embedding_str in embeddings_ori_list:
|
||||
embedding = Txt2ImgV3Embedding(
|
||||
model_name=embedding_str.strip()
|
||||
)
|
||||
locals_list.append(embedding)
|
||||
|
||||
res_parameters['embeddings'] = locals_list
|
||||
res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings'))
|
||||
|
||||
# process hires_fix
|
||||
if 'hires_fix' in res_parameters:
|
||||
hires_fix_ori = res_parameters.get('hires_fix')
|
||||
hires_fix_info = hires_fix_ori.strip().split(',')
|
||||
if 'upscaler' in hires_fix_info:
|
||||
hires_fix = Txt2ImgV3HiresFix(
|
||||
target_width=int(hires_fix_info[0]),
|
||||
target_height=int(hires_fix_info[1]),
|
||||
strength=float(hires_fix_info[2]),
|
||||
upscaler=hires_fix_info[3].strip()
|
||||
)
|
||||
else:
|
||||
hires_fix = Txt2ImgV3HiresFix(
|
||||
target_width=int(hires_fix_info[0]),
|
||||
target_height=int(hires_fix_info[1]),
|
||||
strength=float(hires_fix_info[2])
|
||||
)
|
||||
res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix'))
|
||||
|
||||
res_parameters['hires_fix'] = hires_fix
|
||||
|
||||
if 'refiner_switch_at' in res_parameters:
|
||||
refiner = Txt2ImgV3Refiner(
|
||||
switch_at=float(res_parameters.get('refiner_switch_at'))
|
||||
)
|
||||
del res_parameters['refiner_switch_at']
|
||||
res_parameters['refiner'] = refiner
|
||||
# process refiner
|
||||
if 'refiner_switch_at' in res_parameters:
|
||||
res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at'))
|
||||
del res_parameters['refiner_switch_at']
|
||||
|
||||
return res_parameters
|
||||
|
||||
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
|
||||
"""
|
||||
is hit nsfw
|
||||
"""
|
||||
if image.nsfw_detection_result is None:
|
||||
return False
|
||||
if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
@ -18,6 +18,13 @@ from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow
|
||||
|
||||
VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
||||
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
|
||||
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
|
||||
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
|
||||
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
@ -28,7 +35,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
if not app:
|
||||
raise ValueError('app not found')
|
||||
|
||||
|
||||
controller = WorkflowToolProviderController(**{
|
||||
'identity': {
|
||||
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
|
||||
@ -46,7 +53,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
'credentials_schema': {},
|
||||
'provider_id': db_provider.id or '',
|
||||
})
|
||||
|
||||
|
||||
# init tools
|
||||
|
||||
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
|
||||
@ -56,7 +63,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
|
||||
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
|
||||
"""
|
||||
get db provider tool
|
||||
@ -93,23 +100,11 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
if variable:
|
||||
parameter_type = None
|
||||
options = None
|
||||
if variable.type in [
|
||||
VariableEntity.Type.TEXT_INPUT,
|
||||
VariableEntity.Type.PARAGRAPH,
|
||||
]:
|
||||
parameter_type = ToolParameter.ToolParameterType.STRING
|
||||
elif variable.type in [
|
||||
VariableEntity.Type.SELECT
|
||||
]:
|
||||
parameter_type = ToolParameter.ToolParameterType.SELECT
|
||||
elif variable.type in [
|
||||
VariableEntity.Type.NUMBER
|
||||
]:
|
||||
parameter_type = ToolParameter.ToolParameterType.NUMBER
|
||||
else:
|
||||
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
|
||||
raise ValueError(f'unsupported variable type {variable.type}')
|
||||
|
||||
if variable.type == VariableEntity.Type.SELECT and variable.options:
|
||||
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
|
||||
|
||||
if variable.type == VariableEntityType.SELECT and variable.options:
|
||||
options = [
|
||||
ToolParameterOption(
|
||||
value=option,
|
||||
@ -200,7 +195,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
"""
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
|
||||
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
@ -208,11 +203,11 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
if not db_providers:
|
||||
return []
|
||||
|
||||
|
||||
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
|
||||
|
||||
return self.tools
|
||||
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
|
||||
"""
|
||||
get tool by name
|
||||
@ -226,5 +221,5 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
for tool in self.tools:
|
||||
if tool.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
|
||||
return None
|
||||
|
||||
@ -10,14 +10,11 @@ from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
|
||||
@ -26,10 +23,7 @@ from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import (
|
||||
ToolConfigurationManager,
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from extensions.ext_database import db
|
||||
@ -38,6 +32,7 @@ from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_builtin_providers = {}
|
||||
@ -107,7 +102,7 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
@ -346,7 +341,7 @@ class ToolManager:
|
||||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
parent_type=BuiltinToolProviderController)
|
||||
provider: BuiltinToolProviderController = provider_class()
|
||||
cls._builtin_providers[provider.identity.name] = provider
|
||||
@ -414,6 +409,15 @@ class ToolManager:
|
||||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name
|
||||
):
|
||||
continue
|
||||
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.identity.name),
|
||||
@ -473,7 +477,7 @@ class ToolManager:
|
||||
|
||||
@classmethod
|
||||
def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
|
||||
ApiToolProviderController, dict[str, Any]]:
|
||||
ApiToolProviderController, dict[str, Any]]:
|
||||
"""
|
||||
get the api provider
|
||||
|
||||
@ -593,4 +597,5 @@ class ToolManager:
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
|
||||
ToolManager.load_builtin_providers_cache()
|
||||
|
||||
@ -6,20 +6,20 @@ from typing_extensions import deprecated
|
||||
|
||||
from core.app.segments import Segment, Variable, factory
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||
|
||||
|
||||
SYSTEM_VARIABLE_NODE_ID = 'sys'
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
|
||||
CONVERSATION_VARIABLE_NODE_ID = 'conversation'
|
||||
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
|
||||
|
||||
class VariablePool:
|
||||
def __init__(
|
||||
self,
|
||||
system_variables: Mapping[SystemVariable, Any],
|
||||
system_variables: Mapping[SystemVariableKey, Any],
|
||||
user_inputs: Mapping[str, Any],
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable] | None = None,
|
||||
@ -68,7 +68,7 @@ class VariablePool:
|
||||
None
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
raise ValueError("Invalid selector")
|
||||
|
||||
if value is None:
|
||||
return
|
||||
@ -95,13 +95,13 @@ class VariablePool:
|
||||
ValueError: If the selector is invalid.
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
raise ValueError("Invalid selector")
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
|
||||
return value
|
||||
|
||||
@deprecated('This method is deprecated, use `get` instead.')
|
||||
@deprecated("This method is deprecated, use `get` instead.")
|
||||
def get_any(self, selector: Sequence[str], /) -> Any | None:
|
||||
"""
|
||||
Retrieves the value from the variable pool based on the given selector.
|
||||
@ -116,7 +116,7 @@ class VariablePool:
|
||||
ValueError: If the selector is invalid.
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
raise ValueError("Invalid selector")
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
return value.to_object() if value else None
|
||||
|
||||
@ -1,25 +1,13 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SystemVariable(str, Enum):
|
||||
class SystemVariableKey(str, Enum):
|
||||
"""
|
||||
System Variables.
|
||||
"""
|
||||
QUERY = 'query'
|
||||
FILES = 'files'
|
||||
CONVERSATION_ID = 'conversation_id'
|
||||
USER_ID = 'user_id'
|
||||
DIALOGUE_COUNT = 'dialogue_count'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
Get value of given system variable.
|
||||
|
||||
:param value: system variable value
|
||||
:return: system variable
|
||||
"""
|
||||
for system_variable in cls:
|
||||
if system_variable.value == value:
|
||||
return system_variable
|
||||
raise ValueError(f'invalid system variable value {value}')
|
||||
QUERY = "query"
|
||||
FILES = "files"
|
||||
CONVERSATION_ID = "conversation_id"
|
||||
USER_ID = "user_id"
|
||||
DIALOGUE_COUNT = "dialogue_count"
|
||||
|
||||
@ -13,8 +13,8 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
MAX_NUMBER = dify_config.CODE_MAX_NUMBER
|
||||
MIN_NUMBER = dify_config.CODE_MIN_NUMBER
|
||||
MAX_PRECISION = 20
|
||||
MAX_DEPTH = 5
|
||||
MAX_PRECISION = dify_config.CODE_MAX_PRECISION
|
||||
MAX_DEPTH = dify_config.CODE_MAX_DEPTH
|
||||
MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH
|
||||
MAX_STRING_ARRAY_LENGTH = dify_config.CODE_MAX_STRING_ARRAY_LENGTH
|
||||
MAX_OBJECT_ARRAY_LENGTH = dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH
|
||||
@ -23,7 +23,7 @@ MAX_NUMBER_ARRAY_LENGTH = dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH
|
||||
|
||||
class CodeNode(BaseNode):
|
||||
_node_data_cls = CodeNodeData
|
||||
node_type = NodeType.CODE
|
||||
_node_type = NodeType.CODE
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
@ -48,8 +48,7 @@ class CodeNode(BaseNode):
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data: CodeNodeData = cast(self._node_data_cls, node_data)
|
||||
node_data = cast(CodeNodeData, self.node_data)
|
||||
|
||||
# Get code language
|
||||
code_language = node_data.code_language
|
||||
@ -68,7 +67,6 @@ class CodeNode(BaseNode):
|
||||
language=code_language,
|
||||
code=code,
|
||||
inputs=variables,
|
||||
dependencies=node_data.dependencies
|
||||
)
|
||||
|
||||
# Transform result
|
||||
|
||||
@ -3,7 +3,6 @@ from typing import Literal, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
@ -16,8 +15,12 @@ class CodeNodeData(BaseNodeData):
|
||||
type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]']
|
||||
children: Optional[dict[str, 'Output']] = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
name: str
|
||||
version: str
|
||||
|
||||
variables: list[VariableSelector]
|
||||
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
|
||||
code: str
|
||||
outputs: dict[str, Output]
|
||||
dependencies: Optional[list[CodeDependency]] = None
|
||||
dependencies: Optional[list[Dependency]] = None
|
||||
@ -24,7 +24,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
@ -94,7 +94,7 @@ class LLMNode(BaseNode):
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
node_data=node_data,
|
||||
query=variable_pool.get_any(['sys', SystemVariable.QUERY.value])
|
||||
query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value])
|
||||
if node_data.memory else None,
|
||||
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
||||
inputs=inputs,
|
||||
@ -113,7 +113,7 @@ class LLMNode(BaseNode):
|
||||
}
|
||||
|
||||
# handle invoke result
|
||||
result_text, usage = self._invoke_llm(
|
||||
result_text, usage, finish_reason = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
@ -129,7 +129,8 @@ class LLMNode(BaseNode):
|
||||
|
||||
outputs = {
|
||||
'text': result_text,
|
||||
'usage': jsonable_encoder(usage)
|
||||
'usage': jsonable_encoder(usage),
|
||||
'finish_reason': finish_reason
|
||||
}
|
||||
|
||||
return NodeRunResult(
|
||||
@ -167,14 +168,14 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
text, usage = self._handle_invoke_result(
|
||||
text, usage, finish_reason = self._handle_invoke_result(
|
||||
invoke_result=invoke_result
|
||||
)
|
||||
|
||||
# deduct quota
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
return text, usage, finish_reason
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
@ -186,6 +187,7 @@ class LLMNode(BaseNode):
|
||||
prompt_messages = []
|
||||
full_text = ''
|
||||
usage = None
|
||||
finish_reason = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
full_text += text
|
||||
@ -201,10 +203,13 @@ class LLMNode(BaseNode):
|
||||
if not usage and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
|
||||
if not finish_reason and result.delta.finish_reason:
|
||||
finish_reason = result.delta.finish_reason
|
||||
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
return full_text, usage
|
||||
return full_text, usage, finish_reason
|
||||
|
||||
def _transform_chat_messages(self,
|
||||
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
@ -335,7 +340,7 @@ class LLMNode(BaseNode):
|
||||
if not node_data.vision.enabled:
|
||||
return []
|
||||
|
||||
files = variable_pool.get_any(['sys', SystemVariable.FILES.value])
|
||||
files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value])
|
||||
if not files:
|
||||
return []
|
||||
|
||||
@ -500,7 +505,7 @@ class LLMNode(BaseNode):
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value])
|
||||
conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value])
|
||||
if conversation_id is None:
|
||||
return None
|
||||
|
||||
@ -672,10 +677,10 @@ class LLMNode(BaseNode):
|
||||
variable_mapping['#context#'] = node_data.context.variable_selector
|
||||
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
|
||||
variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value]
|
||||
|
||||
if node_data.memory:
|
||||
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
|
||||
variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value]
|
||||
|
||||
if node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
@ -63,7 +63,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
result_text, usage = self._invoke_llm(
|
||||
result_text, usage, finish_reason = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
@ -93,6 +93,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
prompt_messages=prompt_messages
|
||||
),
|
||||
'usage': jsonable_encoder(usage),
|
||||
'finish_reason': finish_reason
|
||||
}
|
||||
outputs = {
|
||||
'class_name': category_name
|
||||
|
||||
@ -1,3 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData):
|
||||
"""
|
||||
Start Node Data
|
||||
"""
|
||||
variables: list[VariableEntity] = []
|
||||
variables: Sequence[VariableEntity] = Field(default_factory=list)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
@ -17,16 +17,16 @@ class StartNode(BaseNode):
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
# Get cleaned inputs
|
||||
cleaned_inputs = dict(variable_pool.user_inputs)
|
||||
node_inputs = dict(variable_pool.user_inputs)
|
||||
system_inputs = variable_pool.system_variables
|
||||
|
||||
for var in variable_pool.system_variables:
|
||||
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
|
||||
for var in system_inputs:
|
||||
node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=cleaned_inputs,
|
||||
outputs=cleaned_inputs
|
||||
inputs=node_inputs,
|
||||
outputs=node_inputs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence
|
||||
from os import path
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.segments import ArrayAnyVariable, parser
|
||||
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
@ -141,8 +141,8 @@ class ToolNode(BaseNode):
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable)
|
||||
variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
|
||||
|
||||
@ -1,109 +1,8 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional, cast
|
||||
from .node import VariableAssignerNode
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.segments import SegmentType, Variable, factory
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class VariableAssignerNodeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WriteMode(str, Enum):
|
||||
OVER_WRITE = 'over-write'
|
||||
APPEND = 'append'
|
||||
CLEAR = 'clear'
|
||||
|
||||
|
||||
class VariableAssignerData(BaseNodeData):
|
||||
title: str = 'Variable Assigner'
|
||||
desc: Optional[str] = 'Assign a value to a variable'
|
||||
assigned_variable_selector: Sequence[str]
|
||||
write_mode: WriteMode
|
||||
input_variable_selector: Sequence[str]
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode):
|
||||
_node_data_cls: type[BaseNodeData] = VariableAssignerData
|
||||
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
data = cast(VariableAssignerData, self.node_data)
|
||||
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = variable_pool.get(data.assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableAssignerNodeError('assigned variable not found')
|
||||
|
||||
match data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = variable_pool.get(data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError('input value not found')
|
||||
updated_variable = original_variable.model_copy(update={'value': income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = variable_pool.get(data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError('input value not found')
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
updated_variable = original_variable.model_copy(update={'value': updated_value})
|
||||
|
||||
case WriteMode.CLEAR:
|
||||
income_value = get_zero_value(original_variable.value_type)
|
||||
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
|
||||
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
|
||||
|
||||
# Over write the variable.
|
||||
variable_pool.add(data.assigned_variable_selector, updated_variable)
|
||||
|
||||
# Update conversation variable.
|
||||
# TODO: Find a better way to use the database.
|
||||
conversation_id = variable_pool.get(['sys', 'conversation_id'])
|
||||
if not conversation_id:
|
||||
raise VariableAssignerNodeError('conversation_id not found')
|
||||
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
'value': income_value.to_object(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableAssignerNodeError('conversation variable not found in the database')
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_zero_value(t: SegmentType):
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
return factory.build_segment([])
|
||||
case SegmentType.OBJECT:
|
||||
return factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
return factory.build_segment('')
|
||||
case SegmentType.NUMBER:
|
||||
return factory.build_segment(0)
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f'unsupported variable type: {t}')
|
||||
__all__ = [
|
||||
'VariableAssignerNode',
|
||||
'VariableAssignerData',
|
||||
'WriteMode',
|
||||
]
|
||||
|
||||
2
api/core/workflow/nodes/variable_assigner/exc.py
Normal file
2
api/core/workflow/nodes/variable_assigner/exc.py
Normal file
@ -0,0 +1,2 @@
|
||||
class VariableAssignerNodeError(Exception):
|
||||
pass
|
||||
92
api/core/workflow/nodes/variable_assigner/node.py
Normal file
92
api/core/workflow/nodes/variable_assigner/node.py
Normal file
@ -0,0 +1,92 @@
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.segments import SegmentType, Variable, factory
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable, WorkflowNodeExecutionStatus
|
||||
|
||||
from .exc import VariableAssignerNodeError
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode):
|
||||
_node_data_cls: type[BaseNodeData] = VariableAssignerData
|
||||
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
data = cast(VariableAssignerData, self.node_data)
|
||||
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = variable_pool.get(data.assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableAssignerNodeError('assigned variable not found')
|
||||
|
||||
match data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = variable_pool.get(data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError('input value not found')
|
||||
updated_variable = original_variable.model_copy(update={'value': income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = variable_pool.get(data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError('input value not found')
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
updated_variable = original_variable.model_copy(update={'value': updated_value})
|
||||
|
||||
case WriteMode.CLEAR:
|
||||
income_value = get_zero_value(original_variable.value_type)
|
||||
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
|
||||
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
|
||||
|
||||
# Over write the variable.
|
||||
variable_pool.add(data.assigned_variable_selector, updated_variable)
|
||||
|
||||
# TODO: Move database operation to the pipeline.
|
||||
# Update conversation variable.
|
||||
conversation_id = variable_pool.get(['sys', 'conversation_id'])
|
||||
if not conversation_id:
|
||||
raise VariableAssignerNodeError('conversation_id not found')
|
||||
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
'value': income_value.to_object(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableAssignerNodeError('conversation variable not found in the database')
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_zero_value(t: SegmentType):
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
return factory.build_segment([])
|
||||
case SegmentType.OBJECT:
|
||||
return factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
return factory.build_segment('')
|
||||
case SegmentType.NUMBER:
|
||||
return factory.build_segment(0)
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f'unsupported variable type: {t}')
|
||||
19
api/core/workflow/nodes/variable_assigner/node_data.py
Normal file
19
api/core/workflow/nodes/variable_assigner/node_data.py
Normal file
@ -0,0 +1,19 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class WriteMode(str, Enum):
|
||||
OVER_WRITE = 'over-write'
|
||||
APPEND = 'append'
|
||||
CLEAR = 'clear'
|
||||
|
||||
|
||||
class VariableAssignerData(BaseNodeData):
|
||||
title: str = 'Variable Assigner'
|
||||
desc: Optional[str] = 'Assign a value to a variable'
|
||||
assigned_variable_selector: Sequence[str]
|
||||
write_mode: WriteMode
|
||||
input_variable_selector: Sequence[str]
|
||||
@ -28,6 +28,16 @@ class S3Storage(BaseStorage):
|
||||
region_name=app_config.get("S3_REGION"),
|
||||
config=Config(s3={"addressing_style": app_config.get("S3_ADDRESS_STYLE")}),
|
||||
)
|
||||
# create bucket
|
||||
try:
|
||||
self.client.head_bucket(Bucket=self.bucket_name)
|
||||
except ClientError as e:
|
||||
# if bucket not exists, create it
|
||||
if e.response["Error"]["Code"] == "404":
|
||||
self.client.create_bucket(Bucket=self.bucket_name)
|
||||
else:
|
||||
# other error, raise exception
|
||||
raise
|
||||
|
||||
def save(self, filename, data):
|
||||
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
|
||||
|
||||
@ -150,6 +150,7 @@ conversation_with_summary_fields = {
|
||||
"summary": fields.String(attribute="summary_or_query"),
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"model_config": fields.Nested(simple_model_config_fields),
|
||||
"message_count": fields.Integer,
|
||||
|
||||
@ -0,0 +1,28 @@
|
||||
"""rename workflow__conversation_variables to workflow_conversation_variables
|
||||
|
||||
Revision ID: 2dbe42621d96
|
||||
Revises: a6be81136580
|
||||
Create Date: 2024-08-20 04:55:38.160010
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '2dbe42621d96'
|
||||
down_revision = 'a6be81136580'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.rename_table('workflow__conversation_variables', 'workflow_conversation_variables')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.rename_table('workflow_conversation_variables', 'workflow__conversation_variables')
|
||||
# ### end Alembic commands ###
|
||||
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import enum
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
@ -22,6 +23,11 @@ from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class DatasetPermissionEnum(str, enum.Enum):
|
||||
ONLY_ME = 'only_me'
|
||||
ALL_TEAM = 'all_team_members'
|
||||
PARTIAL_TEAM = 'partial_members'
|
||||
|
||||
class Dataset(db.Model):
|
||||
__tablename__ = 'datasets'
|
||||
__table_args__ = (
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@ -110,19 +111,32 @@ class Workflow(db.Model):
|
||||
db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
type = db.Column(db.String(255), nullable=False)
|
||||
version = db.Column(db.String(255), nullable=False)
|
||||
graph = db.Column(db.Text)
|
||||
features = db.Column(db.Text)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_by = db.Column(StringUUID)
|
||||
updated_at = db.Column(db.DateTime)
|
||||
_environment_variables = db.Column('environment_variables', db.Text, nullable=False, server_default='{}')
|
||||
_conversation_variables = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}')
|
||||
id: Mapped[str] = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
version: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
graph: Mapped[str] = db.Column(db.Text)
|
||||
features: Mapped[str] = db.Column(db.Text)
|
||||
created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_by: Mapped[str] = db.Column(StringUUID)
|
||||
updated_at: Mapped[datetime] = db.Column(db.DateTime)
|
||||
_environment_variables: Mapped[str] = db.Column('environment_variables', db.Text, nullable=False, server_default='{}')
|
||||
_conversation_variables: Mapped[str] = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}')
|
||||
|
||||
def __init__(self, *, tenant_id: str, app_id: str, type: str, version: str, graph: str,
|
||||
features: str, created_by: str, environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable]):
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
self.type = type
|
||||
self.version = version
|
||||
self.graph = graph
|
||||
self.features = features
|
||||
self.created_by = created_by
|
||||
self.environment_variables = environment_variables or []
|
||||
self.conversation_variables = conversation_variables or []
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
@ -724,7 +738,7 @@ class WorkflowAppLog(db.Model):
|
||||
|
||||
|
||||
class ConversationVariable(db.Model):
|
||||
__tablename__ = 'workflow__conversation_variables'
|
||||
__tablename__ = 'workflow_conversation_variables'
|
||||
|
||||
id: Mapped[str] = db.Column(StringUUID, primary_key=True)
|
||||
conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True)
|
||||
|
||||
1556
api/poetry.lock
generated
1556
api/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -153,7 +153,7 @@ langfuse = "^2.36.1"
|
||||
langsmith = "^0.1.77"
|
||||
mailchimp-transactional = "~1.0.50"
|
||||
markdown = "~3.5.1"
|
||||
novita-client = "^0.5.6"
|
||||
novita-client = "^0.5.7"
|
||||
numpy = "~1.26.4"
|
||||
openai = "~1.29.0"
|
||||
openpyxl = "~3.1.5"
|
||||
|
||||
@ -111,6 +111,12 @@ class AppService:
|
||||
'completion_params': {}
|
||||
}
|
||||
else:
|
||||
provider, model = model_manager.get_default_provider_model_name(
|
||||
tenant_id=account.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
default_model_config['model']['provider'] = provider
|
||||
default_model_config['model']['name'] = model
|
||||
default_model_dict = default_model_config['model']
|
||||
|
||||
default_model_config['model'] = json.dumps(default_model_dict)
|
||||
@ -190,13 +196,14 @@ class AppService:
|
||||
"""
|
||||
Modified App class
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
self.__dict__.update(app.__dict__)
|
||||
|
||||
@property
|
||||
def app_model_config(self):
|
||||
return model_config
|
||||
|
||||
|
||||
app = ModifiedApp(app)
|
||||
|
||||
return app
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import asc, desc, or_
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
@ -18,7 +19,8 @@ class ConversationService:
|
||||
last_id: Optional[str], limit: int,
|
||||
invoke_from: InvokeFrom,
|
||||
include_ids: Optional[list] = None,
|
||||
exclude_ids: Optional[list] = None) -> InfiniteScrollPagination:
|
||||
exclude_ids: Optional[list] = None,
|
||||
sort_by: str = '-updated_at') -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
@ -37,28 +39,28 @@ class ConversationService:
|
||||
if exclude_ids is not None:
|
||||
base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
|
||||
|
||||
if last_id:
|
||||
last_conversation = base_query.filter(
|
||||
Conversation.id == last_id,
|
||||
).first()
|
||||
# define sort fields and directions
|
||||
sort_field, sort_direction = cls._get_sort_params(sort_by)
|
||||
|
||||
if last_id:
|
||||
last_conversation = base_query.filter(Conversation.id == last_id).first()
|
||||
if not last_conversation:
|
||||
raise LastConversationNotExistsError()
|
||||
|
||||
conversations = base_query.filter(
|
||||
Conversation.created_at < last_conversation.created_at,
|
||||
Conversation.id != last_conversation.id
|
||||
).order_by(Conversation.created_at.desc()).limit(limit).all()
|
||||
else:
|
||||
conversations = base_query.order_by(Conversation.created_at.desc()).limit(limit).all()
|
||||
# build filters based on sorting
|
||||
filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation)
|
||||
base_query = base_query.filter(filter_condition)
|
||||
|
||||
base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field)))
|
||||
|
||||
conversations = base_query.limit(limit).all()
|
||||
|
||||
has_more = False
|
||||
if len(conversations) == limit:
|
||||
current_page_first_conversation = conversations[-1]
|
||||
rest_count = base_query.filter(
|
||||
Conversation.created_at < current_page_first_conversation.created_at,
|
||||
Conversation.id != current_page_first_conversation.id
|
||||
).count()
|
||||
current_page_last_conversation = conversations[-1]
|
||||
rest_filter_condition = cls._build_filter_condition(sort_field, sort_direction,
|
||||
current_page_last_conversation, is_next_page=True)
|
||||
rest_count = base_query.filter(rest_filter_condition).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
@ -69,6 +71,21 @@ class ConversationService:
|
||||
has_more=has_more
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]:
|
||||
if sort_by.startswith('-'):
|
||||
return sort_by[1:], desc
|
||||
return sort_by, asc
|
||||
|
||||
@classmethod
|
||||
def _build_filter_condition(cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation,
|
||||
is_next_page: bool = False):
|
||||
field_value = getattr(reference_conversation, sort_field)
|
||||
if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page):
|
||||
return getattr(Conversation, sort_field) < field_value
|
||||
else:
|
||||
return getattr(Conversation, sort_field) > field_value
|
||||
|
||||
@classmethod
|
||||
def rename(cls, app_model: App, conversation_id: str,
|
||||
user: Optional[Union[Account, EndUser]], name: str, auto_generate: bool):
|
||||
@ -78,6 +95,7 @@ class ConversationService:
|
||||
return cls.auto_generate_name(app_model, conversation)
|
||||
else:
|
||||
conversation.name = name
|
||||
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return conversation
|
||||
@ -87,9 +105,9 @@ class ConversationService:
|
||||
# get conversation first message
|
||||
message = db.session.query(Message) \
|
||||
.filter(
|
||||
Message.app_id == app_model.id,
|
||||
Message.conversation_id == conversation.id
|
||||
).order_by(Message.created_at.asc()).first()
|
||||
Message.app_id == app_model.id,
|
||||
Message.conversation_id == conversation.id
|
||||
).order_by(Message.created_at.asc()).first()
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
@ -27,6 +27,7 @@ from models.dataset import (
|
||||
Dataset,
|
||||
DatasetCollectionBinding,
|
||||
DatasetPermission,
|
||||
DatasetPermissionEnum,
|
||||
DatasetProcessRule,
|
||||
DatasetQuery,
|
||||
Document,
|
||||
@ -80,21 +81,21 @@ class DatasetService:
|
||||
if permitted_dataset_ids:
|
||||
query = query.filter(
|
||||
db.or_(
|
||||
Dataset.permission == 'all_team_members',
|
||||
db.and_(Dataset.permission == 'only_me', Dataset.created_by == user.id),
|
||||
db.and_(Dataset.permission == 'partial_members', Dataset.id.in_(permitted_dataset_ids))
|
||||
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
|
||||
db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id),
|
||||
db.and_(Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, Dataset.id.in_(permitted_dataset_ids))
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = query.filter(
|
||||
db.or_(
|
||||
Dataset.permission == 'all_team_members',
|
||||
db.and_(Dataset.permission == 'only_me', Dataset.created_by == user.id)
|
||||
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
|
||||
db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# if no user, only show datasets that are shared with all team members
|
||||
query = query.filter(Dataset.permission == 'all_team_members')
|
||||
query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
|
||||
|
||||
if search:
|
||||
query = query.filter(Dataset.name.ilike(f'%{search}%'))
|
||||
@ -330,7 +331,7 @@ class DatasetService:
|
||||
raise NoPermissionError(
|
||||
'You do not have permission to access this dataset.'
|
||||
)
|
||||
if dataset.permission == 'only_me' and dataset.created_by != user.id:
|
||||
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
|
||||
logging.debug(
|
||||
f'User {user.id} does not have permission to access dataset {dataset.id}'
|
||||
)
|
||||
@ -351,11 +352,11 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None):
|
||||
if dataset.permission == 'only_me':
|
||||
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
|
||||
if dataset.created_by != user.id:
|
||||
raise NoPermissionError('You do not have permission to access this dataset.')
|
||||
|
||||
elif dataset.permission == 'partial_members':
|
||||
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
if not any(
|
||||
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all()
|
||||
):
|
||||
|
||||
@ -30,6 +30,7 @@ class ModelProviderService:
|
||||
"""
|
||||
Model Provider Service
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.provider_manager = ProviderManager()
|
||||
|
||||
@ -387,18 +388,21 @@ class ModelProviderService:
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type_enum
|
||||
)
|
||||
|
||||
return DefaultModelResponse(
|
||||
model=result.model,
|
||||
model_type=result.model_type,
|
||||
provider=SimpleProviderEntityResponse(
|
||||
provider=result.provider.provider,
|
||||
label=result.provider.label,
|
||||
icon_small=result.provider.icon_small,
|
||||
icon_large=result.provider.icon_large,
|
||||
supported_model_types=result.provider.supported_model_types
|
||||
)
|
||||
) if result else None
|
||||
try:
|
||||
return DefaultModelResponse(
|
||||
model=result.model,
|
||||
model_type=result.model_type,
|
||||
provider=SimpleProviderEntityResponse(
|
||||
provider=result.provider.provider,
|
||||
label=result.provider.label,
|
||||
icon_small=result.provider.icon_small,
|
||||
icon_large=result.provider.icon_large,
|
||||
supported_model_types=result.provider.supported_model_types
|
||||
)
|
||||
) if result else None
|
||||
except Exception as e:
|
||||
logger.info(f"get_default_model_of_model_type error: {e}")
|
||||
return None
|
||||
|
||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
|
||||
"""
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||
@ -43,14 +45,14 @@ class BuiltinToolManageService:
|
||||
result = []
|
||||
for tool in tools:
|
||||
result.append(ToolTransformService.tool_to_user_tool(
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller)
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(
|
||||
provider_name
|
||||
@ -78,7 +80,7 @@ class BuiltinToolManageService:
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
).first()
|
||||
|
||||
try:
|
||||
try:
|
||||
# get provider
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
if not provider_controller.need_credentials:
|
||||
@ -119,8 +121,8 @@ class BuiltinToolManageService:
|
||||
# delete cache
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_credentials(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
@ -135,7 +137,7 @@ class BuiltinToolManageService:
|
||||
|
||||
if provider is None:
|
||||
return {}
|
||||
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider.provider)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
@ -156,7 +158,7 @@ class BuiltinToolManageService:
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider_name}')
|
||||
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
|
||||
@ -165,8 +167,8 @@ class BuiltinToolManageService:
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_icon(
|
||||
provider: str
|
||||
@ -179,7 +181,7 @@ class BuiltinToolManageService:
|
||||
icon_bytes = f.read()
|
||||
|
||||
return icon_bytes, mime_type
|
||||
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tools(
|
||||
user_id: str, tenant_id: str
|
||||
@ -202,6 +204,15 @@ class BuiltinToolManageService:
|
||||
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name
|
||||
):
|
||||
continue
|
||||
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
@ -226,4 +237,3 @@ class BuiltinToolManageService:
|
||||
raise e
|
||||
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
@ -13,7 +13,8 @@ class WebConversationService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str], limit: int, invoke_from: InvokeFrom,
|
||||
pinned: Optional[bool] = None) -> InfiniteScrollPagination:
|
||||
pinned: Optional[bool] = None,
|
||||
sort_by='-updated_at') -> InfiniteScrollPagination:
|
||||
include_ids = None
|
||||
exclude_ids = None
|
||||
if pinned is not None:
|
||||
@ -36,6 +37,7 @@ class WebConversationService:
|
||||
invoke_from=invoke_from,
|
||||
include_ids=include_ids,
|
||||
exclude_ids=exclude_ids,
|
||||
sort_by=sort_by
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -32,12 +32,9 @@ class WorkflowConverter:
|
||||
App Convert to Workflow Mode
|
||||
"""
|
||||
|
||||
def convert_to_workflow(self, app_model: App,
|
||||
account: Account,
|
||||
name: str,
|
||||
icon_type: str,
|
||||
icon: str,
|
||||
icon_background: str) -> App:
|
||||
def convert_to_workflow(
|
||||
self, app_model: App, account: Account, name: str, icon_type: str, icon: str, icon_background: str
|
||||
):
|
||||
"""
|
||||
Convert app to workflow
|
||||
|
||||
@ -56,18 +53,18 @@ class WorkflowConverter:
|
||||
:return: new App instance
|
||||
"""
|
||||
# convert app model config
|
||||
if not app_model.app_model_config:
|
||||
raise ValueError("App model config is required")
|
||||
|
||||
workflow = self.convert_app_model_config_to_workflow(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model.app_model_config,
|
||||
account_id=account.id
|
||||
app_model=app_model, app_model_config=app_model.app_model_config, account_id=account.id
|
||||
)
|
||||
|
||||
# create new app
|
||||
new_app = App()
|
||||
new_app.tenant_id = app_model.tenant_id
|
||||
new_app.name = name if name else app_model.name + '(workflow)'
|
||||
new_app.mode = AppMode.ADVANCED_CHAT.value \
|
||||
if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
|
||||
new_app.name = name if name else app_model.name + "(workflow)"
|
||||
new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
|
||||
new_app.icon_type = icon_type if icon_type else app_model.icon_type
|
||||
new_app.icon = icon if icon else app_model.icon
|
||||
new_app.icon_background = icon_background if icon_background else app_model.icon_background
|
||||
@ -88,30 +85,21 @@ class WorkflowConverter:
|
||||
|
||||
return new_app
|
||||
|
||||
def convert_app_model_config_to_workflow(self, app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
account_id: str) -> Workflow:
|
||||
def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: AppModelConfig, account_id: str):
|
||||
"""
|
||||
Convert app model config to workflow mode
|
||||
:param app_model: App instance
|
||||
:param app_model_config: AppModelConfig instance
|
||||
:param account_id: Account ID
|
||||
:return:
|
||||
"""
|
||||
# get new app mode
|
||||
new_app_mode = self._get_new_app_mode(app_model)
|
||||
|
||||
# convert app model config
|
||||
app_config = self._convert_to_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config)
|
||||
|
||||
# init workflow graph
|
||||
graph = {
|
||||
"nodes": [],
|
||||
"edges": []
|
||||
}
|
||||
graph = {"nodes": [], "edges": []}
|
||||
|
||||
# Convert list:
|
||||
# - variables -> start
|
||||
@ -123,11 +111,9 @@ class WorkflowConverter:
|
||||
# - show_retrieve_source -> knowledge-retrieval
|
||||
|
||||
# convert to start node
|
||||
start_node = self._convert_to_start_node(
|
||||
variables=app_config.variables
|
||||
)
|
||||
start_node = self._convert_to_start_node(variables=app_config.variables)
|
||||
|
||||
graph['nodes'].append(start_node)
|
||||
graph["nodes"].append(start_node)
|
||||
|
||||
# convert to http request node
|
||||
external_data_variable_node_mapping = {}
|
||||
@ -135,7 +121,7 @@ class WorkflowConverter:
|
||||
http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node(
|
||||
app_model=app_model,
|
||||
variables=app_config.variables,
|
||||
external_data_variables=app_config.external_data_variables
|
||||
external_data_variables=app_config.external_data_variables,
|
||||
)
|
||||
|
||||
for http_request_node in http_request_nodes:
|
||||
@ -144,9 +130,7 @@ class WorkflowConverter:
|
||||
# convert to knowledge retrieval node
|
||||
if app_config.dataset:
|
||||
knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=new_app_mode,
|
||||
dataset_config=app_config.dataset,
|
||||
model_config=app_config.model
|
||||
new_app_mode=new_app_mode, dataset_config=app_config.dataset, model_config=app_config.model
|
||||
)
|
||||
|
||||
if knowledge_retrieval_node:
|
||||
@ -160,7 +144,7 @@ class WorkflowConverter:
|
||||
model_config=app_config.model,
|
||||
prompt_template=app_config.prompt_template,
|
||||
file_upload=app_config.additional_features.file_upload,
|
||||
external_data_variable_node_mapping=external_data_variable_node_mapping
|
||||
external_data_variable_node_mapping=external_data_variable_node_mapping,
|
||||
)
|
||||
|
||||
graph = self._append_node(graph, llm_node)
|
||||
@ -199,7 +183,7 @@ class WorkflowConverter:
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=WorkflowType.from_app_mode(new_app_mode).value,
|
||||
version='draft',
|
||||
version="draft",
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps(features),
|
||||
created_by=account_id,
|
||||
@ -212,24 +196,18 @@ class WorkflowConverter:
|
||||
|
||||
return workflow
|
||||
|
||||
def _convert_to_app_config(self, app_model: App,
|
||||
app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
|
||||
def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
app_model.mode = AppMode.AGENT_CHAT.value
|
||||
app_config = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
app_model=app_model, app_model_config=app_model_config
|
||||
)
|
||||
elif app_mode == AppMode.CHAT:
|
||||
app_config = ChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
app_model=app_model, app_model_config=app_model_config
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
@ -248,14 +226,13 @@ class WorkflowConverter:
|
||||
"data": {
|
||||
"title": "START",
|
||||
"type": NodeType.START.value,
|
||||
"variables": [jsonable_encoder(v) for v in variables]
|
||||
}
|
||||
"variables": [jsonable_encoder(v) for v in variables],
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_to_http_request_node(self, app_model: App,
|
||||
variables: list[VariableEntity],
|
||||
external_data_variables: list[ExternalDataVariableEntity]) \
|
||||
-> tuple[list[dict], dict[str, str]]:
|
||||
def _convert_to_http_request_node(
|
||||
self, app_model: App, variables: list[VariableEntity], external_data_variables: list[ExternalDataVariableEntity]
|
||||
) -> tuple[list[dict], dict[str, str]]:
|
||||
"""
|
||||
Convert API Based Extension to HTTP Request Node
|
||||
:param app_model: App instance
|
||||
@ -277,40 +254,33 @@ class WorkflowConverter:
|
||||
|
||||
# get params from config
|
||||
api_based_extension_id = tool_config.get("api_based_extension_id")
|
||||
if not api_based_extension_id:
|
||||
continue
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = self._get_api_based_extension(
|
||||
tenant_id=tenant_id,
|
||||
api_based_extension_id=api_based_extension_id
|
||||
tenant_id=tenant_id, api_based_extension_id=api_based_extension_id
|
||||
)
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, "
|
||||
"error: api_based_extension_id is invalid"
|
||||
.format(tool_variable))
|
||||
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=tenant_id,
|
||||
token=api_based_extension.api_key
|
||||
)
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=api_based_extension.api_key)
|
||||
|
||||
inputs = {}
|
||||
for v in variables:
|
||||
inputs[v.variable] = '{{#start.' + v.variable + '#}}'
|
||||
inputs[v.variable] = "{{#start." + v.variable + "#}}"
|
||||
|
||||
request_body = {
|
||||
'point': APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value,
|
||||
'params': {
|
||||
'app_id': app_model.id,
|
||||
'tool_variable': tool_variable,
|
||||
'inputs': inputs,
|
||||
'query': '{{#sys.query#}}' if app_model.mode == AppMode.CHAT.value else ''
|
||||
}
|
||||
"point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value,
|
||||
"params": {
|
||||
"app_id": app_model.id,
|
||||
"tool_variable": tool_variable,
|
||||
"inputs": inputs,
|
||||
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "",
|
||||
},
|
||||
}
|
||||
|
||||
request_body_json = json.dumps(request_body)
|
||||
request_body_json = request_body_json.replace(r'\{\{', '{{').replace(r'\}\}', '}}')
|
||||
request_body_json = request_body_json.replace(r"\{\{", "{{").replace(r"\}\}", "}}")
|
||||
|
||||
http_request_node = {
|
||||
"id": f"http_request_{index}",
|
||||
@ -320,20 +290,11 @@ class WorkflowConverter:
|
||||
"type": NodeType.HTTP_REQUEST.value,
|
||||
"method": "post",
|
||||
"url": api_based_extension.api_endpoint,
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "bearer",
|
||||
"api_key": api_key
|
||||
}
|
||||
},
|
||||
"authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}},
|
||||
"headers": "",
|
||||
"params": "",
|
||||
"body": {
|
||||
"type": "json",
|
||||
"data": request_body_json
|
||||
}
|
||||
}
|
||||
"body": {"type": "json", "data": request_body_json},
|
||||
},
|
||||
}
|
||||
|
||||
nodes.append(http_request_node)
|
||||
@ -345,32 +306,24 @@ class WorkflowConverter:
|
||||
"data": {
|
||||
"title": f"Parse {api_based_extension.name} Response",
|
||||
"type": NodeType.CODE.value,
|
||||
"variables": [{
|
||||
"variable": "response_json",
|
||||
"value_selector": [http_request_node['id'], "body"]
|
||||
}],
|
||||
"variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}],
|
||||
"code_language": "python3",
|
||||
"code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads("
|
||||
"response_json)\n return {\n \"result\": response_body[\"result\"]\n }",
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
'response_json)\n return {\n "result": response_body["result"]\n }',
|
||||
"outputs": {"result": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
|
||||
nodes.append(code_node)
|
||||
|
||||
external_data_variable_node_mapping[external_data_variable.variable] = code_node['id']
|
||||
external_data_variable_node_mapping[external_data_variable.variable] = code_node["id"]
|
||||
index += 1
|
||||
|
||||
return nodes, external_data_variable_node_mapping
|
||||
|
||||
def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode,
|
||||
dataset_config: DatasetEntity,
|
||||
model_config: ModelConfigEntity) \
|
||||
-> Optional[dict]:
|
||||
def _convert_to_knowledge_retrieval_node(
|
||||
self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Convert datasets to Knowledge Retrieval Node
|
||||
:param new_app_mode: new app mode
|
||||
@ -404,7 +357,7 @@ class WorkflowConverter:
|
||||
"completion_params": {
|
||||
**model_config.parameters,
|
||||
"stop": model_config.stop,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
|
||||
@ -412,20 +365,23 @@ class WorkflowConverter:
|
||||
"multiple_retrieval_config": {
|
||||
"top_k": retrieve_config.top_k,
|
||||
"score_threshold": retrieve_config.score_threshold,
|
||||
"reranking_model": retrieve_config.reranking_model
|
||||
"reranking_model": retrieve_config.reranking_model,
|
||||
}
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE
|
||||
else None,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_to_llm_node(self, original_app_mode: AppMode,
|
||||
new_app_mode: AppMode,
|
||||
graph: dict,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template: PromptTemplateEntity,
|
||||
file_upload: Optional[FileExtraConfig] = None,
|
||||
external_data_variable_node_mapping: dict[str, str] = None) -> dict:
|
||||
def _convert_to_llm_node(
|
||||
self,
|
||||
original_app_mode: AppMode,
|
||||
new_app_mode: AppMode,
|
||||
graph: dict,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template: PromptTemplateEntity,
|
||||
file_upload: Optional[FileExtraConfig] = None,
|
||||
external_data_variable_node_mapping: dict[str, str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert to LLM Node
|
||||
:param original_app_mode: original app mode
|
||||
@ -437,17 +393,18 @@ class WorkflowConverter:
|
||||
:param external_data_variable_node_mapping: external data variable node mapping
|
||||
"""
|
||||
# fetch start and knowledge retrieval node
|
||||
start_node = next(filter(lambda n: n['data']['type'] == NodeType.START.value, graph['nodes']))
|
||||
knowledge_retrieval_node = next(filter(
|
||||
lambda n: n['data']['type'] == NodeType.KNOWLEDGE_RETRIEVAL.value,
|
||||
graph['nodes']
|
||||
), None)
|
||||
start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START.value, graph["nodes"]))
|
||||
knowledge_retrieval_node = next(
|
||||
filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None
|
||||
)
|
||||
|
||||
role_prefix = None
|
||||
|
||||
# Chat Model
|
||||
if model_config.mode == LLMMode.CHAT.value:
|
||||
if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
if not prompt_template.simple_prompt_template:
|
||||
raise ValueError("Simple prompt template is required")
|
||||
# get prompt template
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_template_config = prompt_transform.get_prompt_template(
|
||||
@ -456,45 +413,35 @@ class WorkflowConverter:
|
||||
model=model_config.model,
|
||||
pre_prompt=prompt_template.simple_prompt_template,
|
||||
has_context=knowledge_retrieval_node is not None,
|
||||
query_in_prompt=False
|
||||
query_in_prompt=False,
|
||||
)
|
||||
|
||||
template = prompt_template_config['prompt_template'].template
|
||||
template = prompt_template_config["prompt_template"].template
|
||||
if not template:
|
||||
prompts = []
|
||||
else:
|
||||
template = self._replace_template_variables(
|
||||
template,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
template, start_node["data"]["variables"], external_data_variable_node_mapping
|
||||
)
|
||||
|
||||
prompts = [
|
||||
{
|
||||
"role": 'user',
|
||||
"text": template
|
||||
}
|
||||
]
|
||||
prompts = [{"role": "user", "text": template}]
|
||||
else:
|
||||
advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template
|
||||
|
||||
prompts = []
|
||||
for m in advanced_chat_prompt_template.messages:
|
||||
if advanced_chat_prompt_template:
|
||||
if advanced_chat_prompt_template:
|
||||
for m in advanced_chat_prompt_template.messages:
|
||||
text = m.text
|
||||
text = self._replace_template_variables(
|
||||
text,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
text, start_node["data"]["variables"], external_data_variable_node_mapping
|
||||
)
|
||||
|
||||
prompts.append({
|
||||
"role": m.role.value,
|
||||
"text": text
|
||||
})
|
||||
prompts.append({"role": m.role.value, "text": text})
|
||||
# Completion Model
|
||||
else:
|
||||
if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
if not prompt_template.simple_prompt_template:
|
||||
raise ValueError("Simple prompt template is required")
|
||||
# get prompt template
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_template_config = prompt_transform.get_prompt_template(
|
||||
@ -503,57 +450,50 @@ class WorkflowConverter:
|
||||
model=model_config.model,
|
||||
pre_prompt=prompt_template.simple_prompt_template,
|
||||
has_context=knowledge_retrieval_node is not None,
|
||||
query_in_prompt=False
|
||||
query_in_prompt=False,
|
||||
)
|
||||
|
||||
template = prompt_template_config['prompt_template'].template
|
||||
template = prompt_template_config["prompt_template"].template
|
||||
template = self._replace_template_variables(
|
||||
template,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
template=template,
|
||||
variables=start_node["data"]["variables"],
|
||||
external_data_variable_node_mapping=external_data_variable_node_mapping,
|
||||
)
|
||||
|
||||
prompts = {
|
||||
"text": template
|
||||
}
|
||||
prompts = {"text": template}
|
||||
|
||||
prompt_rules = prompt_template_config['prompt_rules']
|
||||
prompt_rules = prompt_template_config["prompt_rules"]
|
||||
role_prefix = {
|
||||
"user": prompt_rules.get('human_prefix', 'Human'),
|
||||
"assistant": prompt_rules.get('assistant_prefix', 'Assistant')
|
||||
"user": prompt_rules.get("human_prefix", "Human"),
|
||||
"assistant": prompt_rules.get("assistant_prefix", "Assistant"),
|
||||
}
|
||||
else:
|
||||
advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template
|
||||
if advanced_completion_prompt_template:
|
||||
text = advanced_completion_prompt_template.prompt
|
||||
text = self._replace_template_variables(
|
||||
text,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
template=text,
|
||||
variables=start_node["data"]["variables"],
|
||||
external_data_variable_node_mapping=external_data_variable_node_mapping,
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
|
||||
text = text.replace('{{#query#}}', '{{#sys.query#}}')
|
||||
text = text.replace("{{#query#}}", "{{#sys.query#}}")
|
||||
|
||||
prompts = {
|
||||
"text": text,
|
||||
}
|
||||
|
||||
if advanced_completion_prompt_template.role_prefix:
|
||||
if advanced_completion_prompt_template and advanced_completion_prompt_template.role_prefix:
|
||||
role_prefix = {
|
||||
"user": advanced_completion_prompt_template.role_prefix.user,
|
||||
"assistant": advanced_completion_prompt_template.role_prefix.assistant
|
||||
"assistant": advanced_completion_prompt_template.role_prefix.assistant,
|
||||
}
|
||||
|
||||
memory = None
|
||||
if new_app_mode == AppMode.ADVANCED_CHAT:
|
||||
memory = {
|
||||
"role_prefix": role_prefix,
|
||||
"window": {
|
||||
"enabled": False
|
||||
}
|
||||
}
|
||||
memory = {"role_prefix": role_prefix, "window": {"enabled": False}}
|
||||
|
||||
completion_params = model_config.parameters
|
||||
completion_params.update({"stop": model_config.stop})
|
||||
@ -567,28 +507,29 @@ class WorkflowConverter:
|
||||
"provider": model_config.provider,
|
||||
"name": model_config.model,
|
||||
"mode": model_config.mode,
|
||||
"completion_params": completion_params
|
||||
"completion_params": completion_params,
|
||||
},
|
||||
"prompt_template": prompts,
|
||||
"memory": memory,
|
||||
"context": {
|
||||
"enabled": knowledge_retrieval_node is not None,
|
||||
"variable_selector": ["knowledge_retrieval", "result"]
|
||||
if knowledge_retrieval_node is not None else None
|
||||
if knowledge_retrieval_node is not None
|
||||
else None,
|
||||
},
|
||||
"vision": {
|
||||
"enabled": file_upload is not None,
|
||||
"variable_selector": ["sys", "files"] if file_upload is not None else None,
|
||||
"configs": {
|
||||
"detail": file_upload.image_config['detail']
|
||||
} if file_upload is not None else None
|
||||
}
|
||||
}
|
||||
"configs": {"detail": file_upload.image_config["detail"]}
|
||||
if file_upload is not None and file_upload.image_config is not None
|
||||
else None,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _replace_template_variables(self, template: str,
|
||||
variables: list[dict],
|
||||
external_data_variable_node_mapping: dict[str, str] = None) -> str:
|
||||
def _replace_template_variables(
|
||||
self, template: str, variables: list[dict], external_data_variable_node_mapping: dict[str, str] | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Replace Template Variables
|
||||
:param template: template
|
||||
@ -597,12 +538,11 @@ class WorkflowConverter:
|
||||
:return:
|
||||
"""
|
||||
for v in variables:
|
||||
template = template.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}')
|
||||
template = template.replace("{{" + v["variable"] + "}}", "{{#start." + v["variable"] + "#}}")
|
||||
|
||||
if external_data_variable_node_mapping:
|
||||
for variable, code_node_id in external_data_variable_node_mapping.items():
|
||||
template = template.replace('{{' + variable + '}}',
|
||||
'{{#' + code_node_id + '.result#}}')
|
||||
template = template.replace("{{" + variable + "}}", "{{#" + code_node_id + ".result#}}")
|
||||
|
||||
return template
|
||||
|
||||
@ -618,11 +558,8 @@ class WorkflowConverter:
|
||||
"data": {
|
||||
"title": "END",
|
||||
"type": NodeType.END.value,
|
||||
"outputs": [{
|
||||
"variable": "result",
|
||||
"value_selector": ["llm", "text"]
|
||||
}]
|
||||
}
|
||||
"outputs": [{"variable": "result", "value_selector": ["llm", "text"]}],
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_to_answer_node(self) -> dict:
|
||||
@ -634,11 +571,7 @@ class WorkflowConverter:
|
||||
return {
|
||||
"id": "answer",
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "ANSWER",
|
||||
"type": NodeType.ANSWER.value,
|
||||
"answer": "{{#llm.text#}}"
|
||||
}
|
||||
"data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"},
|
||||
}
|
||||
|
||||
def _create_edge(self, source: str, target: str) -> dict:
|
||||
@ -648,11 +581,7 @@ class WorkflowConverter:
|
||||
:param target: target node id
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"id": f"{source}-{target}",
|
||||
"source": source,
|
||||
"target": target
|
||||
}
|
||||
return {"id": f"{source}-{target}", "source": source, "target": target}
|
||||
|
||||
def _append_node(self, graph: dict, node: dict) -> dict:
|
||||
"""
|
||||
@ -662,9 +591,9 @@ class WorkflowConverter:
|
||||
:param node: Node to append
|
||||
:return:
|
||||
"""
|
||||
previous_node = graph['nodes'][-1]
|
||||
graph['nodes'].append(node)
|
||||
graph['edges'].append(self._create_edge(previous_node['id'], node['id']))
|
||||
previous_node = graph["nodes"][-1]
|
||||
graph["nodes"].append(node)
|
||||
graph["edges"].append(self._create_edge(previous_node["id"], node["id"]))
|
||||
return graph
|
||||
|
||||
def _get_new_app_mode(self, app_model: App) -> AppMode:
|
||||
@ -678,14 +607,20 @@ class WorkflowConverter:
|
||||
else:
|
||||
return AppMode.ADVANCED_CHAT
|
||||
|
||||
def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
|
||||
def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str):
|
||||
"""
|
||||
Get API Based Extension
|
||||
:param tenant_id: tenant id
|
||||
:param api_based_extension_id: api based extension id
|
||||
:return:
|
||||
"""
|
||||
return db.session.query(APIBasedExtension).filter(
|
||||
APIBasedExtension.tenant_id == tenant_id,
|
||||
APIBasedExtension.id == api_based_extension_id
|
||||
).first()
|
||||
api_based_extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError(f"API Based Extension not found, id: {api_based_extension_id}")
|
||||
|
||||
return api_based_extension
|
||||
|
||||
@ -5,7 +5,7 @@ from core.model_runtime.entities.text_embedding_entities import TextEmbeddingRes
|
||||
from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import WenxinTextEmbeddingModel
|
||||
|
||||
|
||||
def test_invoke_embedding_model():
|
||||
def test_invoke_embedding_v1():
|
||||
sleep(3)
|
||||
model = WenxinTextEmbeddingModel()
|
||||
|
||||
@ -21,4 +21,61 @@ def test_invoke_embedding_model():
|
||||
|
||||
assert isinstance(response, TextEmbeddingResult)
|
||||
assert len(response.embeddings) == 3
|
||||
assert isinstance(response.embeddings[0], list)
|
||||
assert isinstance(response.embeddings[0], list)
|
||||
|
||||
|
||||
def test_invoke_embedding_bge_large_en():
|
||||
sleep(3)
|
||||
model = WenxinTextEmbeddingModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='bge-large-en',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
texts=['hello', '你好', 'xxxxx'],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, TextEmbeddingResult)
|
||||
assert len(response.embeddings) == 3
|
||||
assert isinstance(response.embeddings[0], list)
|
||||
|
||||
|
||||
def test_invoke_embedding_bge_large_zh():
|
||||
sleep(3)
|
||||
model = WenxinTextEmbeddingModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='bge-large-zh',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
texts=['hello', '你好', 'xxxxx'],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, TextEmbeddingResult)
|
||||
assert len(response.embeddings) == 3
|
||||
assert isinstance(response.embeddings[0], list)
|
||||
|
||||
|
||||
def test_invoke_embedding_tao_8k():
|
||||
sleep(3)
|
||||
model = WenxinTextEmbeddingModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='tao-8k',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
texts=['hello', '你好', 'xxxxx'],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, TextEmbeddingResult)
|
||||
assert len(response.embeddings) == 3
|
||||
assert isinstance(response.embeddings[0], list)
|
||||
|
||||
@ -21,10 +21,6 @@ class PGVectorTest(AbstractVectorTest):
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
|
||||
def test_pgvector(setup_mock_redis):
|
||||
PGVectorTest().run_all_tests()
|
||||
|
||||
@ -6,14 +6,13 @@ from _pytest.monkeypatch import MonkeyPatch
|
||||
from jinja2 import Template
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
|
||||
|
||||
class MockedCodeExecutor:
|
||||
@classmethod
|
||||
def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'],
|
||||
code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
|
||||
code: str, inputs: dict) -> dict:
|
||||
# invoke directly
|
||||
match language:
|
||||
case CodeLanguage.PYTHON3:
|
||||
@ -24,6 +23,8 @@ class MockedCodeExecutor:
|
||||
return {
|
||||
"result": Template(code).render(inputs)
|
||||
}
|
||||
case _:
|
||||
raise Exception("Language not supported")
|
||||
|
||||
@pytest.fixture
|
||||
def setup_code_executor_mock(request, monkeypatch: MonkeyPatch):
|
||||
|
||||
@ -28,14 +28,6 @@ def test_javascript_with_code_template():
|
||||
inputs={'arg1': 'Hello', 'arg2': 'World'})
|
||||
assert result == {'result': 'HelloWorld'}
|
||||
|
||||
|
||||
def test_javascript_list_default_available_packages():
|
||||
packages = JavascriptCodeProvider.get_default_available_packages()
|
||||
|
||||
# no default packages available for javascript
|
||||
assert len(packages) == 0
|
||||
|
||||
|
||||
def test_javascript_get_runner_script():
|
||||
runner_script = NodeJsTemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(NodeJsTemplateTransformer._code_placeholder) == 1
|
||||
|
||||
@ -29,15 +29,6 @@ def test_python3_with_code_template():
|
||||
assert result == {'result': 'HelloWorld'}
|
||||
|
||||
|
||||
def test_python3_list_default_available_packages():
|
||||
packages = Python3CodeProvider.get_default_available_packages()
|
||||
assert len(packages) > 0
|
||||
assert {'requests', 'httpx'}.issubset(p['name'] for p in packages)
|
||||
|
||||
# check JSON serializable
|
||||
assert len(str(json.dumps(packages))) > 0
|
||||
|
||||
|
||||
def test_python3_get_runner_script():
|
||||
runner_script = Python3TemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1
|
||||
|
||||
@ -11,7 +11,7 @@ from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers import ModelProviderFactory
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from extensions.ext_database import db
|
||||
@ -66,10 +66,10 @@ def test_execute_llm(setup_openai_mock):
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather today?',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.QUERY: 'what\'s the weather today?',
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['abc', 'output'], 'sunny')
|
||||
|
||||
@ -181,10 +181,10 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather today?',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.QUERY: 'what\'s the weather today?',
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['abc', 'output'], 'sunny')
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from extensions.ext_database import db
|
||||
@ -119,10 +119,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock):
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
|
||||
result = node.run(pool)
|
||||
@ -177,10 +177,10 @@ def test_instructions(setup_openai_mock):
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
|
||||
result = node.run(pool)
|
||||
@ -243,10 +243,10 @@ def test_chat_parameter_extractor(setup_anthropic_mock):
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
|
||||
result = node.run(pool)
|
||||
@ -307,10 +307,10 @@ def test_completion_parameter_extractor(setup_openai_mock):
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
|
||||
result = node.run(pool)
|
||||
@ -420,10 +420,10 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
|
||||
result = node.run(pool)
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
from core.app.segments import SecretVariable, StringSegment, parser
|
||||
from core.helper import encrypter
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
|
||||
|
||||
def test_segment_group_to_text():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariable('user_id'): 'fake-user-id',
|
||||
SystemVariableKey('user_id'): 'fake-user-id',
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[
|
||||
@ -42,7 +42,7 @@ def test_convert_constant_to_segment_group():
|
||||
def test_convert_variable_to_segment_group():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariable('user_id'): 'fake-user-id',
|
||||
SystemVariableKey('user_id'): 'fake-user-id',
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
|
||||
@ -2,7 +2,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from extensions.ext_database import db
|
||||
@ -29,8 +29,8 @@ def test_execute_answer():
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'weather'], 'sunny')
|
||||
pool.add(['llm', 'text'], 'You are a helpful AI.')
|
||||
|
||||
@ -2,7 +2,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from extensions.ext_database import db
|
||||
@ -119,8 +119,8 @@ def test_execute_if_else_result_true():
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'array_contains'], ['ab', 'def'])
|
||||
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
|
||||
@ -182,8 +182,8 @@ def test_execute_if_else_result_false():
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'array_contains'], ['1ab', 'def'])
|
||||
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
|
||||
|
||||
@ -4,7 +4,7 @@ from uuid import uuid4
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.segments import ArrayStringVariable, StringVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
|
||||
|
||||
@ -42,7 +42,7 @@ def test_overwrite_string_variable():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
@ -52,7 +52,7 @@ def test_overwrite_string_variable():
|
||||
input_variable,
|
||||
)
|
||||
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
|
||||
node.run(variable_pool)
|
||||
mock_run.assert_called_once()
|
||||
|
||||
@ -93,7 +93,7 @@ def test_append_variable_to_array():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
@ -103,7 +103,7 @@ def test_append_variable_to_array():
|
||||
input_variable,
|
||||
)
|
||||
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
|
||||
node.run(variable_pool)
|
||||
mock_run.assert_called_once()
|
||||
|
||||
@ -137,7 +137,7 @@ def test_clear_array():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
|
||||
@ -11,7 +11,17 @@ def test_environment_variables():
|
||||
contexts.tenant_id.set('tenant_id')
|
||||
|
||||
# Create a Workflow instance
|
||||
workflow = Workflow()
|
||||
workflow = Workflow(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
type='workflow',
|
||||
version='draft',
|
||||
graph='{}',
|
||||
features='{}',
|
||||
created_by='account_id',
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Create some EnvironmentVariable instances
|
||||
variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())})
|
||||
@ -35,7 +45,17 @@ def test_update_environment_variables():
|
||||
contexts.tenant_id.set('tenant_id')
|
||||
|
||||
# Create a Workflow instance
|
||||
workflow = Workflow()
|
||||
workflow = Workflow(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
type='workflow',
|
||||
version='draft',
|
||||
graph='{}',
|
||||
features='{}',
|
||||
created_by='account_id',
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Create some EnvironmentVariable instances
|
||||
variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())})
|
||||
@ -70,9 +90,17 @@ def test_to_dict():
|
||||
contexts.tenant_id.set('tenant_id')
|
||||
|
||||
# Create a Workflow instance
|
||||
workflow = Workflow()
|
||||
workflow.graph = '{}'
|
||||
workflow.features = '{}'
|
||||
workflow = Workflow(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
type='workflow',
|
||||
version='draft',
|
||||
graph='{}',
|
||||
features='{}',
|
||||
created_by='account_id',
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Create some EnvironmentVariable instances
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user