mirror of
https://github.com/langgenius/dify.git
synced 2026-04-14 09:35:50 +08:00
Compare commits
2 Commits
0.13.2
...
fix/iterat
| Author | SHA1 | Date | |
|---|---|---|---|
| 5f7771bc47 | |||
| 286741e139 |
13
.github/pull_request_template.md
vendored
13
.github/pull_request_template.md
vendored
@ -8,9 +8,16 @@ Please include a summary of the change and which issue is fixed. Please also inc
|
||||
|
||||
# Screenshots
|
||||
|
||||
| Before | After |
|
||||
|--------|-------|
|
||||
| ... | ... |
|
||||
<table>
|
||||
<tr>
|
||||
<td>Before: </td>
|
||||
<td>After: </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>...</td>
|
||||
<td>...</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
# Checklist
|
||||
|
||||
|
||||
@ -413,3 +413,4 @@ RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
|
||||
|
||||
CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
||||
|
||||
RETRIEVAL_TOP_N=0
|
||||
|
||||
@ -20,8 +20,6 @@ select = [
|
||||
"PLC0208", # iteration-over-set
|
||||
"PLC2801", # unnecessary-dunder-call
|
||||
"PLC0414", # useless-import-alias
|
||||
"PLE0604", # invalid-all-object
|
||||
"PLE0605", # invalid-all-format
|
||||
"PLR0402", # manual-from-import
|
||||
"PLR1711", # useless-return
|
||||
"PLR1714", # repeated-equality-comparison
|
||||
@ -30,7 +28,6 @@ select = [
|
||||
"RUF100", # unused-noqa
|
||||
"RUF101", # redirected-noqa
|
||||
"RUF200", # invalid-pyproject-toml
|
||||
"RUF022", # unsorted-dunder-all
|
||||
"S506", # unsafe-yaml-load
|
||||
"SIM", # flake8-simplify rules
|
||||
"TRY400", # error-instead-of-exception
|
||||
|
||||
@ -259,7 +259,7 @@ def migrate_knowledge_vector_database():
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
upper_collection_vector_types = {
|
||||
upper_colletion_vector_types = {
|
||||
VectorType.MILVUS,
|
||||
VectorType.PGVECTOR,
|
||||
VectorType.RELYT,
|
||||
@ -267,7 +267,7 @@ def migrate_knowledge_vector_database():
|
||||
VectorType.ORACLE,
|
||||
VectorType.ELASTICSEARCH,
|
||||
}
|
||||
lower_collection_vector_types = {
|
||||
lower_colletion_vector_types = {
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.CHROMA,
|
||||
VectorType.MYSCALE,
|
||||
@ -307,7 +307,7 @@ def migrate_knowledge_vector_database():
|
||||
continue
|
||||
collection_name = ""
|
||||
dataset_id = dataset.id
|
||||
if vector_type in upper_collection_vector_types:
|
||||
if vector_type in upper_colletion_vector_types:
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
elif vector_type == VectorType.QDRANT:
|
||||
if dataset.collection_binding_id:
|
||||
@ -323,7 +323,7 @@ def migrate_knowledge_vector_database():
|
||||
else:
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
|
||||
elif vector_type in lower_collection_vector_types:
|
||||
elif vector_type in lower_colletion_vector_types:
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
else:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
@ -626,6 +626,8 @@ class DataSetConfig(BaseSettings):
|
||||
default=30,
|
||||
)
|
||||
|
||||
RETRIEVAL_TOP_N: int = Field(description="number of retrieval top_n", default=0)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.13.2",
|
||||
default="0.12.1",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -100,11 +100,11 @@ class DraftWorkflowApi(Resource):
|
||||
try:
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = args.get("conversation_variables") or []
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
@ -382,7 +382,7 @@ class DefaultBlockConfigApi(Resource):
|
||||
filters = None
|
||||
if args.get("q"):
|
||||
try:
|
||||
filters = json.loads(args.get("q", ""))
|
||||
filters = json.loads(args.get("q"))
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid filters")
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse
|
||||
from sqlalchemy import and_
|
||||
@ -21,17 +20,8 @@ class InstalledAppsListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(installed_app_list_fields)
|
||||
def get(self):
|
||||
app_id = request.args.get("app_id", default=None, type=str)
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
|
||||
if app_id:
|
||||
installed_apps = (
|
||||
db.session.query(InstalledApp)
|
||||
.filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
|
||||
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
|
||||
|
||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
||||
installed_apps = [
|
||||
|
||||
@ -368,7 +368,6 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||
description=args["description"],
|
||||
parameters=args["parameters"],
|
||||
privacy_policy=args["privacy_policy"],
|
||||
labels=args["labels"],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks.
|
||||
|
||||
Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid detach errors.
|
||||
Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid deattach errors.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
@ -82,7 +82,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
for resource in metadata["retriever_resources"]:
|
||||
updated_resources.append(
|
||||
{
|
||||
"segment_id": resource.get("segment_id", ""),
|
||||
"segment_id": resource["segment_id"],
|
||||
"position": resource["position"],
|
||||
"document_name": resource["document_name"],
|
||||
"score": resource["score"],
|
||||
|
||||
@ -43,7 +43,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
@ -138,8 +138,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
|
||||
node_version = iteration_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
node_cls = node_type_classes_mapping[node_type]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
|
||||
@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
@ -113,6 +113,18 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
duration: Optional[float] = None
|
||||
|
||||
@field_validator("output", mode="before")
|
||||
@classmethod
|
||||
def set_output(cls, v):
|
||||
"""
|
||||
Set output
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, int | float | str | bool | dict | list):
|
||||
return v
|
||||
raise ValueError("output must be a valid type")
|
||||
|
||||
|
||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
@ -7,13 +7,13 @@ from .models import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FILE_MODEL_IDENTITY",
|
||||
"ArrayFileAttribute",
|
||||
"File",
|
||||
"FileAttribute",
|
||||
"FileBelongsTo",
|
||||
"FileTransferMethod",
|
||||
"FileType",
|
||||
"FileUploadConfig",
|
||||
"FileTransferMethod",
|
||||
"FileBelongsTo",
|
||||
"File",
|
||||
"ImageConfig",
|
||||
"FileAttribute",
|
||||
"ArrayFileAttribute",
|
||||
"FILE_MODEL_IDENTITY",
|
||||
]
|
||||
|
||||
@ -91,7 +91,7 @@ class XinferenceProvider(Provider):
|
||||
"""
|
||||
```
|
||||
|
||||
也可以直接抛出对应 Errors,并做如下定义,这样在之后的调用中可以直接抛出`InvokeConnectionError`等异常。
|
||||
也可以直接抛出对应Erros,并做如下定义,这样在之后的调用中可以直接抛出`InvokeConnectionError`等异常。
|
||||
|
||||
```python
|
||||
@property
|
||||
|
||||
@ -18,25 +18,25 @@ from .message_entities import (
|
||||
from .model_entities import ModelPropertyKey
|
||||
|
||||
__all__ = [
|
||||
"AssistantPromptMessage",
|
||||
"AudioPromptMessageContent",
|
||||
"DocumentPromptMessageContent",
|
||||
"ImagePromptMessageContent",
|
||||
"VideoPromptMessageContent",
|
||||
"PromptMessage",
|
||||
"PromptMessageRole",
|
||||
"LLMUsage",
|
||||
"ModelPropertyKey",
|
||||
"AssistantPromptMessage",
|
||||
"PromptMessage",
|
||||
"PromptMessageContent",
|
||||
"PromptMessageRole",
|
||||
"SystemPromptMessage",
|
||||
"TextPromptMessageContent",
|
||||
"UserPromptMessage",
|
||||
"PromptMessageTool",
|
||||
"ToolPromptMessage",
|
||||
"PromptMessageContentType",
|
||||
"LLMResult",
|
||||
"LLMResultChunk",
|
||||
"LLMResultChunkDelta",
|
||||
"LLMUsage",
|
||||
"ModelPropertyKey",
|
||||
"PromptMessage",
|
||||
"PromptMessage",
|
||||
"PromptMessageContent",
|
||||
"PromptMessageContentType",
|
||||
"PromptMessageRole",
|
||||
"PromptMessageRole",
|
||||
"PromptMessageTool",
|
||||
"SystemPromptMessage",
|
||||
"TextPromptMessageContent",
|
||||
"ToolPromptMessage",
|
||||
"UserPromptMessage",
|
||||
"VideoPromptMessageContent",
|
||||
"AudioPromptMessageContent",
|
||||
"DocumentPromptMessageContent",
|
||||
]
|
||||
|
||||
@ -16,7 +16,6 @@ help:
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
|
||||
@ -1,53 +0,0 @@
|
||||
model: amazon.nova-lite-v1:0
|
||||
label:
|
||||
en_US: Nova Lite V1
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
parameter_rules:
|
||||
- name: max_new_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 5000
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.00006'
|
||||
output: '0.00024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -1,52 +0,0 @@
|
||||
model: amazon.nova-micro-v1:0
|
||||
label:
|
||||
en_US: Nova Micro V1
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: max_new_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 5000
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.000035'
|
||||
output: '0.00014'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -1,53 +0,0 @@
|
||||
model: amazon.nova-pro-v1:0
|
||||
label:
|
||||
en_US: Nova Pro V1
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
parameter_rules:
|
||||
- name: max_new_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 5000
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.0008'
|
||||
output: '0.0032'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -70,8 +70,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
{"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False},
|
||||
{"prefix": "ai21.jamba-1-5", "support_system_prompts": True, "support_tool_use": False},
|
||||
{"prefix": "amazon.nova", "support_system_prompts": True, "support_tool_use": False},
|
||||
{"prefix": "us.amazon.nova", "support_system_prompts": True, "support_tool_use": False},
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@ -196,13 +194,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
if model_info["support_tool_use"] and tools:
|
||||
parameters["toolConfig"] = self._convert_converse_tool_config(tools=tools)
|
||||
try:
|
||||
# for issue #10976
|
||||
conversations_list = parameters["messages"]
|
||||
# if two consecutive user messages found, combine them into one message
|
||||
for i in range(len(conversations_list) - 2, -1, -1):
|
||||
if conversations_list[i]["role"] == conversations_list[i + 1]["role"]:
|
||||
conversations_list[i]["content"].extend(conversations_list.pop(i + 1)["content"])
|
||||
|
||||
if stream:
|
||||
response = bedrock_client.converse_stream(**parameters)
|
||||
return self._handle_converse_stream_response(
|
||||
|
||||
@ -1,53 +0,0 @@
|
||||
model: us.amazon.nova-lite-v1:0
|
||||
label:
|
||||
en_US: Nova Lite V1 (US.Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
parameter_rules:
|
||||
- name: max_new_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 5000
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.00006'
|
||||
output: '0.00024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -1,52 +0,0 @@
|
||||
model: us.amazon.nova-micro-v1:0
|
||||
label:
|
||||
en_US: Nova Micro V1 (US.Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: max_new_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 5000
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.000035'
|
||||
output: '0.00014'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -1,53 +0,0 @@
|
||||
model: us.amazon.nova-pro-v1:0
|
||||
label:
|
||||
en_US: Nova Pro V1 (US.Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
parameter_rules:
|
||||
- name: max_new_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 5000
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.0008'
|
||||
output: '0.0032'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -1,2 +0,0 @@
|
||||
- amazon.rerank-v1
|
||||
- cohere.rerank-v3-5
|
||||
@ -1,4 +0,0 @@
|
||||
model: amazon.rerank-v1:0
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 5120
|
||||
@ -1,4 +0,0 @@
|
||||
model: cohere.rerank-v3-5:0
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 5120
|
||||
@ -1,147 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
|
||||
class BedrockRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for Cohere rerank model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=docs)
|
||||
|
||||
# initialize client
|
||||
client_config = Config(region_name=credentials["aws_region"])
|
||||
bedrock_runtime = boto3.client(
|
||||
service_name="bedrock-agent-runtime",
|
||||
config=client_config,
|
||||
aws_access_key_id=credentials.get("aws_access_key_id", ""),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
||||
)
|
||||
queries = [{"type": "TEXT", "textQuery": {"text": query}}]
|
||||
text_sources = []
|
||||
for text in docs:
|
||||
text_sources.append(
|
||||
{
|
||||
"type": "INLINE",
|
||||
"inlineDocumentSource": {
|
||||
"type": "TEXT",
|
||||
"textDocument": {
|
||||
"text": text,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
modelId = model
|
||||
region = credentials["aws_region"]
|
||||
model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
|
||||
rerankingConfiguration = {
|
||||
"type": "BEDROCK_RERANKING_MODEL",
|
||||
"bedrockRerankingConfiguration": {
|
||||
"numberOfResults": top_n,
|
||||
"modelConfiguration": {
|
||||
"modelArn": model_package_arn,
|
||||
},
|
||||
},
|
||||
}
|
||||
response = bedrock_runtime.rerank(
|
||||
queries=queries, sources=text_sources, rerankingConfiguration=rerankingConfiguration
|
||||
)
|
||||
|
||||
rerank_documents = []
|
||||
for idx, result in enumerate(response["results"]):
|
||||
# format document
|
||||
index = result["index"]
|
||||
rerank_document = RerankDocument(
|
||||
index=index,
|
||||
text=docs[index],
|
||||
score=result["relevanceScore"],
|
||||
)
|
||||
|
||||
# score threshold check
|
||||
if score_threshold is not None:
|
||||
if rerank_document.score >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
else:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self.invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke emd = genai.GenerativeModel(model) error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: [],
|
||||
}
|
||||
@ -2,4 +2,3 @@
|
||||
- rerank-english-v3.0
|
||||
- rerank-multilingual-v2.0
|
||||
- rerank-multilingual-v3.0
|
||||
- rerank-v3.5
|
||||
|
||||
@ -1,4 +0,0 @@
|
||||
model: rerank-v3.5
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 5120
|
||||
@ -1,38 +0,0 @@
|
||||
model: gemini-exp-1206
|
||||
label:
|
||||
en_US: Gemini exp 1206
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -252,7 +252,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().removeprefix("data: ")
|
||||
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
|
||||
@ -181,11 +181,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
# prepare the payload for a simple ping to the model
|
||||
data = {"model": model, "stream": stream}
|
||||
|
||||
if format_schema := model_parameters.pop("format", None):
|
||||
try:
|
||||
data["format"] = format_schema if format_schema == "json" else json.loads(format_schema)
|
||||
except json.JSONDecodeError as e:
|
||||
raise InvokeBadRequestError(f"Invalid format schema: {str(e)}")
|
||||
if "format" in model_parameters:
|
||||
data["format"] = model_parameters["format"]
|
||||
del model_parameters["format"]
|
||||
|
||||
if "keep_alive" in model_parameters:
|
||||
data["keep_alive"] = model_parameters["keep_alive"]
|
||||
@ -735,12 +733,12 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
ParameterRule(
|
||||
name="format",
|
||||
label=I18nObject(en_US="Format", zh_Hans="返回格式"),
|
||||
type=ParameterType.TEXT,
|
||||
default="json",
|
||||
type=ParameterType.STRING,
|
||||
help=I18nObject(
|
||||
en_US="the format to return a response in. Format can be `json` or a JSON schema.",
|
||||
zh_Hans="返回响应的格式。目前接受的值是字符串`json`或JSON schema.",
|
||||
en_US="the format to return a response in. Currently the only accepted value is json.",
|
||||
zh_Hans="返回响应的格式。目前唯一接受的值是json。",
|
||||
),
|
||||
options=["json"],
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
|
||||
@ -462,7 +462,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().removeprefix("data: ")
|
||||
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
|
||||
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
|
||||
continue
|
||||
|
||||
|
||||
@ -250,7 +250,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().removeprefix("data: ")
|
||||
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
|
||||
@ -104,14 +104,13 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
# use Anthropic official SDK references
|
||||
# - https://github.com/anthropics/anthropic-sdk-python
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
project_id = credentials["vertex_project_id"]
|
||||
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
|
||||
token = ""
|
||||
|
||||
# get access token from service account credential
|
||||
if service_account_key:
|
||||
service_account_info = json.loads(base64.b64decode(service_account_key))
|
||||
if service_account_info:
|
||||
credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
|
||||
request = google.auth.transport.requests.Request()
|
||||
credentials.refresh(request)
|
||||
@ -479,11 +478,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
if service_account_key:
|
||||
service_account_info = json.loads(base64.b64decode(service_account_key))
|
||||
if service_account_info:
|
||||
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||
else:
|
||||
|
||||
@ -48,11 +48,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
if service_account_key:
|
||||
service_account_info = json.loads(base64.b64decode(service_account_key))
|
||||
if service_account_info:
|
||||
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||
else:
|
||||
@ -101,11 +100,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
if service_account_key:
|
||||
service_account_info = json.loads(base64.b64decode(service_account_key))
|
||||
if service_account_info:
|
||||
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||
else:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .common import ChatRole
|
||||
from .maas import MaasError, MaasService
|
||||
|
||||
__all__ = ["ChatRole", "MaasError", "MaasService"]
|
||||
__all__ = ["MaasService", "ChatRole", "MaasError"]
|
||||
|
||||
@ -17,13 +17,7 @@ class WenxinRerank(_CommonWenxin):
|
||||
def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None):
|
||||
access_token = self._get_access_token()
|
||||
url = f"{self.api_bases[model]}?access_token={access_token}"
|
||||
# For issue #11252
|
||||
# for wenxin Rerank model top_n length should be equal or less than docs length
|
||||
if top_n is not None and top_n > len(docs):
|
||||
top_n = len(docs)
|
||||
# for wenxin Rerank model, query should not be an empty string
|
||||
if query == "":
|
||||
query = " " # FIXME: this is a workaround for wenxin rerank model for better user experience.
|
||||
|
||||
try:
|
||||
response = httpx.post(
|
||||
url,
|
||||
@ -31,11 +25,7 @@ class WenxinRerank(_CommonWenxin):
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
# wenxin error handling
|
||||
if "error_code" in data:
|
||||
raise InternalServerError(data["error_msg"])
|
||||
return data
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise InternalServerError(str(e))
|
||||
|
||||
@ -79,9 +69,6 @@ class WenxinRerankModel(RerankModel):
|
||||
results = wenxin_rerank.rerank(model, query, docs, top_n)
|
||||
|
||||
rerank_documents = []
|
||||
if "results" not in results:
|
||||
raise ValueError("results key not found in response")
|
||||
|
||||
for result in results["results"]:
|
||||
index = result["index"]
|
||||
if "document" in result:
|
||||
|
||||
@ -8,7 +8,6 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -8,7 +8,6 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -8,7 +8,6 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -8,7 +8,6 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -8,7 +8,6 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -8,7 +8,6 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -8,7 +8,6 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -8,7 +8,7 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
context_size: 10240
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -8,7 +8,6 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -4,7 +4,6 @@ label:
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2048
|
||||
features:
|
||||
- vision
|
||||
parameter_rules:
|
||||
|
||||
@ -1,52 +0,0 @@
|
||||
model: glm-4v-flash
|
||||
label:
|
||||
en_US: glm-4v-flash
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2048
|
||||
features:
|
||||
- vision
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.95
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
|
||||
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.6
|
||||
help:
|
||||
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
|
||||
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
|
||||
- name: do_sample
|
||||
label:
|
||||
zh_Hans: 采样策略
|
||||
en_US: Sampling strategy
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 1024
|
||||
- name: web_search
|
||||
type: boolean
|
||||
label:
|
||||
zh_Hans: 联网搜索
|
||||
en_US: Web Search
|
||||
default: false
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@ -4,7 +4,6 @@ label:
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
features:
|
||||
- vision
|
||||
- video
|
||||
|
||||
@ -22,6 +22,18 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
|
||||
The structure of the JSON object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
And you should always end the block with a "```" to indicate the end of the JSON object.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
|
||||
```JSON""" # noqa: E501
|
||||
|
||||
|
||||
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
def _invoke(
|
||||
@ -52,8 +64,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
# invoke model
|
||||
# stop = stop or []
|
||||
# self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
# def _transform_json_prompts(self, model: str, credentials: dict,
|
||||
# prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
# tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
# stream: bool = True, user: str | None = None) \
|
||||
# -> None:
|
||||
# """
|
||||
# Transform json prompts to model prompts
|
||||
# """
|
||||
# if "}\n\n" not in stop:
|
||||
# stop.append("}\n\n")
|
||||
|
||||
# # check if there is a system message
|
||||
# if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
# # override the system message
|
||||
# prompt_messages[0] = SystemPromptMessage(
|
||||
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content)
|
||||
# )
|
||||
# else:
|
||||
# # insert the system message
|
||||
# prompt_messages.insert(0, SystemPromptMessage(
|
||||
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", "Please output a valid JSON object.")
|
||||
# ))
|
||||
# # check if the last message is a user message
|
||||
# if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
|
||||
# # add ```JSON\n to the last message
|
||||
# prompt_messages[-1].content += "\n```JSON\n"
|
||||
# else:
|
||||
# # append a user message
|
||||
# prompt_messages.append(UserPromptMessage(
|
||||
# content="```JSON\n"
|
||||
# ))
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
@ -124,7 +170,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
extra_model_kwargs = {}
|
||||
# request to glm-4v-plus with stop words will always respond "finish_reason":"network_error"
|
||||
# request to glm-4v-plus with stop words will always response "finish_reason":"network_error"
|
||||
if stop and model != "glm-4v-plus":
|
||||
extra_model_kwargs["stop"] = stop
|
||||
|
||||
@ -140,11 +186,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
# resolve zhipuai model not support system message and user message, assistant message must be in sequence
|
||||
new_prompt_messages: list[PromptMessage] = []
|
||||
for prompt_message in prompt_messages:
|
||||
copy_prompt_message = prompt_message.model_copy()
|
||||
copy_prompt_message = prompt_message.copy()
|
||||
if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}:
|
||||
if isinstance(copy_prompt_message.content, list):
|
||||
# check if model is 'glm-4v'
|
||||
if not model.startswith("glm-4v"):
|
||||
if model not in {"glm-4v", "glm-4v-plus"}:
|
||||
# not support list message
|
||||
continue
|
||||
# get image and
|
||||
@ -188,42 +234,63 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
else:
|
||||
model_parameters["tools"] = [web_search_params]
|
||||
|
||||
if model.startswith("glm-4v"):
|
||||
if model in {"glm-4v", "glm-4v-plus"}:
|
||||
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
|
||||
else:
|
||||
params = {"model": model, "messages": [], **model_parameters}
|
||||
for prompt_message in new_prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.TOOL:
|
||||
params["messages"].append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": prompt_message.content,
|
||||
"tool_call_id": prompt_message.tool_call_id,
|
||||
}
|
||||
)
|
||||
elif isinstance(prompt_message, AssistantPromptMessage):
|
||||
if prompt_message.tool_calls:
|
||||
# glm model
|
||||
if not model.startswith("chatglm"):
|
||||
for prompt_message in new_prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.TOOL:
|
||||
params["messages"].append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"role": "tool",
|
||||
"content": prompt_message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": tool_call.type,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
}
|
||||
for tool_call in prompt_message.tool_calls
|
||||
],
|
||||
"tool_call_id": prompt_message.tool_call_id,
|
||||
}
|
||||
)
|
||||
elif isinstance(prompt_message, AssistantPromptMessage):
|
||||
if prompt_message.tool_calls:
|
||||
params["messages"].append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": prompt_message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": tool_call.type,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
}
|
||||
for tool_call in prompt_message.tool_calls
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
params["messages"].append({"role": "assistant", "content": prompt_message.content})
|
||||
else:
|
||||
params["messages"].append({"role": "assistant", "content": prompt_message.content})
|
||||
else:
|
||||
params["messages"].append({"role": prompt_message.role.value, "content": prompt_message.content})
|
||||
params["messages"].append(
|
||||
{"role": prompt_message.role.value, "content": prompt_message.content}
|
||||
)
|
||||
else:
|
||||
# chatglm model
|
||||
for prompt_message in new_prompt_messages:
|
||||
# merge system message to user message
|
||||
if prompt_message.role in {
|
||||
PromptMessageRole.SYSTEM,
|
||||
PromptMessageRole.TOOL,
|
||||
PromptMessageRole.USER,
|
||||
}:
|
||||
if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user":
|
||||
params["messages"][-1]["content"] += "\n\n" + prompt_message.content
|
||||
else:
|
||||
params["messages"].append({"role": "user", "content": prompt_message.content})
|
||||
else:
|
||||
params["messages"].append(
|
||||
{"role": prompt_message.role.value, "content": prompt_message.content}
|
||||
)
|
||||
|
||||
if tools and len(tools) > 0:
|
||||
params["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools]
|
||||
@ -339,7 +406,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param responses: response
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
@ -412,8 +479,6 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
human_prompt = "\n\nHuman:"
|
||||
ai_prompt = "\n\nAssistant:"
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
@ -440,7 +505,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
if tools and len(tools) > 0:
|
||||
text += "\n\nTools:"
|
||||
for tool in tools:
|
||||
text += f"\n{tool.model_dump_json()}"
|
||||
text += f"\n{tool.json()}"
|
||||
|
||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||
return text.rstrip()
|
||||
|
||||
@ -5,7 +5,7 @@ BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找
|
||||
CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"completion_prompt_config": {
|
||||
"prompt": {
|
||||
"text": "{{#pre_prompt#}}\nHere are the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: " # noqa: E501
|
||||
"text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: " # noqa: E501
|
||||
},
|
||||
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
|
||||
},
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import DifyConfig
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
@ -113,7 +114,7 @@ class RetrievalService:
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
top_n=DifyConfig.RETRIEVAL_TOP_N or top_k,
|
||||
)
|
||||
|
||||
return all_documents
|
||||
@ -185,7 +186,7 @@ class RetrievalService:
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents),
|
||||
top_n=DifyConfig.RETRIEVAL_TOP_N or len(documents),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -230,7 +231,7 @@ class RetrievalService:
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents),
|
||||
top_n=DifyConfig.RETRIEVAL_TOP_N or len(documents),
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@ -104,7 +104,8 @@ class OceanBaseVector(BaseVector):
|
||||
val = int(row[6])
|
||||
vals.append(val)
|
||||
if len(vals) == 0:
|
||||
raise ValueError("ob_vector_memory_limit_percentage not found in parameters.")
|
||||
print("ob_vector_memory_limit_percentage not found in parameters.")
|
||||
exit(1)
|
||||
if any(val == 0 for val in vals):
|
||||
try:
|
||||
self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30")
|
||||
@ -199,10 +200,10 @@ class OceanBaseVectorFactory(AbstractVectorFactory):
|
||||
return OceanBaseVector(
|
||||
collection_name,
|
||||
OceanBaseVectorConfig(
|
||||
host=dify_config.OCEANBASE_VECTOR_HOST or "",
|
||||
port=dify_config.OCEANBASE_VECTOR_PORT or 0,
|
||||
user=dify_config.OCEANBASE_VECTOR_USER or "",
|
||||
host=dify_config.OCEANBASE_VECTOR_HOST,
|
||||
port=dify_config.OCEANBASE_VECTOR_PORT,
|
||||
user=dify_config.OCEANBASE_VECTOR_USER,
|
||||
password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""),
|
||||
database=dify_config.OCEANBASE_VECTOR_DATABASE or "",
|
||||
database=dify_config.OCEANBASE_VECTOR_DATABASE,
|
||||
),
|
||||
)
|
||||
|
||||
@ -375,6 +375,7 @@ class TidbOnQdrantVector(BaseVector):
|
||||
for result in results:
|
||||
if result:
|
||||
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
|
||||
document.metadata["vector"] = result.vector
|
||||
documents.append(document)
|
||||
|
||||
return documents
|
||||
@ -393,7 +394,6 @@ class TidbOnQdrantVector(BaseVector):
|
||||
) -> Document:
|
||||
return Document(
|
||||
page_content=scored_point.payload.get(content_payload_key),
|
||||
vector=scored_point.vector,
|
||||
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||
)
|
||||
|
||||
|
||||
@ -162,7 +162,7 @@ class TidbService:
|
||||
clusters = []
|
||||
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
|
||||
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
|
||||
params = {"clusterIds": cluster_ids, "view": "BASIC"}
|
||||
params = {"clusterIds": cluster_ids, "view": "FULL"}
|
||||
response = requests.get(
|
||||
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
|
||||
)
|
||||
|
||||
@ -15,7 +15,7 @@ class ComfyUIProvider(BuiltinToolProviderController):
|
||||
|
||||
try:
|
||||
ws.connect(ws_address)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}")
|
||||
finally:
|
||||
ws.close()
|
||||
|
||||
@ -6,9 +6,9 @@ identity:
|
||||
zh_Hans: GitLab 合并请求查询
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for query GitLab merge requests, Input should be a exists repository or branch.
|
||||
en_US: A tool for query GitLab merge requests, Input should be a exists reposity or branch.
|
||||
zh_Hans: 一个用于查询 GitLab 代码合并请求的工具,输入的内容应该是一个已存在的仓库名或者分支。
|
||||
llm: A tool for query GitLab merge requests, Input should be a exists repository or branch.
|
||||
llm: A tool for query GitLab merge requests, Input should be a exists reposity or branch.
|
||||
parameters:
|
||||
- name: repository
|
||||
type: string
|
||||
|
||||
@ -61,7 +61,7 @@ class WolframAlphaTool(BuiltinTool):
|
||||
params["input"] = query
|
||||
else:
|
||||
finished = True
|
||||
if "sources" in response_data["queryresult"]:
|
||||
if "souces" in response_data["queryresult"]:
|
||||
return self.create_link_message(response_data["queryresult"]["sources"]["url"])
|
||||
elif "pods" in response_data["queryresult"]:
|
||||
result = response_data["queryresult"]["pods"][0]["subpods"][0]["plaintext"]
|
||||
|
||||
@ -32,32 +32,32 @@ from .variables import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ArrayAnySegment",
|
||||
"ArrayAnyVariable",
|
||||
"ArrayFileSegment",
|
||||
"ArrayFileVariable",
|
||||
"ArrayNumberSegment",
|
||||
"ArrayNumberVariable",
|
||||
"ArrayObjectSegment",
|
||||
"ArrayObjectVariable",
|
||||
"ArraySegment",
|
||||
"ArrayStringSegment",
|
||||
"ArrayStringVariable",
|
||||
"FileSegment",
|
||||
"FileVariable",
|
||||
"FloatSegment",
|
||||
"FloatVariable",
|
||||
"IntegerSegment",
|
||||
"IntegerVariable",
|
||||
"NoneSegment",
|
||||
"NoneVariable",
|
||||
"ObjectSegment",
|
||||
"FloatVariable",
|
||||
"ObjectVariable",
|
||||
"SecretVariable",
|
||||
"Segment",
|
||||
"SegmentGroup",
|
||||
"SegmentType",
|
||||
"StringSegment",
|
||||
"StringVariable",
|
||||
"ArrayAnyVariable",
|
||||
"Variable",
|
||||
"SegmentType",
|
||||
"SegmentGroup",
|
||||
"Segment",
|
||||
"NoneSegment",
|
||||
"NoneVariable",
|
||||
"IntegerSegment",
|
||||
"FloatSegment",
|
||||
"ObjectSegment",
|
||||
"ArrayAnySegment",
|
||||
"StringSegment",
|
||||
"ArrayStringVariable",
|
||||
"ArrayNumberVariable",
|
||||
"ArrayObjectVariable",
|
||||
"ArraySegment",
|
||||
"ArrayFileSegment",
|
||||
"ArrayNumberSegment",
|
||||
"ArrayObjectSegment",
|
||||
"ArrayStringSegment",
|
||||
"FileSegment",
|
||||
"FileVariable",
|
||||
"ArrayFileVariable",
|
||||
]
|
||||
|
||||
@ -2,19 +2,16 @@ from enum import StrEnum
|
||||
|
||||
|
||||
class SegmentType(StrEnum):
|
||||
NONE = "none"
|
||||
NUMBER = "number"
|
||||
STRING = "string"
|
||||
OBJECT = "object"
|
||||
SECRET = "secret"
|
||||
|
||||
FILE = "file"
|
||||
|
||||
ARRAY_ANY = "array[any]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
OBJECT = "object"
|
||||
FILE = "file"
|
||||
ARRAY_FILE = "array[file]"
|
||||
|
||||
NONE = "none"
|
||||
|
||||
GROUP = "group"
|
||||
|
||||
@ -2,6 +2,6 @@ from .base_workflow_callback import WorkflowCallback
|
||||
from .workflow_logging_callback import WorkflowLoggingCallback
|
||||
|
||||
__all__ = [
|
||||
"WorkflowCallback",
|
||||
"WorkflowLoggingCallback",
|
||||
"WorkflowCallback",
|
||||
]
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -141,8 +140,8 @@ class BaseIterationEvent(GraphEngineEvent):
|
||||
|
||||
class IterationRunStartedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
|
||||
|
||||
@ -154,18 +153,18 @@ class IterationRunNextEvent(BaseIterationEvent):
|
||||
|
||||
class IterationRunSucceededEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
steps: int = 0
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class IterationRunFailedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProce
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
@ -227,8 +227,7 @@ class GraphEngine:
|
||||
|
||||
# convert to specific node
|
||||
node_type = NodeType(node_config.get("data", {}).get("type"))
|
||||
node_version = node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
node_cls = node_type_classes_mapping[node_type]
|
||||
|
||||
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .answer_node import AnswerNode
|
||||
from .entities import AnswerStreamGenerateRoute
|
||||
|
||||
__all__ = ["AnswerNode", "AnswerStreamGenerateRoute"]
|
||||
__all__ = ["AnswerStreamGenerateRoute", "AnswerNode"]
|
||||
|
||||
@ -153,7 +153,7 @@ class AnswerStreamGeneratorRouter:
|
||||
NodeType.IF_ELSE,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
NodeType.ITERATION,
|
||||
NodeType.VARIABLE_ASSIGNER,
|
||||
NodeType.CONVERSATION_VARIABLE_ASSIGNER,
|
||||
}:
|
||||
answer_dependencies[answer_node_id].append(source_node_id)
|
||||
else:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
from .node import BaseNode
|
||||
|
||||
__all__ = ["BaseIterationNodeData", "BaseIterationState", "BaseNode", "BaseNodeData"]
|
||||
__all__ = ["BaseNode", "BaseNodeData", "BaseIterationNodeData", "BaseIterationState"]
|
||||
|
||||
@ -7,7 +7,6 @@ from pydantic import BaseModel
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
version: str = "1"
|
||||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
|
||||
@ -55,9 +55,7 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
self.node_id = node_id
|
||||
|
||||
node_data = self._node_data_cls.model_validate(config.get("data", {}))
|
||||
self.node_data = cast(GenericNodeData, node_data)
|
||||
self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {})))
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import docx
|
||||
import pandas as pd
|
||||
@ -266,20 +264,14 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
|
||||
def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
try:
|
||||
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
|
||||
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
with open(temp_file.name, "rb") as file:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY,
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
else:
|
||||
with io.BytesIO(file_content) as file:
|
||||
with io.BytesIO(file_content) as file:
|
||||
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY,
|
||||
)
|
||||
else:
|
||||
elements = partition_pptx(file=file)
|
||||
return "\n".join([getattr(element, "text", "") for element in elements])
|
||||
except Exception as e:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .end_node import EndNode
|
||||
from .entities import EndStreamParam
|
||||
|
||||
__all__ = ["EndNode", "EndStreamParam"]
|
||||
__all__ = ["EndStreamParam", "EndNode"]
|
||||
|
||||
@ -14,11 +14,11 @@ class NodeType(StrEnum):
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
VARIABLE_ASSIGNER = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
ITERATION = "iteration"
|
||||
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
VARIABLE_ASSIGNER = "assigner"
|
||||
CONVERSATION_VARIABLE_ASSIGNER = "assigner"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
|
||||
@ -2,9 +2,9 @@ from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverRes
|
||||
from .types import NodeEvent
|
||||
|
||||
__all__ = [
|
||||
"ModelInvokeCompletedEvent",
|
||||
"NodeEvent",
|
||||
"RunCompletedEvent",
|
||||
"RunRetrieverResourceEvent",
|
||||
"RunStreamChunkEvent",
|
||||
"NodeEvent",
|
||||
"ModelInvokeCompletedEvent",
|
||||
]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData
|
||||
from .node import HttpRequestNode
|
||||
|
||||
__all__ = ["BodyData", "HttpRequestNode", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "HttpRequestNodeData"]
|
||||
__all__ = ["HttpRequestNodeData", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "BodyData", "HttpRequestNode"]
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from mimetypes import guess_extension
|
||||
from os import path
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
@ -148,6 +150,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
content = response.content
|
||||
|
||||
if is_file and content_type:
|
||||
# extract filename from url
|
||||
filename = path.basename(url)
|
||||
# extract extension if possible
|
||||
extension = guess_extension(content_type) or ".bin"
|
||||
|
||||
tool_file = ToolFileManager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
@ -158,6 +165,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file.id,
|
||||
"type": FileType.IMAGE.value,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE.value,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
|
||||
@ -24,7 +24,7 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
"""
|
||||
node_inputs: dict[str, list] = {"conditions": []}
|
||||
|
||||
process_data: dict[str, list] = {"condition_results": []}
|
||||
process_datas: dict[str, list] = {"condition_results": []}
|
||||
|
||||
input_conditions = []
|
||||
final_result = False
|
||||
@ -40,7 +40,7 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
operator=case.logical_operator,
|
||||
)
|
||||
|
||||
process_data["condition_results"].append(
|
||||
process_datas["condition_results"].append(
|
||||
{
|
||||
"group": case.model_dump(),
|
||||
"results": group_result,
|
||||
@ -65,7 +65,7 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
|
||||
selected_case_id = "true" if final_result else "false"
|
||||
|
||||
process_data["condition_results"].append(
|
||||
process_datas["condition_results"].append(
|
||||
{"group": "default", "results": group_result, "final_result": final_result}
|
||||
)
|
||||
|
||||
@ -73,7 +73,7 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_data, error=str(e)
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_datas, error=str(e)
|
||||
)
|
||||
|
||||
outputs = {"result": final_result, "selected_case_id": selected_case_id}
|
||||
@ -81,7 +81,7 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
data = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
process_data=process_datas,
|
||||
edge_source_handle=selected_case_id or "false", # Use case ID or 'default'
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import IntegerVariable
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import (
|
||||
NodeRunMetadataKey,
|
||||
NodeRunResult,
|
||||
@ -116,7 +116,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
|
||||
|
||||
# init graph engine
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -155,34 +155,33 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
iteration_node_data=self.node_data,
|
||||
index=0,
|
||||
pre_iteration_output=None,
|
||||
duration=None,
|
||||
)
|
||||
iter_run_map: dict[str, float] = {}
|
||||
outputs: list[Any] = [None] * len(iterator_list_value)
|
||||
try:
|
||||
if self.node_data.is_parallel:
|
||||
futures: list[Future] = []
|
||||
q: Queue = Queue()
|
||||
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
|
||||
q = Queue()
|
||||
thread_pool = graph_engine.workflow_thread_pool_mapping[self.thread_pool_id]
|
||||
thread_pool._max_workers = self.node_data.parallel_nums
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
future: Future = thread_pool.submit(
|
||||
self._run_single_iter_parallel,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
q=q,
|
||||
iterator_list_value=iterator_list_value,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
start_at=start_at,
|
||||
graph_engine=graph_engine,
|
||||
iteration_graph=iteration_graph,
|
||||
index=index,
|
||||
item=item,
|
||||
iter_run_map=iter_run_map,
|
||||
current_app._get_current_object(),
|
||||
q,
|
||||
iterator_list_value,
|
||||
inputs,
|
||||
outputs,
|
||||
start_at,
|
||||
graph_engine,
|
||||
iteration_graph,
|
||||
index,
|
||||
item,
|
||||
iter_run_map,
|
||||
)
|
||||
future.add_done_callback(thread_pool.task_done_callback)
|
||||
futures.append(future)
|
||||
succeeded_count = 0
|
||||
empty_count = 0
|
||||
while True:
|
||||
try:
|
||||
event = q.get(timeout=1)
|
||||
@ -210,22 +209,17 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
else:
|
||||
for _ in range(len(iterator_list_value)):
|
||||
yield from self._run_single_iter(
|
||||
iterator_list_value=iterator_list_value,
|
||||
variable_pool=variable_pool,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
start_at=start_at,
|
||||
graph_engine=graph_engine,
|
||||
iteration_graph=iteration_graph,
|
||||
iter_run_map=iter_run_map,
|
||||
iterator_list_value,
|
||||
variable_pool,
|
||||
inputs,
|
||||
outputs,
|
||||
start_at,
|
||||
graph_engine,
|
||||
iteration_graph,
|
||||
iter_run_map,
|
||||
)
|
||||
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
outputs = [output for output in outputs if output is not None]
|
||||
|
||||
# Flatten the list of lists
|
||||
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
|
||||
outputs = [item for sublist in outputs for item in sublist]
|
||||
|
||||
yield IterationRunSucceededEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
@ -233,7 +227,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
)
|
||||
@ -241,10 +235,10 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": outputs},
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
metadata={
|
||||
NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
|
||||
},
|
||||
)
|
||||
)
|
||||
@ -258,7 +252,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=str(e),
|
||||
@ -268,6 +262,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
)
|
||||
)
|
||||
finally:
|
||||
@ -290,7 +285,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
variable_mapping: dict[str, Sequence[str]] = {
|
||||
variable_mapping = {
|
||||
f"{node_id}.input_selector": node_data.iterator_selector,
|
||||
}
|
||||
|
||||
@ -307,18 +302,17 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
|
||||
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
node_cls = node_type_classes_mapping.get(node_type)
|
||||
if not node_cls:
|
||||
continue
|
||||
node_version = sub_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
|
||||
sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
sub_node_variable_mapping = {}
|
||||
|
||||
@ -339,12 +333,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
return variable_mapping
|
||||
|
||||
def _handle_event_metadata(
|
||||
self,
|
||||
*,
|
||||
event: BaseNodeEvent | InNodeEvent,
|
||||
iter_run_index: int,
|
||||
parallel_mode_run_id: str | None,
|
||||
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
|
||||
self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str
|
||||
) -> NodeRunStartedEvent | BaseNodeEvent:
|
||||
"""
|
||||
add iteration metadata to event.
|
||||
"""
|
||||
@ -369,10 +359,9 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
|
||||
def _run_single_iter(
|
||||
self,
|
||||
*,
|
||||
iterator_list_value: Sequence[str],
|
||||
iterator_list_value: list[str],
|
||||
variable_pool: VariablePool,
|
||||
inputs: Mapping[str, list],
|
||||
inputs: dict[str, list],
|
||||
outputs: list,
|
||||
start_at: datetime,
|
||||
graph_engine: "GraphEngine",
|
||||
@ -388,12 +377,12 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
try:
|
||||
rst = graph_engine.run()
|
||||
# get current iteration index
|
||||
index_variable = variable_pool.get([self.node_id, "index"])
|
||||
if not isinstance(index_variable, IntegerVariable):
|
||||
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
|
||||
current_index = index_variable.value
|
||||
current_index = variable_pool.get([self.node_id, "index"]).value
|
||||
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
|
||||
next_index = int(current_index) + 1
|
||||
|
||||
if current_index is None:
|
||||
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||
event.in_iteration_id = self.node_id
|
||||
@ -406,9 +395,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
continue
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
yield self._handle_event_metadata(
|
||||
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
|
||||
)
|
||||
yield self._handle_event_metadata(event, current_index, parallel_mode_run_id)
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# iteration run failed
|
||||
@ -421,7 +408,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
@ -434,7 +421,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
@ -446,11 +433,9 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
)
|
||||
)
|
||||
return
|
||||
elif isinstance(event, InNodeEvent):
|
||||
# event = cast(InNodeEvent, event)
|
||||
metadata_event = self._handle_event_metadata(
|
||||
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
|
||||
)
|
||||
else:
|
||||
event = cast(InNodeEvent, event)
|
||||
metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id)
|
||||
if isinstance(event, NodeRunFailedEvent):
|
||||
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
yield NodeInIterationFailedEvent(
|
||||
@ -532,7 +517,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
iteration_node_data=self.node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=current_iteration_output or None,
|
||||
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
@ -559,11 +544,10 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
|
||||
def _run_single_iter_parallel(
|
||||
self,
|
||||
*,
|
||||
flask_app: Flask,
|
||||
q: Queue,
|
||||
iterator_list_value: Sequence[str],
|
||||
inputs: Mapping[str, list],
|
||||
iterator_list_value: list[str],
|
||||
inputs: dict[str, list],
|
||||
outputs: list,
|
||||
start_at: datetime,
|
||||
graph_engine: "GraphEngine",
|
||||
@ -571,7 +555,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
index: int,
|
||||
item: Any,
|
||||
iter_run_map: dict[str, float],
|
||||
):
|
||||
) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
run single iteration in parallel mode
|
||||
"""
|
||||
|
||||
@ -815,7 +815,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
"completion_model": {
|
||||
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
|
||||
"prompt": {
|
||||
"text": "Here are the chat histories between human and assistant, inside "
|
||||
"text": "Here is the chat histories between human and assistant, inside "
|
||||
"<histories></histories> XML tags.\n\n<histories>\n{{"
|
||||
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
|
||||
"edition_type": "basic",
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.workflow.nodes.answer import AnswerNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.code import CodeNode
|
||||
@ -18,87 +16,26 @@ from core.workflow.nodes.start import StartNode
|
||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
|
||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode
|
||||
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
NodeType.START: {
|
||||
LATEST_VERSION: StartNode,
|
||||
"1": StartNode,
|
||||
},
|
||||
NodeType.END: {
|
||||
LATEST_VERSION: EndNode,
|
||||
"1": EndNode,
|
||||
},
|
||||
NodeType.ANSWER: {
|
||||
LATEST_VERSION: AnswerNode,
|
||||
"1": AnswerNode,
|
||||
},
|
||||
NodeType.LLM: {
|
||||
LATEST_VERSION: LLMNode,
|
||||
"1": LLMNode,
|
||||
},
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: {
|
||||
LATEST_VERSION: KnowledgeRetrievalNode,
|
||||
"1": KnowledgeRetrievalNode,
|
||||
},
|
||||
NodeType.IF_ELSE: {
|
||||
LATEST_VERSION: IfElseNode,
|
||||
"1": IfElseNode,
|
||||
},
|
||||
NodeType.CODE: {
|
||||
LATEST_VERSION: CodeNode,
|
||||
"1": CodeNode,
|
||||
},
|
||||
NodeType.TEMPLATE_TRANSFORM: {
|
||||
LATEST_VERSION: TemplateTransformNode,
|
||||
"1": TemplateTransformNode,
|
||||
},
|
||||
NodeType.QUESTION_CLASSIFIER: {
|
||||
LATEST_VERSION: QuestionClassifierNode,
|
||||
"1": QuestionClassifierNode,
|
||||
},
|
||||
NodeType.HTTP_REQUEST: {
|
||||
LATEST_VERSION: HttpRequestNode,
|
||||
"1": HttpRequestNode,
|
||||
},
|
||||
NodeType.TOOL: {
|
||||
LATEST_VERSION: ToolNode,
|
||||
"1": ToolNode,
|
||||
},
|
||||
NodeType.VARIABLE_AGGREGATOR: {
|
||||
LATEST_VERSION: VariableAggregatorNode,
|
||||
"1": VariableAggregatorNode,
|
||||
},
|
||||
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
|
||||
LATEST_VERSION: VariableAggregatorNode,
|
||||
"1": VariableAggregatorNode,
|
||||
}, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: {
|
||||
LATEST_VERSION: IterationNode,
|
||||
"1": IterationNode,
|
||||
},
|
||||
NodeType.ITERATION_START: {
|
||||
LATEST_VERSION: IterationStartNode,
|
||||
"1": IterationStartNode,
|
||||
},
|
||||
NodeType.PARAMETER_EXTRACTOR: {
|
||||
LATEST_VERSION: ParameterExtractorNode,
|
||||
"1": ParameterExtractorNode,
|
||||
},
|
||||
NodeType.VARIABLE_ASSIGNER: {
|
||||
LATEST_VERSION: VariableAssignerNodeV2,
|
||||
"1": VariableAssignerNodeV1,
|
||||
"2": VariableAssignerNodeV2,
|
||||
},
|
||||
NodeType.DOCUMENT_EXTRACTOR: {
|
||||
LATEST_VERSION: DocumentExtractorNode,
|
||||
"1": DocumentExtractorNode,
|
||||
},
|
||||
NodeType.LIST_OPERATOR: {
|
||||
LATEST_VERSION: ListOperatorNode,
|
||||
"1": ListOperatorNode,
|
||||
},
|
||||
node_type_classes_mapping: dict[NodeType, type[BaseNode]] = {
|
||||
NodeType.START: StartNode,
|
||||
NodeType.END: EndNode,
|
||||
NodeType.ANSWER: AnswerNode,
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.IF_ELSE: IfElseNode,
|
||||
NodeType.CODE: CodeNode,
|
||||
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
|
||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
|
||||
NodeType.HTTP_REQUEST: HttpRequestNode,
|
||||
NodeType.TOOL: ToolNode,
|
||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: IterationNode,
|
||||
NodeType.ITERATION_START: IterationStartNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode,
|
||||
NodeType.LIST_OPERATOR: ListOperatorNode,
|
||||
}
|
||||
|
||||
@ -98,7 +98,7 @@ Step 3: Structure the extracted parameters to JSON object as specified in <struc
|
||||
Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted.
|
||||
|
||||
### Memory
|
||||
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
<histories>
|
||||
{histories}
|
||||
</histories>
|
||||
@ -125,7 +125,7 @@ CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and out
|
||||
The structure of the JSON object you can found in the instructions.
|
||||
|
||||
### Memory
|
||||
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
<histories>
|
||||
{histories}
|
||||
</histories>
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .entities import QuestionClassifierNodeData
|
||||
from .question_classifier_node import QuestionClassifierNode
|
||||
|
||||
__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeData"]
|
||||
__all__ = ["QuestionClassifierNodeData", "QuestionClassifierNode"]
|
||||
|
||||
@ -8,7 +8,7 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
### Memory
|
||||
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
<histories>
|
||||
{histories}
|
||||
</histories>
|
||||
@ -66,7 +66,7 @@ User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"
|
||||
Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}}
|
||||
</example>
|
||||
### Memory
|
||||
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
<histories>
|
||||
{histories}
|
||||
</histories>
|
||||
|
||||
@ -0,0 +1,8 @@
|
||||
from .node import VariableAssignerNode
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
__all__ = [
|
||||
"VariableAssignerNode",
|
||||
"VariableAssignerData",
|
||||
"WriteMode",
|
||||
]
|
||||
|
||||
@ -1,4 +0,0 @@
|
||||
class VariableOperatorNodeError(Exception):
|
||||
"""Base error type, don't use directly."""
|
||||
|
||||
pass
|
||||
@ -1,19 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables import Variable
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable
|
||||
|
||||
|
||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
2
api/core/workflow/nodes/variable_assigner/exc.py
Normal file
2
api/core/workflow/nodes/variable_assigner/exc.py
Normal file
@ -0,0 +1,2 @@
|
||||
class VariableAssignerNodeError(Exception):
|
||||
pass
|
||||
@ -1,36 +1,40 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode, BaseNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from models import ConversationVariable
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .exc import VariableAssignerNodeError
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
_node_data_cls: type[BaseNodeData] = VariableAssignerData
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
raise VariableAssignerNodeError("assigned variable not found")
|
||||
|
||||
match self.node_data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
raise VariableAssignerNodeError("input value not found")
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
raise VariableAssignerNodeError("input value not found")
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
updated_variable = original_variable.model_copy(update={"value": updated_value})
|
||||
|
||||
@ -39,7 +43,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||
raise VariableAssignerNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
|
||||
@ -48,8 +52,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
# Update conversation variable.
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
raise VariableOperatorNodeError("conversation_id not found")
|
||||
common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
raise VariableAssignerNodeError("conversation_id not found")
|
||||
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -59,6 +63,18 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
)
|
||||
|
||||
|
||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableAssignerNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_zero_value(t: SegmentType):
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
@ -70,4 +86,4 @@ def get_zero_value(t: SegmentType):
|
||||
case SegmentType.NUMBER:
|
||||
return variable_factory.build_segment(0)
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported variable type: {t}")
|
||||
raise VariableAssignerNodeError(f"unsupported variable type: {t}")
|
||||
@ -1,5 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
@ -11,6 +12,8 @@ class WriteMode(StrEnum):
|
||||
|
||||
|
||||
class VariableAssignerData(BaseNodeData):
|
||||
title: str = "Variable Assigner"
|
||||
desc: Optional[str] = "Assign a value to a variable"
|
||||
assigned_variable_selector: Sequence[str]
|
||||
write_mode: WriteMode
|
||||
input_variable_selector: Sequence[str]
|
||||
@ -1,3 +0,0 @@
|
||||
from .node import VariableAssignerNode
|
||||
|
||||
__all__ = ["VariableAssignerNode"]
|
||||
@ -1,3 +0,0 @@
|
||||
from .node import VariableAssignerNode
|
||||
|
||||
__all__ = ["VariableAssignerNode"]
|
||||
@ -1,11 +0,0 @@
|
||||
from core.variables import SegmentType
|
||||
|
||||
EMPTY_VALUE_MAPPING = {
|
||||
SegmentType.STRING: "",
|
||||
SegmentType.NUMBER: 0,
|
||||
SegmentType.OBJECT: {},
|
||||
SegmentType.ARRAY_ANY: [],
|
||||
SegmentType.ARRAY_STRING: [],
|
||||
SegmentType.ARRAY_NUMBER: [],
|
||||
SegmentType.ARRAY_OBJECT: [],
|
||||
}
|
||||
@ -1,20 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
from .enums import InputType, Operation
|
||||
|
||||
|
||||
class VariableOperationItem(BaseModel):
|
||||
variable_selector: Sequence[str]
|
||||
input_type: InputType
|
||||
operation: Operation
|
||||
value: Any | None = None
|
||||
|
||||
|
||||
class VariableAssignerNodeData(BaseNodeData):
|
||||
version: str = "2"
|
||||
items: Sequence[VariableOperationItem]
|
||||
@ -1,18 +0,0 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class Operation(StrEnum):
|
||||
OVER_WRITE = "over-write"
|
||||
CLEAR = "clear"
|
||||
APPEND = "append"
|
||||
EXTEND = "extend"
|
||||
SET = "set"
|
||||
ADD = "+="
|
||||
SUBTRACT = "-="
|
||||
MULTIPLY = "*="
|
||||
DIVIDE = "/="
|
||||
|
||||
|
||||
class InputType(StrEnum):
|
||||
VARIABLE = "variable"
|
||||
CONSTANT = "constant"
|
||||
@ -1,31 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
|
||||
from .enums import InputType, Operation
|
||||
|
||||
|
||||
class OperationNotSupportedError(VariableOperatorNodeError):
|
||||
def __init__(self, *, operation: Operation, variable_type: str):
|
||||
super().__init__(f"Operation {operation} is not supported for type {variable_type}")
|
||||
|
||||
|
||||
class InputTypeNotSupportedError(VariableOperatorNodeError):
|
||||
def __init__(self, *, input_type: InputType, operation: Operation):
|
||||
super().__init__(f"Input type {input_type} is not supported for operation {operation}")
|
||||
|
||||
|
||||
class VariableNotFoundError(VariableOperatorNodeError):
|
||||
def __init__(self, *, variable_selector: Sequence[str]):
|
||||
super().__init__(f"Variable {variable_selector} not found")
|
||||
|
||||
|
||||
class InvalidInputValueError(VariableOperatorNodeError):
|
||||
def __init__(self, *, value: Any):
|
||||
super().__init__(f"Invalid input value {value}")
|
||||
|
||||
|
||||
class ConversationIDNotFoundError(VariableOperatorNodeError):
|
||||
def __init__(self):
|
||||
super().__init__("conversation_id not found")
|
||||
@ -1,91 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.variables import SegmentType
|
||||
|
||||
from .enums import Operation
|
||||
|
||||
|
||||
def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
match operation:
|
||||
case Operation.OVER_WRITE | Operation.CLEAR:
|
||||
return True
|
||||
case Operation.SET:
|
||||
return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER}
|
||||
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
|
||||
# Only number variable can be added, subtracted, multiplied or divided
|
||||
return variable_type == SegmentType.NUMBER
|
||||
case Operation.APPEND | Operation.EXTEND:
|
||||
# Only array variable can be appended or extended
|
||||
return variable_type in {
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_FILE,
|
||||
}
|
||||
case _:
|
||||
return False
|
||||
|
||||
|
||||
def is_variable_input_supported(*, operation: Operation):
|
||||
if operation in {Operation.SET, Operation.ADD, Operation.SUBTRACT, Operation.MULTIPLY, Operation.DIVIDE}:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
match variable_type:
|
||||
case SegmentType.STRING | SegmentType.OBJECT:
|
||||
return operation in {Operation.OVER_WRITE, Operation.SET}
|
||||
case SegmentType.NUMBER:
|
||||
return operation in {
|
||||
Operation.OVER_WRITE,
|
||||
Operation.SET,
|
||||
Operation.ADD,
|
||||
Operation.SUBTRACT,
|
||||
Operation.MULTIPLY,
|
||||
Operation.DIVIDE,
|
||||
}
|
||||
case _:
|
||||
return False
|
||||
|
||||
|
||||
def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any):
|
||||
if operation == Operation.CLEAR:
|
||||
return True
|
||||
match variable_type:
|
||||
case SegmentType.STRING:
|
||||
return isinstance(value, str)
|
||||
|
||||
case SegmentType.NUMBER:
|
||||
if not isinstance(value, int | float):
|
||||
return False
|
||||
if operation == Operation.DIVIDE and value == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
case SegmentType.OBJECT:
|
||||
return isinstance(value, dict)
|
||||
|
||||
# Array & Append
|
||||
case SegmentType.ARRAY_ANY if operation == Operation.APPEND:
|
||||
return isinstance(value, str | float | int | dict)
|
||||
case SegmentType.ARRAY_STRING if operation == Operation.APPEND:
|
||||
return isinstance(value, str)
|
||||
case SegmentType.ARRAY_NUMBER if operation == Operation.APPEND:
|
||||
return isinstance(value, int | float)
|
||||
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
|
||||
return isinstance(value, dict)
|
||||
|
||||
# Array & Extend / Overwrite
|
||||
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, str | float | int | dict) for item in value)
|
||||
case SegmentType.ARRAY_STRING if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, str) for item in value)
|
||||
case SegmentType.ARRAY_NUMBER if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
|
||||
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
|
||||
|
||||
case _:
|
||||
return False
|
||||
@ -1,159 +0,0 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from . import helpers
|
||||
from .constants import EMPTY_VALUE_MAPPING
|
||||
from .entities import VariableAssignerNodeData
|
||||
from .enums import InputType, Operation
|
||||
from .exc import (
|
||||
ConversationIDNotFoundError,
|
||||
InputTypeNotSupportedError,
|
||||
InvalidInputValueError,
|
||||
OperationNotSupportedError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
inputs = self.node_data.model_dump()
|
||||
process_data = {}
|
||||
# NOTE: This node has no outputs
|
||||
updated_variables: list[Variable] = []
|
||||
|
||||
try:
|
||||
for item in self.node_data.items:
|
||||
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
|
||||
|
||||
# ==================== Validation Part
|
||||
|
||||
# Check if variable exists
|
||||
if not isinstance(variable, Variable):
|
||||
raise VariableNotFoundError(variable_selector=item.variable_selector)
|
||||
|
||||
# Check if operation is supported
|
||||
if not helpers.is_operation_supported(variable_type=variable.value_type, operation=item.operation):
|
||||
raise OperationNotSupportedError(operation=item.operation, variable_type=variable.value_type)
|
||||
|
||||
# Check if variable input is supported
|
||||
if item.input_type == InputType.VARIABLE and not helpers.is_variable_input_supported(
|
||||
operation=item.operation
|
||||
):
|
||||
raise InputTypeNotSupportedError(input_type=InputType.VARIABLE, operation=item.operation)
|
||||
|
||||
# Check if constant input is supported
|
||||
if item.input_type == InputType.CONSTANT and not helpers.is_constant_input_supported(
|
||||
variable_type=variable.value_type, operation=item.operation
|
||||
):
|
||||
raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation)
|
||||
|
||||
# Get value from variable pool
|
||||
if (
|
||||
item.input_type == InputType.VARIABLE
|
||||
and item.operation != Operation.CLEAR
|
||||
and item.value is not None
|
||||
):
|
||||
value = self.graph_runtime_state.variable_pool.get(item.value)
|
||||
if value is None:
|
||||
raise VariableNotFoundError(variable_selector=item.value)
|
||||
# Skip if value is NoneSegment
|
||||
if value.value_type == SegmentType.NONE:
|
||||
continue
|
||||
item.value = value.value
|
||||
|
||||
# If set string / bytes / bytearray to object, try convert string to object.
|
||||
if (
|
||||
item.operation == Operation.SET
|
||||
and variable.value_type == SegmentType.OBJECT
|
||||
and isinstance(item.value, str | bytes | bytearray)
|
||||
):
|
||||
try:
|
||||
item.value = json.loads(item.value)
|
||||
except json.JSONDecodeError:
|
||||
raise InvalidInputValueError(value=item.value)
|
||||
|
||||
# Check if input value is valid
|
||||
if not helpers.is_input_value_valid(
|
||||
variable_type=variable.value_type, operation=item.operation, value=item.value
|
||||
):
|
||||
raise InvalidInputValueError(value=item.value)
|
||||
|
||||
# ==================== Execution Part
|
||||
|
||||
updated_value = self._handle_item(
|
||||
variable=variable,
|
||||
operation=item.operation,
|
||||
value=item.value,
|
||||
)
|
||||
variable = variable.model_copy(update={"value": updated_value})
|
||||
updated_variables.append(variable)
|
||||
except VariableOperatorNodeError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Update variables
|
||||
for variable in updated_variables:
|
||||
self.graph_runtime_state.variable_pool.add(variable.selector, variable)
|
||||
process_data[variable.name] = variable.value
|
||||
|
||||
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
raise ConversationIDNotFoundError
|
||||
else:
|
||||
conversation_id = conversation_id.value
|
||||
common_helpers.update_conversation_variable(
|
||||
conversation_id=conversation_id,
|
||||
variable=variable,
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
|
||||
def _handle_item(
|
||||
self,
|
||||
*,
|
||||
variable: Variable,
|
||||
operation: Operation,
|
||||
value: Any,
|
||||
):
|
||||
match operation:
|
||||
case Operation.OVER_WRITE:
|
||||
return value
|
||||
case Operation.CLEAR:
|
||||
return EMPTY_VALUE_MAPPING[variable.value_type]
|
||||
case Operation.APPEND:
|
||||
return variable.value + [value]
|
||||
case Operation.EXTEND:
|
||||
return variable.value + value
|
||||
case Operation.SET:
|
||||
return value
|
||||
case Operation.ADD:
|
||||
return variable.value + value
|
||||
case Operation.SUBTRACT:
|
||||
return variable.value - value
|
||||
case Operation.MULTIPLY:
|
||||
return variable.value * value
|
||||
case Operation.DIVIDE:
|
||||
return variable.value / value
|
||||
case _:
|
||||
raise OperationNotSupportedError(operation=operation, variable_type=variable.value_type)
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
@ -19,7 +19,7 @@ from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.event import NodeEvent
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from factories import file_factory
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import (
|
||||
@ -145,8 +145,11 @@ class WorkflowEntry:
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(node_config.get("data", {}).get("type"))
|
||||
node_version = node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
node_cls = node_type_classes_mapping.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
if not node_cls:
|
||||
raise ValueError(f"Node class not found for node type {node_type}")
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(environment_variables=workflow.environment_variables)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user