Compare commits

..

2 Commits

Author SHA1 Message Date
5f7771bc47 fix: iteration node use the main thread pool 2024-12-02 21:13:47 +08:00
286741e139 fix: iteration node use the main thread pool 2024-12-02 21:13:39 +08:00
261 changed files with 2390 additions and 5992 deletions

View File

@ -8,9 +8,16 @@ Please include a summary of the change and which issue is fixed. Please also inc
# Screenshots # Screenshots
| Before | After | <table>
|--------|-------| <tr>
| ... | ... | <td>Before: </td>
<td>After: </td>
</tr>
<tr>
<td>...</td>
<td>...</td>
</tr>
</table>
# Checklist # Checklist

View File

@ -413,3 +413,4 @@ RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
CREATE_TIDB_SERVICE_JOB_ENABLED=false CREATE_TIDB_SERVICE_JOB_ENABLED=false
RETRIEVAL_TOP_N=0

View File

@ -20,8 +20,6 @@ select = [
"PLC0208", # iteration-over-set "PLC0208", # iteration-over-set
"PLC2801", # unnecessary-dunder-call "PLC2801", # unnecessary-dunder-call
"PLC0414", # useless-import-alias "PLC0414", # useless-import-alias
"PLE0604", # invalid-all-object
"PLE0605", # invalid-all-format
"PLR0402", # manual-from-import "PLR0402", # manual-from-import
"PLR1711", # useless-return "PLR1711", # useless-return
"PLR1714", # repeated-equality-comparison "PLR1714", # repeated-equality-comparison
@ -30,7 +28,6 @@ select = [
"RUF100", # unused-noqa "RUF100", # unused-noqa
"RUF101", # redirected-noqa "RUF101", # redirected-noqa
"RUF200", # invalid-pyproject-toml "RUF200", # invalid-pyproject-toml
"RUF022", # unsorted-dunder-all
"S506", # unsafe-yaml-load "S506", # unsafe-yaml-load
"SIM", # flake8-simplify rules "SIM", # flake8-simplify rules
"TRY400", # error-instead-of-exception "TRY400", # error-instead-of-exception

View File

@ -259,7 +259,7 @@ def migrate_knowledge_vector_database():
skipped_count = 0 skipped_count = 0
total_count = 0 total_count = 0
vector_type = dify_config.VECTOR_STORE vector_type = dify_config.VECTOR_STORE
upper_collection_vector_types = { upper_colletion_vector_types = {
VectorType.MILVUS, VectorType.MILVUS,
VectorType.PGVECTOR, VectorType.PGVECTOR,
VectorType.RELYT, VectorType.RELYT,
@ -267,7 +267,7 @@ def migrate_knowledge_vector_database():
VectorType.ORACLE, VectorType.ORACLE,
VectorType.ELASTICSEARCH, VectorType.ELASTICSEARCH,
} }
lower_collection_vector_types = { lower_colletion_vector_types = {
VectorType.ANALYTICDB, VectorType.ANALYTICDB,
VectorType.CHROMA, VectorType.CHROMA,
VectorType.MYSCALE, VectorType.MYSCALE,
@ -307,7 +307,7 @@ def migrate_knowledge_vector_database():
continue continue
collection_name = "" collection_name = ""
dataset_id = dataset.id dataset_id = dataset.id
if vector_type in upper_collection_vector_types: if vector_type in upper_colletion_vector_types:
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
elif vector_type == VectorType.QDRANT: elif vector_type == VectorType.QDRANT:
if dataset.collection_binding_id: if dataset.collection_binding_id:
@ -323,7 +323,7 @@ def migrate_knowledge_vector_database():
else: else:
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
elif vector_type in lower_collection_vector_types: elif vector_type in lower_colletion_vector_types:
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
else: else:
raise ValueError(f"Vector store {vector_type} is not supported.") raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -626,6 +626,8 @@ class DataSetConfig(BaseSettings):
default=30, default=30,
) )
RETRIEVAL_TOP_N: int = Field(description="number of retrieval top_n", default=0)
class WorkspaceConfig(BaseSettings): class WorkspaceConfig(BaseSettings):
""" """

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="0.13.2", default="0.12.1",
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

View File

@ -100,11 +100,11 @@ class DraftWorkflowApi(Resource):
try: try:
environment_variables_list = args.get("environment_variables") or [] environment_variables_list = args.get("environment_variables") or []
environment_variables = [ environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list
] ]
conversation_variables_list = args.get("conversation_variables") or [] conversation_variables_list = args.get("conversation_variables") or []
conversation_variables = [ conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list
] ]
workflow = workflow_service.sync_draft_workflow( workflow = workflow_service.sync_draft_workflow(
app_model=app_model, app_model=app_model,
@ -382,7 +382,7 @@ class DefaultBlockConfigApi(Resource):
filters = None filters = None
if args.get("q"): if args.get("q"):
try: try:
filters = json.loads(args.get("q", "")) filters = json.loads(args.get("q"))
except json.JSONDecodeError: except json.JSONDecodeError:
raise ValueError("Invalid filters") raise ValueError("Invalid filters")

View File

@ -1,6 +1,5 @@
from datetime import UTC, datetime from datetime import UTC, datetime
from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse from flask_restful import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_ from sqlalchemy import and_
@ -21,16 +20,7 @@ class InstalledAppsListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(installed_app_list_fields) @marshal_with(installed_app_list_fields)
def get(self): def get(self):
app_id = request.args.get("app_id", default=None, type=str)
current_tenant_id = current_user.current_tenant_id current_tenant_id = current_user.current_tenant_id
if app_id:
installed_apps = (
db.session.query(InstalledApp)
.filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
.all()
)
else:
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)

View File

@ -368,7 +368,6 @@ class ToolWorkflowProviderCreateApi(Resource):
description=args["description"], description=args["description"],
parameters=args["parameters"], parameters=args["parameters"],
privacy_policy=args["privacy_policy"], privacy_policy=args["privacy_policy"],
labels=args["labels"],
) )

View File

@ -2,7 +2,7 @@
Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks. Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks.
Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid detach errors. Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid deattach errors.
Examples: Examples:

View File

@ -82,7 +82,7 @@ class AppGenerateResponseConverter(ABC):
for resource in metadata["retriever_resources"]: for resource in metadata["retriever_resources"]:
updated_resources.append( updated_resources.append(
{ {
"segment_id": resource.get("segment_id", ""), "segment_id": resource["segment_id"],
"position": resource["position"], "position": resource["position"],
"document_name": resource["document_name"], "document_name": resource["document_name"],
"score": resource["score"], "score": resource["score"],

View File

@ -43,7 +43,7 @@ from core.workflow.graph_engine.entities.event import (
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App from models.model import App
@ -138,8 +138,7 @@ class WorkflowBasedAppRunner(AppRunner):
# Get node class # Get node class
node_type = NodeType(iteration_node_config.get("data", {}).get("type")) node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
node_version = iteration_node_config.get("data", {}).get("version", "1") node_cls = node_type_classes_mapping[node_type]
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool # init variable pool
variable_pool = VariablePool( variable_pool = VariablePool(

View File

@ -2,7 +2,7 @@ from datetime import datetime
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel, field_validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.node_entities import NodeRunMetadataKey
@ -113,6 +113,18 @@ class QueueIterationNextEvent(AppQueueEvent):
output: Optional[Any] = None # output for the current iteration output: Optional[Any] = None # output for the current iteration
duration: Optional[float] = None duration: Optional[float] = None
@field_validator("output", mode="before")
@classmethod
def set_output(cls, v):
"""
Set output
"""
if v is None:
return None
if isinstance(v, int | float | str | bool | dict | list):
return v
raise ValueError("output must be a valid type")
class QueueIterationCompletedEvent(AppQueueEvent): class QueueIterationCompletedEvent(AppQueueEvent):
""" """

View File

@ -7,13 +7,13 @@ from .models import (
) )
__all__ = [ __all__ = [
"FILE_MODEL_IDENTITY",
"ArrayFileAttribute",
"File",
"FileAttribute",
"FileBelongsTo",
"FileTransferMethod",
"FileType", "FileType",
"FileUploadConfig", "FileUploadConfig",
"FileTransferMethod",
"FileBelongsTo",
"File",
"ImageConfig", "ImageConfig",
"FileAttribute",
"ArrayFileAttribute",
"FILE_MODEL_IDENTITY",
] ]

View File

@ -91,7 +91,7 @@ class XinferenceProvider(Provider):
""" """
``` ```
也可以直接抛出对应 Errors并做如下定义这样在之后的调用中可以直接抛出`InvokeConnectionError`等异常。 也可以直接抛出对应Erros并做如下定义这样在之后的调用中可以直接抛出`InvokeConnectionError`等异常。
```python ```python
@property @property

View File

@ -18,25 +18,25 @@ from .message_entities import (
from .model_entities import ModelPropertyKey from .model_entities import ModelPropertyKey
__all__ = [ __all__ = [
"AssistantPromptMessage",
"AudioPromptMessageContent",
"DocumentPromptMessageContent",
"ImagePromptMessageContent", "ImagePromptMessageContent",
"VideoPromptMessageContent",
"PromptMessage",
"PromptMessageRole",
"LLMUsage",
"ModelPropertyKey",
"AssistantPromptMessage",
"PromptMessage",
"PromptMessageContent",
"PromptMessageRole",
"SystemPromptMessage",
"TextPromptMessageContent",
"UserPromptMessage",
"PromptMessageTool",
"ToolPromptMessage",
"PromptMessageContentType",
"LLMResult", "LLMResult",
"LLMResultChunk", "LLMResultChunk",
"LLMResultChunkDelta", "LLMResultChunkDelta",
"LLMUsage", "AudioPromptMessageContent",
"ModelPropertyKey", "DocumentPromptMessageContent",
"PromptMessage",
"PromptMessage",
"PromptMessageContent",
"PromptMessageContentType",
"PromptMessageRole",
"PromptMessageRole",
"PromptMessageTool",
"SystemPromptMessage",
"TextPromptMessageContent",
"ToolPromptMessage",
"UserPromptMessage",
"VideoPromptMessageContent",
] ]

View File

@ -16,7 +16,6 @@ help:
supported_model_types: supported_model_types:
- llm - llm
- text-embedding - text-embedding
- rerank
configurate_methods: configurate_methods:
- predefined-model - predefined-model
provider_credential_schema: provider_credential_schema:

View File

@ -1,53 +0,0 @@
model: amazon.nova-lite-v1:0
label:
en_US: Nova Lite V1
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 300000
parameter_rules:
- name: max_new_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 5000
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.00006'
output: '0.00024'
unit: '0.001'
currency: USD

View File

@ -1,52 +0,0 @@
model: amazon.nova-micro-v1:0
label:
en_US: Nova Micro V1
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: max_new_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 5000
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.000035'
output: '0.00014'
unit: '0.001'
currency: USD

View File

@ -1,53 +0,0 @@
model: amazon.nova-pro-v1:0
label:
en_US: Nova Pro V1
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 300000
parameter_rules:
- name: max_new_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 5000
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.0008'
output: '0.0032'
unit: '0.001'
currency: USD

View File

@ -70,8 +70,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
{"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True}, {"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False}, {"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False},
{"prefix": "ai21.jamba-1-5", "support_system_prompts": True, "support_tool_use": False}, {"prefix": "ai21.jamba-1-5", "support_system_prompts": True, "support_tool_use": False},
{"prefix": "amazon.nova", "support_system_prompts": True, "support_tool_use": False},
{"prefix": "us.amazon.nova", "support_system_prompts": True, "support_tool_use": False},
] ]
@staticmethod @staticmethod
@ -196,13 +194,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
if model_info["support_tool_use"] and tools: if model_info["support_tool_use"] and tools:
parameters["toolConfig"] = self._convert_converse_tool_config(tools=tools) parameters["toolConfig"] = self._convert_converse_tool_config(tools=tools)
try: try:
# for issue #10976
conversations_list = parameters["messages"]
# if two consecutive user messages found, combine them into one message
for i in range(len(conversations_list) - 2, -1, -1):
if conversations_list[i]["role"] == conversations_list[i + 1]["role"]:
conversations_list[i]["content"].extend(conversations_list.pop(i + 1)["content"])
if stream: if stream:
response = bedrock_client.converse_stream(**parameters) response = bedrock_client.converse_stream(**parameters)
return self._handle_converse_stream_response( return self._handle_converse_stream_response(

View File

@ -1,53 +0,0 @@
model: us.amazon.nova-lite-v1:0
label:
en_US: Nova Lite V1 (US.Cross Region Inference)
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 300000
parameter_rules:
- name: max_new_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 5000
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.00006'
output: '0.00024'
unit: '0.001'
currency: USD

View File

@ -1,52 +0,0 @@
model: us.amazon.nova-micro-v1:0
label:
en_US: Nova Micro V1 (US.Cross Region Inference)
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: max_new_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 5000
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.000035'
output: '0.00014'
unit: '0.001'
currency: USD

View File

@ -1,53 +0,0 @@
model: us.amazon.nova-pro-v1:0
label:
en_US: Nova Pro V1 (US.Cross Region Inference)
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 300000
parameter_rules:
- name: max_new_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 5000
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.0008'
output: '0.0032'
unit: '0.001'
currency: USD

View File

@ -1,2 +0,0 @@
- amazon.rerank-v1
- cohere.rerank-v3-5

View File

@ -1,4 +0,0 @@
model: amazon.rerank-v1:0
model_type: rerank
model_properties:
context_size: 5120

View File

@ -1,4 +0,0 @@
model: cohere.rerank-v3-5:0
model_type: rerank
model_properties:
context_size: 5120

View File

@ -1,147 +0,0 @@
from typing import Optional
import boto3
from botocore.config import Config
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
class BedrockRerankModel(RerankModel):
"""
Model class for Cohere rerank model.
"""
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=docs)
# initialize client
client_config = Config(region_name=credentials["aws_region"])
bedrock_runtime = boto3.client(
service_name="bedrock-agent-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id", ""),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
queries = [{"type": "TEXT", "textQuery": {"text": query}}]
text_sources = []
for text in docs:
text_sources.append(
{
"type": "INLINE",
"inlineDocumentSource": {
"type": "TEXT",
"textDocument": {
"text": text,
},
},
}
)
modelId = model
region = credentials["aws_region"]
model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
rerankingConfiguration = {
"type": "BEDROCK_RERANKING_MODEL",
"bedrockRerankingConfiguration": {
"numberOfResults": top_n,
"modelConfiguration": {
"modelArn": model_package_arn,
},
},
}
response = bedrock_runtime.rerank(
queries=queries, sources=text_sources, rerankingConfiguration=rerankingConfiguration
)
rerank_documents = []
for idx, result in enumerate(response["results"]):
# format document
index = result["index"]
rerank_document = RerankDocument(
index=index,
text=docs[index],
score=result["relevanceScore"],
)
# score threshold check
if score_threshold is not None:
if rerank_document.score >= score_threshold:
rerank_documents.append(rerank_document)
else:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self.invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke emd = genai.GenerativeModel(model) error mapping
"""
return {
InvokeConnectionError: [],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: [],
}

View File

@ -2,4 +2,3 @@
- rerank-english-v3.0 - rerank-english-v3.0
- rerank-multilingual-v2.0 - rerank-multilingual-v2.0
- rerank-multilingual-v3.0 - rerank-multilingual-v3.0
- rerank-v3.5

View File

@ -1,4 +0,0 @@
model: rerank-v3.5
model_type: rerank
model_properties:
context_size: 5120

View File

@ -1,38 +0,0 @@
model: gemini-exp-1206
label:
en_US: Gemini exp 1206
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 2097152
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -252,7 +252,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
# ignore sse comments # ignore sse comments
if chunk.startswith(":"): if chunk.startswith(":"):
continue continue
decoded_chunk = chunk.strip().removeprefix("data: ") decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
chunk_json = None chunk_json = None
try: try:
chunk_json = json.loads(decoded_chunk) chunk_json = json.loads(decoded_chunk)

View File

@ -181,11 +181,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
# prepare the payload for a simple ping to the model # prepare the payload for a simple ping to the model
data = {"model": model, "stream": stream} data = {"model": model, "stream": stream}
if format_schema := model_parameters.pop("format", None): if "format" in model_parameters:
try: data["format"] = model_parameters["format"]
data["format"] = format_schema if format_schema == "json" else json.loads(format_schema) del model_parameters["format"]
except json.JSONDecodeError as e:
raise InvokeBadRequestError(f"Invalid format schema: {str(e)}")
if "keep_alive" in model_parameters: if "keep_alive" in model_parameters:
data["keep_alive"] = model_parameters["keep_alive"] data["keep_alive"] = model_parameters["keep_alive"]
@ -735,12 +733,12 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
ParameterRule( ParameterRule(
name="format", name="format",
label=I18nObject(en_US="Format", zh_Hans="返回格式"), label=I18nObject(en_US="Format", zh_Hans="返回格式"),
type=ParameterType.TEXT, type=ParameterType.STRING,
default="json",
help=I18nObject( help=I18nObject(
en_US="the format to return a response in. Format can be `json` or a JSON schema.", en_US="the format to return a response in. Currently the only accepted value is json.",
zh_Hans="返回响应的格式。目前接受的值是字符串`json`或JSON schema.", zh_Hans="返回响应的格式。目前唯一接受的值是json。",
), ),
options=["json"],
), ),
], ],
pricing=PriceConfig( pricing=PriceConfig(

View File

@ -462,7 +462,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
# ignore sse comments # ignore sse comments
if chunk.startswith(":"): if chunk.startswith(":"):
continue continue
decoded_chunk = chunk.strip().removeprefix("data: ") decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]" if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
continue continue

View File

@ -250,7 +250,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel):
# ignore sse comments # ignore sse comments
if chunk.startswith(":"): if chunk.startswith(":"):
continue continue
decoded_chunk = chunk.strip().removeprefix("data: ") decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
chunk_json = None chunk_json = None
try: try:
chunk_json = json.loads(decoded_chunk) chunk_json = json.loads(decoded_chunk)

View File

@ -104,14 +104,13 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
""" """
# use Anthropic official SDK references # use Anthropic official SDK references
# - https://github.com/anthropics/anthropic-sdk-python # - https://github.com/anthropics/anthropic-sdk-python
service_account_key = credentials.get("vertex_service_account_key", "") service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
project_id = credentials["vertex_project_id"] project_id = credentials["vertex_project_id"]
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
token = "" token = ""
# get access token from service account credential # get access token from service account credential
if service_account_key: if service_account_info:
service_account_info = json.loads(base64.b64decode(service_account_key))
credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES) credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
request = google.auth.transport.requests.Request() request = google.auth.transport.requests.Request()
credentials.refresh(request) credentials.refresh(request)
@ -479,11 +478,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
if stop: if stop:
config_kwargs["stop_sequences"] = stop config_kwargs["stop_sequences"] = stop
service_account_key = credentials.get("vertex_service_account_key", "") service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
project_id = credentials["vertex_project_id"] project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"] location = credentials["vertex_location"]
if service_account_key: if service_account_info:
service_account_info = json.loads(base64.b64decode(service_account_key))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location) aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else: else:

View File

@ -48,11 +48,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:param input_type: input type :param input_type: input type
:return: embeddings result :return: embeddings result
""" """
service_account_key = credentials.get("vertex_service_account_key", "") service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
project_id = credentials["vertex_project_id"] project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"] location = credentials["vertex_location"]
if service_account_key: if service_account_info:
service_account_info = json.loads(base64.b64decode(service_account_key))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location) aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else: else:
@ -101,11 +100,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:return: :return:
""" """
try: try:
service_account_key = credentials.get("vertex_service_account_key", "") service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
project_id = credentials["vertex_project_id"] project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"] location = credentials["vertex_location"]
if service_account_key: if service_account_info:
service_account_info = json.loads(base64.b64decode(service_account_key))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location) aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else: else:

View File

@ -1,4 +1,4 @@
from .common import ChatRole from .common import ChatRole
from .maas import MaasError, MaasService from .maas import MaasError, MaasService
__all__ = ["ChatRole", "MaasError", "MaasService"] __all__ = ["MaasService", "ChatRole", "MaasError"]

View File

@ -17,13 +17,7 @@ class WenxinRerank(_CommonWenxin):
def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None): def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None):
access_token = self._get_access_token() access_token = self._get_access_token()
url = f"{self.api_bases[model]}?access_token={access_token}" url = f"{self.api_bases[model]}?access_token={access_token}"
# For issue #11252
# for wenxin Rerank model top_n length should be equal or less than docs length
if top_n is not None and top_n > len(docs):
top_n = len(docs)
# for wenxin Rerank model, query should not be an empty string
if query == "":
query = " " # FIXME: this is a workaround for wenxin rerank model for better user experience.
try: try:
response = httpx.post( response = httpx.post(
url, url,
@ -31,11 +25,7 @@ class WenxinRerank(_CommonWenxin):
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
data = response.json() return response.json()
# wenxin error handling
if "error_code" in data:
raise InternalServerError(data["error_msg"])
return data
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
raise InternalServerError(str(e)) raise InternalServerError(str(e))
@ -79,9 +69,6 @@ class WenxinRerankModel(RerankModel):
results = wenxin_rerank.rerank(model, query, docs, top_n) results = wenxin_rerank.rerank(model, query, docs, top_n)
rerank_documents = [] rerank_documents = []
if "results" not in results:
raise ValueError("results key not found in response")
for result in results["results"]: for result in results["results"]:
index = result["index"] index = result["index"]
if "document" in result: if "document" in result:

View File

@ -8,7 +8,6 @@ features:
- stream-tool-call - stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 131072
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature

View File

@ -8,7 +8,6 @@ features:
- stream-tool-call - stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 131072
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature

View File

@ -8,7 +8,6 @@ features:
- stream-tool-call - stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 8192
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature

View File

@ -8,7 +8,6 @@ features:
- stream-tool-call - stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 131072
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature

View File

@ -8,7 +8,6 @@ features:
- stream-tool-call - stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 131072
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature

View File

@ -8,7 +8,6 @@ features:
- stream-tool-call - stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 131072
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature

View File

@ -8,7 +8,6 @@ features:
- stream-tool-call - stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 131072
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature

View File

@ -8,7 +8,7 @@ features:
- stream-tool-call - stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 10240
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature

View File

@ -8,7 +8,6 @@ features:
- stream-tool-call - stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 131072
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature

View File

@ -4,7 +4,6 @@ label:
model_type: llm model_type: llm
model_properties: model_properties:
mode: chat mode: chat
context_size: 2048
features: features:
- vision - vision
parameter_rules: parameter_rules:

View File

@ -1,52 +0,0 @@
model: glm-4v-flash
label:
en_US: glm-4v-flash
model_type: llm
model_properties:
mode: chat
context_size: 2048
features:
- vision
parameter_rules:
- name: temperature
use_template: temperature
default: 0.95
min: 0.0
max: 1.0
help:
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: top_p
use_template: top_p
default: 0.6
help:
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: do_sample
label:
zh_Hans: 采样策略
en_US: Sampling strategy
type: boolean
help:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 1024
- name: web_search
type: boolean
label:
zh_Hans: 联网搜索
en_US: Web Search
default: false
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: RMB

View File

@ -4,7 +4,6 @@ label:
model_type: llm model_type: llm
model_properties: model_properties:
mode: chat mode: chat
context_size: 8192
features: features:
- vision - vision
- video - video

View File

@ -22,6 +22,18 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
from core.model_runtime.utils import helper from core.model_runtime.utils import helper
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
The structure of the JSON object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
And you should always end the block with a "```" to indicate the end of the JSON object.
<instructions>
{{instructions}}
</instructions>
```JSON""" # noqa: E501
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
def _invoke( def _invoke(
@ -52,8 +64,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
credentials_kwargs = self._to_credential_kwargs(credentials) credentials_kwargs = self._to_credential_kwargs(credentials)
# invoke model # invoke model
# stop = stop or []
# self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
# def _transform_json_prompts(self, model: str, credentials: dict,
# prompt_messages: list[PromptMessage], model_parameters: dict,
# tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
# stream: bool = True, user: str | None = None) \
# -> None:
# """
# Transform json prompts to model prompts
# """
# if "}\n\n" not in stop:
# stop.append("}\n\n")
# # check if there is a system message
# if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# # override the system message
# prompt_messages[0] = SystemPromptMessage(
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content)
# )
# else:
# # insert the system message
# prompt_messages.insert(0, SystemPromptMessage(
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", "Please output a valid JSON object.")
# ))
# # check if the last message is a user message
# if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# # add ```JSON\n to the last message
# prompt_messages[-1].content += "\n```JSON\n"
# else:
# # append a user message
# prompt_messages.append(UserPromptMessage(
# content="```JSON\n"
# ))
def get_num_tokens( def get_num_tokens(
self, self,
model: str, model: str,
@ -124,7 +170,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
extra_model_kwargs = {} extra_model_kwargs = {}
# request to glm-4v-plus with stop words will always respond "finish_reason":"network_error" # request to glm-4v-plus with stop words will always response "finish_reason":"network_error"
if stop and model != "glm-4v-plus": if stop and model != "glm-4v-plus":
extra_model_kwargs["stop"] = stop extra_model_kwargs["stop"] = stop
@ -140,11 +186,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
# resolve zhipuai model not support system message and user message, assistant message must be in sequence # resolve zhipuai model not support system message and user message, assistant message must be in sequence
new_prompt_messages: list[PromptMessage] = [] new_prompt_messages: list[PromptMessage] = []
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
copy_prompt_message = prompt_message.model_copy() copy_prompt_message = prompt_message.copy()
if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}: if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}:
if isinstance(copy_prompt_message.content, list): if isinstance(copy_prompt_message.content, list):
# check if model is 'glm-4v' # check if model is 'glm-4v'
if not model.startswith("glm-4v"): if model not in {"glm-4v", "glm-4v-plus"}:
# not support list message # not support list message
continue continue
# get image and # get image and
@ -188,10 +234,12 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
else: else:
model_parameters["tools"] = [web_search_params] model_parameters["tools"] = [web_search_params]
if model.startswith("glm-4v"): if model in {"glm-4v", "glm-4v-plus"}:
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
else: else:
params = {"model": model, "messages": [], **model_parameters} params = {"model": model, "messages": [], **model_parameters}
# glm model
if not model.startswith("chatglm"):
for prompt_message in new_prompt_messages: for prompt_message in new_prompt_messages:
if prompt_message.role == PromptMessageRole.TOOL: if prompt_message.role == PromptMessageRole.TOOL:
params["messages"].append( params["messages"].append(
@ -223,7 +271,26 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
else: else:
params["messages"].append({"role": "assistant", "content": prompt_message.content}) params["messages"].append({"role": "assistant", "content": prompt_message.content})
else: else:
params["messages"].append({"role": prompt_message.role.value, "content": prompt_message.content}) params["messages"].append(
{"role": prompt_message.role.value, "content": prompt_message.content}
)
else:
# chatglm model
for prompt_message in new_prompt_messages:
# merge system message to user message
if prompt_message.role in {
PromptMessageRole.SYSTEM,
PromptMessageRole.TOOL,
PromptMessageRole.USER,
}:
if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user":
params["messages"][-1]["content"] += "\n\n" + prompt_message.content
else:
params["messages"].append({"role": "user", "content": prompt_message.content})
else:
params["messages"].append(
{"role": prompt_message.role.value, "content": prompt_message.content}
)
if tools and len(tools) > 0: if tools and len(tools) > 0:
params["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools] params["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools]
@ -339,7 +406,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
Handle llm stream response Handle llm stream response
:param model: model name :param model: model name
:param responses: response :param response: response
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:return: llm response chunk generator result :return: llm response chunk generator result
""" """
@ -412,8 +479,6 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
human_prompt = "\n\nHuman:" human_prompt = "\n\nHuman:"
ai_prompt = "\n\nAssistant:" ai_prompt = "\n\nAssistant:"
content = message.content content = message.content
if isinstance(content, list):
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
if isinstance(message, UserPromptMessage): if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}" message_text = f"{human_prompt} {content}"
@ -440,7 +505,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
if tools and len(tools) > 0: if tools and len(tools) > 0:
text += "\n\nTools:" text += "\n\nTools:"
for tool in tools: for tool in tools:
text += f"\n{tool.model_dump_json()}" text += f"\n{tool.json()}"
# trim off the trailing ' ' that might come from the "Assistant: " # trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip() return text.rstrip()

View File

@ -5,7 +5,7 @@ BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找
CHAT_APP_COMPLETION_PROMPT_CONFIG = { CHAT_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": { "completion_prompt_config": {
"prompt": { "prompt": {
"text": "{{#pre_prompt#}}\nHere are the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: " # noqa: E501 "text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: " # noqa: E501
}, },
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
}, },

View File

@ -3,6 +3,7 @@ from typing import Optional
from flask import Flask, current_app from flask import Flask, current_app
from configs import DifyConfig
from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
@ -113,7 +114,7 @@ class RetrievalService:
query=query, query=query,
documents=all_documents, documents=all_documents,
score_threshold=score_threshold, score_threshold=score_threshold,
top_n=top_k, top_n=DifyConfig.RETRIEVAL_TOP_N or top_k,
) )
return all_documents return all_documents
@ -185,7 +186,7 @@ class RetrievalService:
query=query, query=query,
documents=documents, documents=documents,
score_threshold=score_threshold, score_threshold=score_threshold,
top_n=len(documents), top_n=DifyConfig.RETRIEVAL_TOP_N or len(documents),
) )
) )
else: else:
@ -230,7 +231,7 @@ class RetrievalService:
query=query, query=query,
documents=documents, documents=documents,
score_threshold=score_threshold, score_threshold=score_threshold,
top_n=len(documents), top_n=DifyConfig.RETRIEVAL_TOP_N or len(documents),
) )
) )
else: else:

View File

@ -104,7 +104,8 @@ class OceanBaseVector(BaseVector):
val = int(row[6]) val = int(row[6])
vals.append(val) vals.append(val)
if len(vals) == 0: if len(vals) == 0:
raise ValueError("ob_vector_memory_limit_percentage not found in parameters.") print("ob_vector_memory_limit_percentage not found in parameters.")
exit(1)
if any(val == 0 for val in vals): if any(val == 0 for val in vals):
try: try:
self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30") self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30")
@ -199,10 +200,10 @@ class OceanBaseVectorFactory(AbstractVectorFactory):
return OceanBaseVector( return OceanBaseVector(
collection_name, collection_name,
OceanBaseVectorConfig( OceanBaseVectorConfig(
host=dify_config.OCEANBASE_VECTOR_HOST or "", host=dify_config.OCEANBASE_VECTOR_HOST,
port=dify_config.OCEANBASE_VECTOR_PORT or 0, port=dify_config.OCEANBASE_VECTOR_PORT,
user=dify_config.OCEANBASE_VECTOR_USER or "", user=dify_config.OCEANBASE_VECTOR_USER,
password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""), password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""),
database=dify_config.OCEANBASE_VECTOR_DATABASE or "", database=dify_config.OCEANBASE_VECTOR_DATABASE,
), ),
) )

View File

@ -375,6 +375,7 @@ class TidbOnQdrantVector(BaseVector):
for result in results: for result in results:
if result: if result:
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
document.metadata["vector"] = result.vector
documents.append(document) documents.append(document)
return documents return documents
@ -393,7 +394,6 @@ class TidbOnQdrantVector(BaseVector):
) -> Document: ) -> Document:
return Document( return Document(
page_content=scored_point.payload.get(content_payload_key), page_content=scored_point.payload.get(content_payload_key),
vector=scored_point.vector,
metadata=scored_point.payload.get(metadata_payload_key) or {}, metadata=scored_point.payload.get(metadata_payload_key) or {},
) )

View File

@ -162,7 +162,7 @@ class TidbService:
clusters = [] clusters = []
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list] cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "BASIC"} params = {"clusterIds": cluster_ids, "view": "FULL"}
response = requests.get( response = requests.get(
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key) f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
) )

View File

@ -15,7 +15,7 @@ class ComfyUIProvider(BuiltinToolProviderController):
try: try:
ws.connect(ws_address) ws.connect(ws_address)
except Exception as e: except Exception:
raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}") raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}")
finally: finally:
ws.close() ws.close()

View File

@ -6,9 +6,9 @@ identity:
zh_Hans: GitLab 合并请求查询 zh_Hans: GitLab 合并请求查询
description: description:
human: human:
en_US: A tool for query GitLab merge requests, Input should be a exists repository or branch. en_US: A tool for query GitLab merge requests, Input should be a exists reposity or branch.
zh_Hans: 一个用于查询 GitLab 代码合并请求的工具,输入的内容应该是一个已存在的仓库名或者分支。 zh_Hans: 一个用于查询 GitLab 代码合并请求的工具,输入的内容应该是一个已存在的仓库名或者分支。
llm: A tool for query GitLab merge requests, Input should be a exists repository or branch. llm: A tool for query GitLab merge requests, Input should be a exists reposity or branch.
parameters: parameters:
- name: repository - name: repository
type: string type: string

View File

@ -61,7 +61,7 @@ class WolframAlphaTool(BuiltinTool):
params["input"] = query params["input"] = query
else: else:
finished = True finished = True
if "sources" in response_data["queryresult"]: if "souces" in response_data["queryresult"]:
return self.create_link_message(response_data["queryresult"]["sources"]["url"]) return self.create_link_message(response_data["queryresult"]["sources"]["url"])
elif "pods" in response_data["queryresult"]: elif "pods" in response_data["queryresult"]:
result = response_data["queryresult"]["pods"][0]["subpods"][0]["plaintext"] result = response_data["queryresult"]["pods"][0]["subpods"][0]["plaintext"]

View File

@ -32,32 +32,32 @@ from .variables import (
) )
__all__ = [ __all__ = [
"ArrayAnySegment",
"ArrayAnyVariable",
"ArrayFileSegment",
"ArrayFileVariable",
"ArrayNumberSegment",
"ArrayNumberVariable",
"ArrayObjectSegment",
"ArrayObjectVariable",
"ArraySegment",
"ArrayStringSegment",
"ArrayStringVariable",
"FileSegment",
"FileVariable",
"FloatSegment",
"FloatVariable",
"IntegerSegment",
"IntegerVariable", "IntegerVariable",
"NoneSegment", "FloatVariable",
"NoneVariable",
"ObjectSegment",
"ObjectVariable", "ObjectVariable",
"SecretVariable", "SecretVariable",
"Segment",
"SegmentGroup",
"SegmentType",
"StringSegment",
"StringVariable", "StringVariable",
"ArrayAnyVariable",
"Variable", "Variable",
"SegmentType",
"SegmentGroup",
"Segment",
"NoneSegment",
"NoneVariable",
"IntegerSegment",
"FloatSegment",
"ObjectSegment",
"ArrayAnySegment",
"StringSegment",
"ArrayStringVariable",
"ArrayNumberVariable",
"ArrayObjectVariable",
"ArraySegment",
"ArrayFileSegment",
"ArrayNumberSegment",
"ArrayObjectSegment",
"ArrayStringSegment",
"FileSegment",
"FileVariable",
"ArrayFileVariable",
] ]

View File

@ -2,19 +2,16 @@ from enum import StrEnum
class SegmentType(StrEnum): class SegmentType(StrEnum):
NONE = "none"
NUMBER = "number" NUMBER = "number"
STRING = "string" STRING = "string"
OBJECT = "object"
SECRET = "secret" SECRET = "secret"
FILE = "file"
ARRAY_ANY = "array[any]" ARRAY_ANY = "array[any]"
ARRAY_STRING = "array[string]" ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]" ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]" ARRAY_OBJECT = "array[object]"
OBJECT = "object"
FILE = "file"
ARRAY_FILE = "array[file]" ARRAY_FILE = "array[file]"
NONE = "none"
GROUP = "group" GROUP = "group"

View File

@ -2,6 +2,6 @@ from .base_workflow_callback import WorkflowCallback
from .workflow_logging_callback import WorkflowLoggingCallback from .workflow_logging_callback import WorkflowLoggingCallback
__all__ = [ __all__ = [
"WorkflowCallback",
"WorkflowLoggingCallback", "WorkflowLoggingCallback",
"WorkflowCallback",
] ]

View File

@ -1,4 +1,3 @@
from collections.abc import Mapping
from datetime import datetime from datetime import datetime
from typing import Any, Optional from typing import Any, Optional
@ -141,8 +140,8 @@ class BaseIterationEvent(GraphEngineEvent):
class IterationRunStartedEvent(BaseIterationEvent): class IterationRunStartedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at") start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None inputs: Optional[dict[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None metadata: Optional[dict[str, Any]] = None
predecessor_node_id: Optional[str] = None predecessor_node_id: Optional[str] = None
@ -154,18 +153,18 @@ class IterationRunNextEvent(BaseIterationEvent):
class IterationRunSucceededEvent(BaseIterationEvent): class IterationRunSucceededEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at") start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None inputs: Optional[dict[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None outputs: Optional[dict[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None metadata: Optional[dict[str, Any]] = None
steps: int = 0 steps: int = 0
iteration_duration_map: Optional[dict[str, float]] = None iteration_duration_map: Optional[dict[str, float]] = None
class IterationRunFailedEvent(BaseIterationEvent): class IterationRunFailedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at") start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None inputs: Optional[dict[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None outputs: Optional[dict[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None metadata: Optional[dict[str, Any]] = None
steps: int = 0 steps: int = 0
error: str = Field(..., description="failed reason") error: str = Field(..., description="failed reason")

View File

@ -38,7 +38,7 @@ from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProce
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import node_type_classes_mapping
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
@ -227,8 +227,7 @@ class GraphEngine:
# convert to specific node # convert to specific node
node_type = NodeType(node_config.get("data", {}).get("type")) node_type = NodeType(node_config.get("data", {}).get("type"))
node_version = node_config.get("data", {}).get("version", "1") node_cls = node_type_classes_mapping[node_type]
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None

View File

@ -1,4 +1,4 @@
from .answer_node import AnswerNode from .answer_node import AnswerNode
from .entities import AnswerStreamGenerateRoute from .entities import AnswerStreamGenerateRoute
__all__ = ["AnswerNode", "AnswerStreamGenerateRoute"] __all__ = ["AnswerStreamGenerateRoute", "AnswerNode"]

View File

@ -153,7 +153,7 @@ class AnswerStreamGeneratorRouter:
NodeType.IF_ELSE, NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER, NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION, NodeType.ITERATION,
NodeType.VARIABLE_ASSIGNER, NodeType.CONVERSATION_VARIABLE_ASSIGNER,
}: }:
answer_dependencies[answer_node_id].append(source_node_id) answer_dependencies[answer_node_id].append(source_node_id)
else: else:

View File

@ -1,4 +1,4 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
from .node import BaseNode from .node import BaseNode
__all__ = ["BaseIterationNodeData", "BaseIterationState", "BaseNode", "BaseNodeData"] __all__ = ["BaseNode", "BaseNodeData", "BaseIterationNodeData", "BaseIterationState"]

View File

@ -7,7 +7,6 @@ from pydantic import BaseModel
class BaseNodeData(ABC, BaseModel): class BaseNodeData(ABC, BaseModel):
title: str title: str
desc: Optional[str] = None desc: Optional[str] = None
version: str = "1"
class BaseIterationNodeData(BaseNodeData): class BaseIterationNodeData(BaseNodeData):

View File

@ -55,9 +55,7 @@ class BaseNode(Generic[GenericNodeData]):
raise ValueError("Node ID is required.") raise ValueError("Node ID is required.")
self.node_id = node_id self.node_id = node_id
self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {})))
node_data = self._node_data_cls.model_validate(config.get("data", {}))
self.node_data = cast(GenericNodeData, node_data)
@abstractmethod @abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]: def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:

View File

@ -1,8 +1,6 @@
import csv import csv
import io import io
import json import json
import os
import tempfile
import docx import docx
import pandas as pd import pandas as pd
@ -266,20 +264,14 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
def _extract_text_from_pptx(file_content: bytes) -> str: def _extract_text_from_pptx(file_content: bytes) -> str:
try: try:
with io.BytesIO(file_content) as file:
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY: if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
temp_file.write(file_content)
temp_file.flush()
with open(temp_file.name, "rb") as file:
elements = partition_via_api( elements = partition_via_api(
file=file, file=file,
metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL, api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY, api_key=dify_config.UNSTRUCTURED_API_KEY,
) )
os.unlink(temp_file.name)
else: else:
with io.BytesIO(file_content) as file:
elements = partition_pptx(file=file) elements = partition_pptx(file=file)
return "\n".join([getattr(element, "text", "") for element in elements]) return "\n".join([getattr(element, "text", "") for element in elements])
except Exception as e: except Exception as e:

View File

@ -1,4 +1,4 @@
from .end_node import EndNode from .end_node import EndNode
from .entities import EndStreamParam from .entities import EndStreamParam
__all__ = ["EndNode", "EndStreamParam"] __all__ = ["EndStreamParam", "EndNode"]

View File

@ -14,11 +14,11 @@ class NodeType(StrEnum):
HTTP_REQUEST = "http-request" HTTP_REQUEST = "http-request"
TOOL = "tool" TOOL = "tool"
VARIABLE_AGGREGATOR = "variable-aggregator" VARIABLE_AGGREGATOR = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. VARIABLE_ASSIGNER = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop" LOOP = "loop"
ITERATION = "iteration" ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration. ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor" PARAMETER_EXTRACTOR = "parameter-extractor"
VARIABLE_ASSIGNER = "assigner" CONVERSATION_VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor" DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator" LIST_OPERATOR = "list-operator"

View File

@ -2,9 +2,9 @@ from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverRes
from .types import NodeEvent from .types import NodeEvent
__all__ = [ __all__ = [
"ModelInvokeCompletedEvent",
"NodeEvent",
"RunCompletedEvent", "RunCompletedEvent",
"RunRetrieverResourceEvent", "RunRetrieverResourceEvent",
"RunStreamChunkEvent", "RunStreamChunkEvent",
"NodeEvent",
"ModelInvokeCompletedEvent",
] ]

View File

@ -1,4 +1,4 @@
from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData
from .node import HttpRequestNode from .node import HttpRequestNode
__all__ = ["BodyData", "HttpRequestNode", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "HttpRequestNodeData"] __all__ = ["HttpRequestNodeData", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "BodyData", "HttpRequestNode"]

View File

@ -1,9 +1,11 @@
import logging import logging
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from mimetypes import guess_extension
from os import path
from typing import Any from typing import Any
from configs import dify_config from configs import dify_config
from core.file import File, FileTransferMethod from core.file import File, FileTransferMethod, FileType
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_entities import VariableSelector
@ -148,6 +150,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
content = response.content content = response.content
if is_file and content_type: if is_file and content_type:
# extract filename from url
filename = path.basename(url)
# extract extension if possible
extension = guess_extension(content_type) or ".bin"
tool_file = ToolFileManager.create_file_by_raw( tool_file = ToolFileManager.create_file_by_raw(
user_id=self.user_id, user_id=self.user_id,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
@ -158,6 +165,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
mapping = { mapping = {
"tool_file_id": tool_file.id, "tool_file_id": tool_file.id,
"type": FileType.IMAGE.value,
"transfer_method": FileTransferMethod.TOOL_FILE.value, "transfer_method": FileTransferMethod.TOOL_FILE.value,
} }
file = file_factory.build_from_mapping( file = file_factory.build_from_mapping(

View File

@ -24,7 +24,7 @@ class IfElseNode(BaseNode[IfElseNodeData]):
""" """
node_inputs: dict[str, list] = {"conditions": []} node_inputs: dict[str, list] = {"conditions": []}
process_data: dict[str, list] = {"condition_results": []} process_datas: dict[str, list] = {"condition_results": []}
input_conditions = [] input_conditions = []
final_result = False final_result = False
@ -40,7 +40,7 @@ class IfElseNode(BaseNode[IfElseNodeData]):
operator=case.logical_operator, operator=case.logical_operator,
) )
process_data["condition_results"].append( process_datas["condition_results"].append(
{ {
"group": case.model_dump(), "group": case.model_dump(),
"results": group_result, "results": group_result,
@ -65,7 +65,7 @@ class IfElseNode(BaseNode[IfElseNodeData]):
selected_case_id = "true" if final_result else "false" selected_case_id = "true" if final_result else "false"
process_data["condition_results"].append( process_datas["condition_results"].append(
{"group": "default", "results": group_result, "final_result": final_result} {"group": "default", "results": group_result, "final_result": final_result}
) )
@ -73,7 +73,7 @@ class IfElseNode(BaseNode[IfElseNodeData]):
except Exception as e: except Exception as e:
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_data, error=str(e) status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_datas, error=str(e)
) )
outputs = {"result": final_result, "selected_case_id": selected_case_id} outputs = {"result": final_result, "selected_case_id": selected_case_id}
@ -81,7 +81,7 @@ class IfElseNode(BaseNode[IfElseNodeData]):
data = NodeRunResult( data = NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs, inputs=node_inputs,
process_data=process_data, process_data=process_datas,
edge_source_handle=selected_case_id or "false", # Use case ID or 'default' edge_source_handle=selected_case_id or "false", # Use case ID or 'default'
outputs=outputs, outputs=outputs,
) )

View File

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
from flask import Flask, current_app from flask import Flask, current_app
from configs import dify_config from configs import dify_config
from core.variables import IntegerVariable from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import ( from core.workflow.entities.node_entities import (
NodeRunMetadataKey, NodeRunMetadataKey,
NodeRunResult, NodeRunResult,
@ -116,7 +116,7 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "item"], iterator_list_value[0]) variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine # init graph engine
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool from core.workflow.graph_engine.graph_engine import GraphEngine
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
@ -155,34 +155,33 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_node_data=self.node_data, iteration_node_data=self.node_data,
index=0, index=0,
pre_iteration_output=None, pre_iteration_output=None,
duration=None,
) )
iter_run_map: dict[str, float] = {} iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value) outputs: list[Any] = [None] * len(iterator_list_value)
try: try:
if self.node_data.is_parallel: if self.node_data.is_parallel:
futures: list[Future] = [] futures: list[Future] = []
q: Queue = Queue() q = Queue()
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100) thread_pool = graph_engine.workflow_thread_pool_mapping[self.thread_pool_id]
thread_pool._max_workers = self.node_data.parallel_nums
for index, item in enumerate(iterator_list_value): for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit( future: Future = thread_pool.submit(
self._run_single_iter_parallel, self._run_single_iter_parallel,
flask_app=current_app._get_current_object(), # type: ignore current_app._get_current_object(),
q=q, q,
iterator_list_value=iterator_list_value, iterator_list_value,
inputs=inputs, inputs,
outputs=outputs, outputs,
start_at=start_at, start_at,
graph_engine=graph_engine, graph_engine,
iteration_graph=iteration_graph, iteration_graph,
index=index, index,
item=item, item,
iter_run_map=iter_run_map, iter_run_map,
) )
future.add_done_callback(thread_pool.task_done_callback) future.add_done_callback(thread_pool.task_done_callback)
futures.append(future) futures.append(future)
succeeded_count = 0 succeeded_count = 0
empty_count = 0
while True: while True:
try: try:
event = q.get(timeout=1) event = q.get(timeout=1)
@ -210,22 +209,17 @@ class IterationNode(BaseNode[IterationNodeData]):
else: else:
for _ in range(len(iterator_list_value)): for _ in range(len(iterator_list_value)):
yield from self._run_single_iter( yield from self._run_single_iter(
iterator_list_value=iterator_list_value, iterator_list_value,
variable_pool=variable_pool, variable_pool,
inputs=inputs, inputs,
outputs=outputs, outputs,
start_at=start_at, start_at,
graph_engine=graph_engine, graph_engine,
iteration_graph=iteration_graph, iteration_graph,
iter_run_map=iter_run_map, iter_run_map,
) )
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None] outputs = [output for output in outputs if output is not None]
# Flatten the list of lists
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
outputs = [item for sublist in outputs for item in sublist]
yield IterationRunSucceededEvent( yield IterationRunSucceededEvent(
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
@ -233,7 +227,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_node_data=self.node_data, iteration_node_data=self.node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs={"output": outputs}, outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value), steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
) )
@ -241,10 +235,10 @@ class IterationNode(BaseNode[IterationNodeData]):
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs}, outputs={"output": jsonable_encoder(outputs)},
metadata={ metadata={
NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map, NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, "total_tokens": graph_engine.graph_runtime_state.total_tokens,
}, },
) )
) )
@ -258,7 +252,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_node_data=self.node_data, iteration_node_data=self.node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs={"output": outputs}, outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value), steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=str(e), error=str(e),
@ -268,6 +262,7 @@ class IterationNode(BaseNode[IterationNodeData]):
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
error=str(e), error=str(e),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
) )
) )
finally: finally:
@ -290,7 +285,7 @@ class IterationNode(BaseNode[IterationNodeData]):
:param node_data: node data :param node_data: node data
:return: :return:
""" """
variable_mapping: dict[str, Sequence[str]] = { variable_mapping = {
f"{node_id}.input_selector": node_data.iterator_selector, f"{node_id}.input_selector": node_data.iterator_selector,
} }
@ -307,18 +302,17 @@ class IterationNode(BaseNode[IterationNodeData]):
# variable selector to variable mapping # variable selector to variable mapping
try: try:
# Get node class # Get node class
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import node_type_classes_mapping
node_type = NodeType(sub_node_config.get("data", {}).get("type")) node_type = NodeType(sub_node_config.get("data", {}).get("type"))
if node_type not in NODE_TYPE_CLASSES_MAPPING: node_cls = node_type_classes_mapping.get(node_type)
if not node_cls:
continue continue
node_version = sub_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config graph_config=graph_config, config=sub_node_config
) )
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
except NotImplementedError: except NotImplementedError:
sub_node_variable_mapping = {} sub_node_variable_mapping = {}
@ -339,12 +333,8 @@ class IterationNode(BaseNode[IterationNodeData]):
return variable_mapping return variable_mapping
def _handle_event_metadata( def _handle_event_metadata(
self, self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str
*, ) -> NodeRunStartedEvent | BaseNodeEvent:
event: BaseNodeEvent | InNodeEvent,
iter_run_index: int,
parallel_mode_run_id: str | None,
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
""" """
add iteration metadata to event. add iteration metadata to event.
""" """
@ -369,10 +359,9 @@ class IterationNode(BaseNode[IterationNodeData]):
def _run_single_iter( def _run_single_iter(
self, self,
*, iterator_list_value: list[str],
iterator_list_value: Sequence[str],
variable_pool: VariablePool, variable_pool: VariablePool,
inputs: Mapping[str, list], inputs: dict[str, list],
outputs: list, outputs: list,
start_at: datetime, start_at: datetime,
graph_engine: "GraphEngine", graph_engine: "GraphEngine",
@ -388,12 +377,12 @@ class IterationNode(BaseNode[IterationNodeData]):
try: try:
rst = graph_engine.run() rst = graph_engine.run()
# get current iteration index # get current iteration index
index_variable = variable_pool.get([self.node_id, "index"]) current_index = variable_pool.get([self.node_id, "index"]).value
if not isinstance(index_variable, IntegerVariable):
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
current_index = index_variable.value
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}" iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
next_index = int(current_index) + 1 next_index = int(current_index) + 1
if current_index is None:
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
for event in rst: for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id event.in_iteration_id = self.node_id
@ -406,9 +395,7 @@ class IterationNode(BaseNode[IterationNodeData]):
continue continue
if isinstance(event, NodeRunSucceededEvent): if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata( yield self._handle_event_metadata(event, current_index, parallel_mode_run_id)
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
elif isinstance(event, BaseGraphEvent): elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent): if isinstance(event, GraphRunFailedEvent):
# iteration run failed # iteration run failed
@ -421,7 +408,7 @@ class IterationNode(BaseNode[IterationNodeData]):
parallel_mode_run_id=parallel_mode_run_id, parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs={"output": outputs}, outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value), steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error, error=event.error,
@ -434,7 +421,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_node_data=self.node_data, iteration_node_data=self.node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs={"output": outputs}, outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value), steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error, error=event.error,
@ -446,11 +433,9 @@ class IterationNode(BaseNode[IterationNodeData]):
) )
) )
return return
elif isinstance(event, InNodeEvent): else:
# event = cast(InNodeEvent, event) event = cast(InNodeEvent, event)
metadata_event = self._handle_event_metadata( metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id)
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
if isinstance(event, NodeRunFailedEvent): if isinstance(event, NodeRunFailedEvent):
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent( yield NodeInIterationFailedEvent(
@ -532,7 +517,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_node_data=self.node_data, iteration_node_data=self.node_data,
index=next_index, index=next_index,
parallel_mode_run_id=parallel_mode_run_id, parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=current_iteration_output or None, pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
duration=duration, duration=duration,
) )
@ -559,11 +544,10 @@ class IterationNode(BaseNode[IterationNodeData]):
def _run_single_iter_parallel( def _run_single_iter_parallel(
self, self,
*,
flask_app: Flask, flask_app: Flask,
q: Queue, q: Queue,
iterator_list_value: Sequence[str], iterator_list_value: list[str],
inputs: Mapping[str, list], inputs: dict[str, list],
outputs: list, outputs: list,
start_at: datetime, start_at: datetime,
graph_engine: "GraphEngine", graph_engine: "GraphEngine",
@ -571,7 +555,7 @@ class IterationNode(BaseNode[IterationNodeData]):
index: int, index: int,
item: Any, item: Any,
iter_run_map: dict[str, float], iter_run_map: dict[str, float],
): ) -> Generator[NodeEvent | InNodeEvent, None, None]:
""" """
run single iteration in parallel mode run single iteration in parallel mode
""" """

View File

@ -815,7 +815,7 @@ class LLMNode(BaseNode[LLMNodeData]):
"completion_model": { "completion_model": {
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
"prompt": { "prompt": {
"text": "Here are the chat histories between human and assistant, inside " "text": "Here is the chat histories between human and assistant, inside "
"<histories></histories> XML tags.\n\n<histories>\n{{" "<histories></histories> XML tags.\n\n<histories>\n{{"
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", "#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
"edition_type": "basic", "edition_type": "basic",

View File

@ -1,5 +1,3 @@
from collections.abc import Mapping
from core.workflow.nodes.answer import AnswerNode from core.workflow.nodes.answer import AnswerNode
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.code import CodeNode from core.workflow.nodes.code import CodeNode
@ -18,87 +16,26 @@ from core.workflow.nodes.start import StartNode
from core.workflow.nodes.template_transform import TemplateTransformNode from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.nodes.tool import ToolNode from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1 from core.workflow.nodes.variable_assigner import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
LATEST_VERSION = "latest" node_type_classes_mapping: dict[NodeType, type[BaseNode]] = {
NodeType.START: StartNode,
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { NodeType.END: EndNode,
NodeType.START: { NodeType.ANSWER: AnswerNode,
LATEST_VERSION: StartNode, NodeType.LLM: LLMNode,
"1": StartNode, NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
}, NodeType.IF_ELSE: IfElseNode,
NodeType.END: { NodeType.CODE: CodeNode,
LATEST_VERSION: EndNode, NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
"1": EndNode, NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
}, NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.ANSWER: { NodeType.TOOL: ToolNode,
LATEST_VERSION: AnswerNode, NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
"1": AnswerNode, NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
}, NodeType.ITERATION: IterationNode,
NodeType.LLM: { NodeType.ITERATION_START: IterationStartNode,
LATEST_VERSION: LLMNode, NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
"1": LLMNode, NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
}, NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode,
NodeType.KNOWLEDGE_RETRIEVAL: { NodeType.LIST_OPERATOR: ListOperatorNode,
LATEST_VERSION: KnowledgeRetrievalNode,
"1": KnowledgeRetrievalNode,
},
NodeType.IF_ELSE: {
LATEST_VERSION: IfElseNode,
"1": IfElseNode,
},
NodeType.CODE: {
LATEST_VERSION: CodeNode,
"1": CodeNode,
},
NodeType.TEMPLATE_TRANSFORM: {
LATEST_VERSION: TemplateTransformNode,
"1": TemplateTransformNode,
},
NodeType.QUESTION_CLASSIFIER: {
LATEST_VERSION: QuestionClassifierNode,
"1": QuestionClassifierNode,
},
NodeType.HTTP_REQUEST: {
LATEST_VERSION: HttpRequestNode,
"1": HttpRequestNode,
},
NodeType.TOOL: {
LATEST_VERSION: ToolNode,
"1": ToolNode,
},
NodeType.VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
},
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
}, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: {
LATEST_VERSION: IterationNode,
"1": IterationNode,
},
NodeType.ITERATION_START: {
LATEST_VERSION: IterationStartNode,
"1": IterationStartNode,
},
NodeType.PARAMETER_EXTRACTOR: {
LATEST_VERSION: ParameterExtractorNode,
"1": ParameterExtractorNode,
},
NodeType.VARIABLE_ASSIGNER: {
LATEST_VERSION: VariableAssignerNodeV2,
"1": VariableAssignerNodeV1,
"2": VariableAssignerNodeV2,
},
NodeType.DOCUMENT_EXTRACTOR: {
LATEST_VERSION: DocumentExtractorNode,
"1": DocumentExtractorNode,
},
NodeType.LIST_OPERATOR: {
LATEST_VERSION: ListOperatorNode,
"1": ListOperatorNode,
},
} }

View File

@ -98,7 +98,7 @@ Step 3: Structure the extracted parameters to JSON object as specified in <struc
Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted. Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted.
### Memory ### Memory
Here are the chat histories between human and assistant, inside <histories></histories> XML tags. Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories> <histories>
{histories} {histories}
</histories> </histories>
@ -125,7 +125,7 @@ CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and out
The structure of the JSON object you can found in the instructions. The structure of the JSON object you can found in the instructions.
### Memory ### Memory
Here are the chat histories between human and assistant, inside <histories></histories> XML tags. Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories> <histories>
{histories} {histories}
</histories> </histories>

View File

@ -1,4 +1,4 @@
from .entities import QuestionClassifierNodeData from .entities import QuestionClassifierNodeData
from .question_classifier_node import QuestionClassifierNode from .question_classifier_node import QuestionClassifierNode
__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeData"] __all__ = ["QuestionClassifierNodeData", "QuestionClassifierNode"]

View File

@ -8,7 +8,7 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
### Constraint ### Constraint
DO NOT include anything other than the JSON array in your response. DO NOT include anything other than the JSON array in your response.
### Memory ### Memory
Here are the chat histories between human and assistant, inside <histories></histories> XML tags. Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories> <histories>
{histories} {histories}
</histories> </histories>
@ -66,7 +66,7 @@ User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"
Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}} Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}}
</example> </example>
### Memory ### Memory
Here are the chat histories between human and assistant, inside <histories></histories> XML tags. Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories> <histories>
{histories} {histories}
</histories> </histories>

View File

@ -0,0 +1,8 @@
from .node import VariableAssignerNode
from .node_data import VariableAssignerData, WriteMode
__all__ = [
"VariableAssignerNode",
"VariableAssignerData",
"WriteMode",
]

View File

@ -1,4 +0,0 @@
class VariableOperatorNodeError(Exception):
"""Base error type, don't use directly."""
pass

View File

@ -1,19 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables import Variable
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from extensions.ext_database import db
from models import ConversationVariable
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()

View File

@ -0,0 +1,2 @@
class VariableAssignerNodeError(Exception):
pass

View File

@ -1,36 +1,40 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables import SegmentType, Variable from core.variables import SegmentType, Variable
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode, BaseNodeData from core.workflow.nodes.base import BaseNode, BaseNodeData
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from extensions.ext_database import db
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory from factories import variable_factory
from models import ConversationVariable
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from .exc import VariableAssignerNodeError
from .node_data import VariableAssignerData, WriteMode from .node_data import VariableAssignerData, WriteMode
class VariableAssignerNode(BaseNode[VariableAssignerData]): class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls: type[BaseNodeData] = VariableAssignerData _node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type = NodeType.VARIABLE_ASSIGNER _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector) original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
if not isinstance(original_variable, Variable): if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found") raise VariableAssignerNodeError("assigned variable not found")
match self.node_data.write_mode: match self.node_data.write_mode:
case WriteMode.OVER_WRITE: case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value: if not income_value:
raise VariableOperatorNodeError("input value not found") raise VariableAssignerNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value}) updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND: case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value: if not income_value:
raise VariableOperatorNodeError("input value not found") raise VariableAssignerNodeError("input value not found")
updated_value = original_variable.value + [income_value.value] updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={"value": updated_value}) updated_variable = original_variable.model_copy(update={"value": updated_value})
@ -39,7 +43,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _: case _:
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}") raise VariableAssignerNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable. # Over write the variable.
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable) self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
@ -48,8 +52,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
# Update conversation variable. # Update conversation variable.
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id: if not conversation_id:
raise VariableOperatorNodeError("conversation_id not found") raise VariableAssignerNodeError("conversation_id not found")
common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -59,6 +63,18 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
) )
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableAssignerNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
def get_zero_value(t: SegmentType): def get_zero_value(t: SegmentType):
match t: match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
@ -70,4 +86,4 @@ def get_zero_value(t: SegmentType):
case SegmentType.NUMBER: case SegmentType.NUMBER:
return variable_factory.build_segment(0) return variable_factory.build_segment(0)
case _: case _:
raise VariableOperatorNodeError(f"unsupported variable type: {t}") raise VariableAssignerNodeError(f"unsupported variable type: {t}")

View File

@ -1,5 +1,6 @@
from collections.abc import Sequence from collections.abc import Sequence
from enum import StrEnum from enum import StrEnum
from typing import Optional
from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base import BaseNodeData
@ -11,6 +12,8 @@ class WriteMode(StrEnum):
class VariableAssignerData(BaseNodeData): class VariableAssignerData(BaseNodeData):
title: str = "Variable Assigner"
desc: Optional[str] = "Assign a value to a variable"
assigned_variable_selector: Sequence[str] assigned_variable_selector: Sequence[str]
write_mode: WriteMode write_mode: WriteMode
input_variable_selector: Sequence[str] input_variable_selector: Sequence[str]

View File

@ -1,3 +0,0 @@
from .node import VariableAssignerNode
__all__ = ["VariableAssignerNode"]

View File

@ -1,3 +0,0 @@
from .node import VariableAssignerNode
__all__ = ["VariableAssignerNode"]

View File

@ -1,11 +0,0 @@
from core.variables import SegmentType
EMPTY_VALUE_MAPPING = {
SegmentType.STRING: "",
SegmentType.NUMBER: 0,
SegmentType.OBJECT: {},
SegmentType.ARRAY_ANY: [],
SegmentType.ARRAY_STRING: [],
SegmentType.ARRAY_NUMBER: [],
SegmentType.ARRAY_OBJECT: [],
}

View File

@ -1,20 +0,0 @@
from collections.abc import Sequence
from typing import Any
from pydantic import BaseModel
from core.workflow.nodes.base import BaseNodeData
from .enums import InputType, Operation
class VariableOperationItem(BaseModel):
variable_selector: Sequence[str]
input_type: InputType
operation: Operation
value: Any | None = None
class VariableAssignerNodeData(BaseNodeData):
version: str = "2"
items: Sequence[VariableOperationItem]

View File

@ -1,18 +0,0 @@
from enum import StrEnum
class Operation(StrEnum):
OVER_WRITE = "over-write"
CLEAR = "clear"
APPEND = "append"
EXTEND = "extend"
SET = "set"
ADD = "+="
SUBTRACT = "-="
MULTIPLY = "*="
DIVIDE = "/="
class InputType(StrEnum):
VARIABLE = "variable"
CONSTANT = "constant"

View File

@ -1,31 +0,0 @@
from collections.abc import Sequence
from typing import Any
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from .enums import InputType, Operation
class OperationNotSupportedError(VariableOperatorNodeError):
def __init__(self, *, operation: Operation, variable_type: str):
super().__init__(f"Operation {operation} is not supported for type {variable_type}")
class InputTypeNotSupportedError(VariableOperatorNodeError):
def __init__(self, *, input_type: InputType, operation: Operation):
super().__init__(f"Input type {input_type} is not supported for operation {operation}")
class VariableNotFoundError(VariableOperatorNodeError):
def __init__(self, *, variable_selector: Sequence[str]):
super().__init__(f"Variable {variable_selector} not found")
class InvalidInputValueError(VariableOperatorNodeError):
def __init__(self, *, value: Any):
super().__init__(f"Invalid input value {value}")
class ConversationIDNotFoundError(VariableOperatorNodeError):
def __init__(self):
super().__init__("conversation_id not found")

View File

@ -1,91 +0,0 @@
from typing import Any
from core.variables import SegmentType
from .enums import Operation
def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
match operation:
case Operation.OVER_WRITE | Operation.CLEAR:
return True
case Operation.SET:
return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER}
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
# Only number variable can be added, subtracted, multiplied or divided
return variable_type == SegmentType.NUMBER
case Operation.APPEND | Operation.EXTEND:
# Only array variable can be appended or extended
return variable_type in {
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
case _:
return False
def is_variable_input_supported(*, operation: Operation):
if operation in {Operation.SET, Operation.ADD, Operation.SUBTRACT, Operation.MULTIPLY, Operation.DIVIDE}:
return False
return True
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
match variable_type:
case SegmentType.STRING | SegmentType.OBJECT:
return operation in {Operation.OVER_WRITE, Operation.SET}
case SegmentType.NUMBER:
return operation in {
Operation.OVER_WRITE,
Operation.SET,
Operation.ADD,
Operation.SUBTRACT,
Operation.MULTIPLY,
Operation.DIVIDE,
}
case _:
return False
def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any):
if operation == Operation.CLEAR:
return True
match variable_type:
case SegmentType.STRING:
return isinstance(value, str)
case SegmentType.NUMBER:
if not isinstance(value, int | float):
return False
if operation == Operation.DIVIDE and value == 0:
return False
return True
case SegmentType.OBJECT:
return isinstance(value, dict)
# Array & Append
case SegmentType.ARRAY_ANY if operation == Operation.APPEND:
return isinstance(value, str | float | int | dict)
case SegmentType.ARRAY_STRING if operation == Operation.APPEND:
return isinstance(value, str)
case SegmentType.ARRAY_NUMBER if operation == Operation.APPEND:
return isinstance(value, int | float)
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
return isinstance(value, dict)
# Array & Extend / Overwrite
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, str | float | int | dict) for item in value)
case SegmentType.ARRAY_STRING if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, str) for item in value)
case SegmentType.ARRAY_NUMBER if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
case _:
return False

View File

@ -1,159 +0,0 @@
import json
from typing import Any
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from models.workflow import WorkflowNodeExecutionStatus
from . import helpers
from .constants import EMPTY_VALUE_MAPPING
from .entities import VariableAssignerNodeData
from .enums import InputType, Operation
from .exc import (
ConversationIDNotFoundError,
InputTypeNotSupportedError,
InvalidInputValueError,
OperationNotSupportedError,
VariableNotFoundError,
)
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER
def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump()
process_data = {}
# NOTE: This node has no outputs
updated_variables: list[Variable] = []
try:
for item in self.node_data.items:
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
# ==================== Validation Part
# Check if variable exists
if not isinstance(variable, Variable):
raise VariableNotFoundError(variable_selector=item.variable_selector)
# Check if operation is supported
if not helpers.is_operation_supported(variable_type=variable.value_type, operation=item.operation):
raise OperationNotSupportedError(operation=item.operation, variable_type=variable.value_type)
# Check if variable input is supported
if item.input_type == InputType.VARIABLE and not helpers.is_variable_input_supported(
operation=item.operation
):
raise InputTypeNotSupportedError(input_type=InputType.VARIABLE, operation=item.operation)
# Check if constant input is supported
if item.input_type == InputType.CONSTANT and not helpers.is_constant_input_supported(
variable_type=variable.value_type, operation=item.operation
):
raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation)
# Get value from variable pool
if (
item.input_type == InputType.VARIABLE
and item.operation != Operation.CLEAR
and item.value is not None
):
value = self.graph_runtime_state.variable_pool.get(item.value)
if value is None:
raise VariableNotFoundError(variable_selector=item.value)
# Skip if value is NoneSegment
if value.value_type == SegmentType.NONE:
continue
item.value = value.value
# If set string / bytes / bytearray to object, try convert string to object.
if (
item.operation == Operation.SET
and variable.value_type == SegmentType.OBJECT
and isinstance(item.value, str | bytes | bytearray)
):
try:
item.value = json.loads(item.value)
except json.JSONDecodeError:
raise InvalidInputValueError(value=item.value)
# Check if input value is valid
if not helpers.is_input_value_valid(
variable_type=variable.value_type, operation=item.operation, value=item.value
):
raise InvalidInputValueError(value=item.value)
# ==================== Execution Part
updated_value = self._handle_item(
variable=variable,
operation=item.operation,
value=item.value,
)
variable = variable.model_copy(update={"value": updated_value})
updated_variables.append(variable)
except VariableOperatorNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
error=str(e),
)
# Update variables
for variable in updated_variables:
self.graph_runtime_state.variable_pool.add(variable.selector, variable)
process_data[variable.name] = variable.value
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise ConversationIDNotFoundError
else:
conversation_id = conversation_id.value
common_helpers.update_conversation_variable(
conversation_id=conversation_id,
variable=variable,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
)
def _handle_item(
self,
*,
variable: Variable,
operation: Operation,
value: Any,
):
match operation:
case Operation.OVER_WRITE:
return value
case Operation.CLEAR:
return EMPTY_VALUE_MAPPING[variable.value_type]
case Operation.APPEND:
return variable.value + [value]
case Operation.EXTEND:
return variable.value + value
case Operation.SET:
return value
case Operation.ADD:
return variable.value + value
case Operation.SUBTRACT:
return variable.value - value
case Operation.MULTIPLY:
return variable.value * value
case Operation.DIVIDE:
return variable.value / value
case _:
raise OperationNotSupportedError(operation=operation, variable_type=variable.value_type)

View File

@ -2,7 +2,7 @@ import logging
import time import time
import uuid import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional from typing import Any, Optional, cast
from configs import dify_config from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
@ -19,7 +19,7 @@ from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import node_type_classes_mapping
from factories import file_factory from factories import file_factory
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import ( from models.workflow import (
@ -145,8 +145,11 @@ class WorkflowEntry:
# Get node class # Get node class
node_type = NodeType(node_config.get("data", {}).get("type")) node_type = NodeType(node_config.get("data", {}).get("type"))
node_version = node_config.get("data", {}).get("version", "1") node_cls = node_type_classes_mapping.get(node_type)
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] node_cls = cast(type[BaseNode], node_cls)
if not node_cls:
raise ValueError(f"Node class not found for node type {node_type}")
# init variable pool # init variable pool
variable_pool = VariablePool(environment_variables=workflow.environment_variables) variable_pool = VariablePool(environment_variables=workflow.environment_variables)

Some files were not shown because too many files have changed in this diff Show More