Compare commits

..

33 Commits

Author SHA1 Message Date
ac0d99281e add migration 2024-12-12 09:48:25 +08:00
bbdadec1bc add download file method 2024-12-05 10:10:35 +09:00
fa9709faa8 fork for fta 2024-12-05 10:10:35 +09:00
eca466bdaa chore: fix typo (#11359) 2024-12-05 09:04:30 +08:00
d56abec195 Revert "Fix: iteration not in main thread pool" (#11358) 2024-12-04 21:22:22 +08:00
961e25f608 fix: better bedrock message handler close #10976 (#11317)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-04 19:46:40 +08:00
138bf698b0 chore: translate i18n files (#11353)
Co-authored-by: douxc <7553076+douxc@users.noreply.github.com>
2024-12-04 19:24:03 +08:00
e5bb4cca12 fix: Correct category of 'Workflow' used in Explore Apps. (#11351) 2024-12-04 18:19:12 +08:00
5e2cb0e3a8 feat: add base skeleton component (#11339) 2024-12-04 17:34:55 +08:00
16a65cb367 fix: cannot send message when debug with multiple model with conversa… (#11333) 2024-12-04 16:17:11 +08:00
1bae9b8ff7 update pricing for bedrock nova LLM models (#11336)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
2024-12-04 16:16:41 +08:00
d7c1f43b49 fix tidb full-text-search vector missed (#11337) 2024-12-04 16:13:23 +08:00
f933af9f57 fix: check valid for number variable (#11334) 2024-12-04 15:46:54 +08:00
91e1ff5e30 chore: improve zhipu LLM (#11321) 2024-12-04 15:14:30 +08:00
5908e10549 integrate amazon nove llms to dify (#11324)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
2024-12-04 15:13:08 +08:00
464e6354c5 feat: correct the prompt grammar. (#11328)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-04 15:12:47 +08:00
d470e55f8c fix: http node download file always image type (#11319) 2024-12-04 12:15:26 +08:00
98a1b01b0c fix: file download in chat (#11322) 2024-12-04 11:10:56 +08:00
e240424be5 fix: number variable can not input constant type value in tool config form (#11320) 2024-12-04 10:46:03 +08:00
1cb5a12abb fix: resolve scrolling issue in workflow-log table (#11302) 2024-12-03 21:29:42 +08:00
ff2a4a6fcd Fix: model params in logs (#11298) 2024-12-03 21:17:55 +08:00
c58d2fce89 roll back rerank topn setting (#11297) 2024-12-03 17:34:56 +08:00
7a962b9f03 chore: bump version to 0.13.0 (#11284)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-03 16:01:12 +08:00
a679079a1d fix: auto translate fail (#11286) 2024-12-03 14:21:59 +08:00
e39e776d03 fix: better wenxin rerank handler, close #11252 (#11283)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-03 13:57:16 +08:00
e135ffc2c1 Feat: upgrade variable assigner (#11285)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-12-03 13:56:40 +08:00
e79eac688a chore(lint): sort __all__ definitions (#11243) 2024-12-03 13:26:33 +08:00
643a90c48d fix: use removeprefix() instead of lstrip() to remove the data: prefix (#11272)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-03 09:16:25 +08:00
2a448a899d Fix: iteration not in main thread pool (#11271)
Co-authored-by: Novice Lee <novicelee@NovicedeMacBook-Pro.local>
2024-12-03 09:16:03 +08:00
7b86f8f024 fix: double split error on redis port and some type hint (#11270)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-03 09:15:51 +08:00
e686f12317 fix: better handle error (#11265)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-03 09:15:38 +08:00
a86f1eca79 docs: add api docs for /v1/info (#11269) 2024-12-03 09:14:13 +08:00
668c1c0792 chore(deps): bump cross-spawn from 7.0.3 to 7.0.6 in /web (#11262)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-12-02 17:30:52 +08:00
187 changed files with 4143 additions and 1029 deletions

View File

@ -413,4 +413,3 @@ RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
CREATE_TIDB_SERVICE_JOB_ENABLED=false
RETRIEVAL_TOP_N=0

View File

@ -20,6 +20,8 @@ select = [
"PLC0208", # iteration-over-set
"PLC2801", # unnecessary-dunder-call
"PLC0414", # useless-import-alias
"PLE0604", # invalid-all-object
"PLE0605", # invalid-all-format
"PLR0402", # manual-from-import
"PLR1711", # useless-return
"PLR1714", # repeated-equality-comparison
@ -28,6 +30,7 @@ select = [
"RUF100", # unused-noqa
"RUF101", # redirected-noqa
"RUF200", # invalid-pyproject-toml
"RUF022", # unsorted-dunder-all
"S506", # unsafe-yaml-load
"SIM", # flake8-simplify rules
"TRY400", # error-instead-of-exception

View File

@ -626,8 +626,6 @@ class DataSetConfig(BaseSettings):
default=30,
)
RETRIEVAL_TOP_N: int = Field(description="number of retrieval top_n", default=0)
class WorkspaceConfig(BaseSettings):
"""

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.12.1",
default="0.13.0",
)
COMMIT_SHA: str = Field(

View File

@ -62,6 +62,7 @@ from .datasets import (
external,
hit_testing,
website,
fta_test,
)
# Import explore controllers

View File

@ -100,11 +100,11 @@ class DraftWorkflowApi(Resource):
try:
environment_variables_list = args.get("environment_variables") or []
environment_variables = [
variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = args.get("conversation_variables") or []
conversation_variables = [
variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
workflow = workflow_service.sync_draft_workflow(
app_model=app_model,
@ -382,7 +382,7 @@ class DefaultBlockConfigApi(Resource):
filters = None
if args.get("q"):
try:
filters = json.loads(args.get("q"))
filters = json.loads(args.get("q", ""))
except json.JSONDecodeError:
raise ValueError("Invalid filters")

View File

@ -0,0 +1,145 @@
import json
import requests
from flask import Response
from flask_restful import Resource, reqparse
from sqlalchemy import text
from controllers.console import api
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.fta import ComponentFailure, ComponentFailureStats
class FATTestApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("log_process_data", nullable=False, required=True, type=str, location="args")
args = parser.parse_args()
print(args["log_process_data"])
# Extract the JSON string from the text field
json_str = args["log_process_data"].strip("```json\\n").strip("```").strip().replace("\\n", "")
log_data = json.loads(json_str)
db.session.query(ComponentFailure).delete()
for data in log_data:
if not isinstance(data, dict):
raise TypeError("Data must be a dictionary.")
required_keys = {"Date", "Component", "FailureMode", "Cause", "RepairAction", "Technician"}
if not required_keys.issubset(data.keys()):
raise ValueError(f"Data dictionary must contain the following keys: {required_keys}")
try:
# Clear existing stats
component_failure = ComponentFailure(
Date=data["Date"],
Component=data["Component"],
FailureMode=data["FailureMode"],
Cause=data["Cause"],
RepairAction=data["RepairAction"],
Technician=data["Technician"],
)
db.session.add(component_failure)
db.session.commit()
except Exception as e:
print(e)
# Clear existing stats
db.session.query(ComponentFailureStats).delete()
# Insert calculated statistics
try:
db.session.execute(
text("""
INSERT INTO component_failure_stats ("Component", "FailureMode", "Cause", "PossibleAction", "Probability", "MTBF")
SELECT
cf."Component",
cf."FailureMode",
cf."Cause",
cf."RepairAction" as "PossibleAction",
COUNT(*) * 1.0 / (SELECT COUNT(*) FROM component_failure WHERE "Component" = cf."Component") AS "Probability",
COALESCE(AVG(EXTRACT(EPOCH FROM (next_failure_date::timestamp - cf."Date"::timestamp)) / 86400.0),0)AS "MTBF"
FROM (
SELECT
"Component",
"FailureMode",
"Cause",
"RepairAction",
"Date",
LEAD("Date") OVER (PARTITION BY "Component", "FailureMode", "Cause" ORDER BY "Date") AS next_failure_date
FROM
component_failure
) cf
GROUP BY
cf."Component", cf."FailureMode", cf."Cause", cf."RepairAction";
""")
)
db.session.commit()
except Exception as e:
db.session.rollback()
print(f"Error during stats calculation: {e}")
# output format
# [
# (17, 'Hydraulic system', 'Leak', 'Hose rupture', 'Replaced hydraulic hose', 0.3333333333333333, None),
# (18, 'Hydraulic system', 'Leak', 'Seal Wear', 'Replaced the faulty seal', 0.3333333333333333, None),
# (19, 'Hydraulic system', 'Pressure drop', 'Fluid leak', 'Replaced hydraulic fluid and seals', 0.3333333333333333, None)
# ]
component_failure_stats = db.session.query(ComponentFailureStats).all()
# Convert stats to list of tuples format
stats_list = []
for stat in component_failure_stats:
stats_list.append(
(
stat.StatID,
stat.Component,
stat.FailureMode,
stat.Cause,
stat.PossibleAction,
stat.Probability,
stat.MTBF,
)
)
return {"data": stats_list}, 200
# generate-fault-tree
class GenerateFaultTreeApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("llm_text", nullable=False, required=True, type=str, location="args")
args = parser.parse_args()
entities = args["llm_text"].replace("```", "").replace("\\n", "\n")
print(entities)
request_data = {"fault_tree_text": entities}
url = "https://fta.cognitech-dev.live/generate-fault-tree"
headers = {"accept": "application/json", "Content-Type": "application/json"}
response = requests.post(url, json=request_data, headers=headers)
print(response.json())
return {"data": response.json()}, 200
class ExtractSVGApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("svg_text", nullable=False, required=True, type=str, location="args")
args = parser.parse_args()
# svg_text = ''.join(args["svg_text"].splitlines())
svg_text = args["svg_text"].replace("\n", "")
svg_text = svg_text.replace('"', '"')
print(svg_text)
svg_text_json = json.loads(svg_text)
svg_content = svg_text_json.get("data").get("svg_content")[0]
svg_content = svg_content.replace("\n", "").replace('"', '"')
file_key = "fta_svg/" + "fat.svg"
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, svg_content.encode("utf-8"))
generator = storage.load(file_key, stream=True)
return Response(generator, mimetype="image/svg+xml")
api.add_resource(FATTestApi, "/fta/db-handler")
api.add_resource(GenerateFaultTreeApi, "/fta/generate-fault-tree")
api.add_resource(ExtractSVGApi, "/fta/extract-svg")

View File

@ -43,7 +43,7 @@ from core.workflow.graph_engine.entities.event import (
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
@ -138,7 +138,8 @@ class WorkflowBasedAppRunner(AppRunner):
# Get node class
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping[node_type]
node_version = iteration_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(

View File

@ -7,13 +7,13 @@ from .models import (
)
__all__ = [
"FILE_MODEL_IDENTITY",
"ArrayFileAttribute",
"File",
"FileAttribute",
"FileBelongsTo",
"FileTransferMethod",
"FileType",
"FileUploadConfig",
"FileTransferMethod",
"FileBelongsTo",
"File",
"ImageConfig",
"FileAttribute",
"ArrayFileAttribute",
"FILE_MODEL_IDENTITY",
]

View File

@ -1,4 +1,6 @@
import base64
import tempfile
from pathlib import Path
from configs import dify_config
from core.file import file_repository
@ -18,6 +20,38 @@ from .models import File, FileTransferMethod, FileType
from .tool_file_parser import ToolFileParser
def download_to_target_path(f: File, temp_dir: str, /):
if f.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
suffix = Path(tool_file.file_key).suffix
target_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
_download_file_to_target_path(tool_file.file_key, target_path)
return target_path
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
suffix = Path(upload_file.key).suffix
target_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
_download_file_to_target_path(upload_file.key, target_path)
return target_path
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
def _download_file_to_target_path(path: str, target_path: str, /):
"""
Download and return the contents of a file as bytes.
This function loads the file from storage and ensures it's in bytes format.
Args:
path (str): The path to the file in storage.
target_path (str): The path to the target file.
Raises:
ValueError: If the loaded file is not a bytes object.
"""
storage.download(path, target_path)
def get_attr(*, file: File, attr: FileAttribute):
match attr:
case FileAttribute.TYPE:

View File

@ -18,25 +18,25 @@ from .message_entities import (
from .model_entities import ModelPropertyKey
__all__ = [
"ImagePromptMessageContent",
"VideoPromptMessageContent",
"PromptMessage",
"PromptMessageRole",
"LLMUsage",
"ModelPropertyKey",
"AssistantPromptMessage",
"PromptMessage",
"PromptMessageContent",
"PromptMessageRole",
"SystemPromptMessage",
"TextPromptMessageContent",
"UserPromptMessage",
"PromptMessageTool",
"ToolPromptMessage",
"PromptMessageContentType",
"AudioPromptMessageContent",
"DocumentPromptMessageContent",
"ImagePromptMessageContent",
"LLMResult",
"LLMResultChunk",
"LLMResultChunkDelta",
"AudioPromptMessageContent",
"DocumentPromptMessageContent",
"LLMUsage",
"ModelPropertyKey",
"PromptMessage",
"PromptMessage",
"PromptMessageContent",
"PromptMessageContentType",
"PromptMessageRole",
"PromptMessageRole",
"PromptMessageTool",
"SystemPromptMessage",
"TextPromptMessageContent",
"ToolPromptMessage",
"UserPromptMessage",
"VideoPromptMessageContent",
]

View File

@ -0,0 +1,52 @@
model: amazon.nova-lite-v1:0
label:
en_US: Nova Lite V1
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 300000
parameter_rules:
- name: max_new_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 5000
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.00006'
output: '0.00024'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,52 @@
model: amazon.nova-micro-v1:0
label:
en_US: Nova Micro V1
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: max_new_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 5000
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.000035'
output: '0.00014'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,52 @@
model: amazon.nova-pro-v1:0
label:
en_US: Nova Pro V1
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 300000
parameter_rules:
- name: max_new_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 5000
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.0008'
output: '0.0032'
unit: '0.001'
currency: USD

View File

@ -70,6 +70,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
{"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False},
{"prefix": "ai21.jamba-1-5", "support_system_prompts": True, "support_tool_use": False},
{"prefix": "amazon.nova", "support_system_prompts": True, "support_tool_use": False},
{"prefix": "us.amazon.nova", "support_system_prompts": True, "support_tool_use": False},
]
@staticmethod
@ -194,6 +196,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
if model_info["support_tool_use"] and tools:
parameters["toolConfig"] = self._convert_converse_tool_config(tools=tools)
try:
# for issue #10976
conversations_list = parameters["messages"]
# if two consecutive user messages found, combine them into one message
for i in range(len(conversations_list) - 2, -1, -1):
if conversations_list[i]["role"] == conversations_list[i + 1]["role"]:
conversations_list[i]["content"].extend(conversations_list.pop(i + 1)["content"])
if stream:
response = bedrock_client.converse_stream(**parameters)
return self._handle_converse_stream_response(

View File

@ -0,0 +1,52 @@
model: us.amazon.nova-lite-v1:0
label:
en_US: Nova Lite V1 (US.Cross Region Inference)
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 300000
parameter_rules:
- name: max_new_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 5000
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.00006'
output: '0.00024'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,52 @@
model: us.amazon.nova-micro-v1:0
label:
en_US: Nova Micro V1 (US.Cross Region Inference)
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: max_new_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 5000
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.000035'
output: '0.00014'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,52 @@
model: us.amazon.nova-pro-v1:0
label:
en_US: Nova Pro V1 (US.Cross Region Inference)
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 300000
parameter_rules:
- name: max_new_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 5000
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.0008'
output: '0.0032'
unit: '0.001'
currency: USD

View File

@ -252,7 +252,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
# ignore sse comments
if chunk.startswith(":"):
continue
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
decoded_chunk = chunk.strip().removeprefix("data: ")
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)

View File

@ -462,7 +462,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
# ignore sse comments
if chunk.startswith(":"):
continue
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
decoded_chunk = chunk.strip().removeprefix("data: ")
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
continue

View File

@ -250,7 +250,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel):
# ignore sse comments
if chunk.startswith(":"):
continue
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
decoded_chunk = chunk.strip().removeprefix("data: ")
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)

View File

@ -1,4 +1,4 @@
from .common import ChatRole
from .maas import MaasError, MaasService
__all__ = ["MaasService", "ChatRole", "MaasError"]
__all__ = ["ChatRole", "MaasError", "MaasService"]

View File

@ -17,7 +17,13 @@ class WenxinRerank(_CommonWenxin):
def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None):
access_token = self._get_access_token()
url = f"{self.api_bases[model]}?access_token={access_token}"
# For issue #11252
# for wenxin Rerank model top_n length should be equal or less than docs length
if top_n is not None and top_n > len(docs):
top_n = len(docs)
# for wenxin Rerank model, query should not be an empty string
if query == "":
query = " " # FIXME: this is a workaround for wenxin rerank model for better user experience.
try:
response = httpx.post(
url,
@ -25,7 +31,11 @@ class WenxinRerank(_CommonWenxin):
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return response.json()
data = response.json()
# wenxin error handling
if "error_code" in data:
raise InternalServerError(data["error_msg"])
return data
except httpx.HTTPStatusError as e:
raise InternalServerError(str(e))
@ -69,6 +79,9 @@ class WenxinRerankModel(RerankModel):
results = wenxin_rerank.rerank(model, query, docs, top_n)
rerank_documents = []
if "results" not in results:
raise ValueError("results key not found in response")
for result in results["results"]:
index = result["index"]
if "document" in result:

View File

@ -8,6 +8,7 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -8,6 +8,7 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -8,6 +8,7 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -8,6 +8,7 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -8,6 +8,7 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -8,6 +8,7 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -8,6 +8,7 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -8,7 +8,7 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 10240
context_size: 1048576
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -8,6 +8,7 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -4,6 +4,7 @@ label:
model_type: llm
model_properties:
mode: chat
context_size: 2048
features:
- vision
parameter_rules:

View File

@ -4,6 +4,7 @@ label:
model_type: llm
model_properties:
mode: chat
context_size: 8192
features:
- vision
- video

View File

@ -22,18 +22,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
from core.model_runtime.utils import helper
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
The structure of the JSON object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
And you should always end the block with a "```" to indicate the end of the JSON object.
<instructions>
{{instructions}}
</instructions>
```JSON""" # noqa: E501
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
def _invoke(
@ -64,42 +52,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
credentials_kwargs = self._to_credential_kwargs(credentials)
# invoke model
# stop = stop or []
# self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
# def _transform_json_prompts(self, model: str, credentials: dict,
# prompt_messages: list[PromptMessage], model_parameters: dict,
# tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
# stream: bool = True, user: str | None = None) \
# -> None:
# """
# Transform json prompts to model prompts
# """
# if "}\n\n" not in stop:
# stop.append("}\n\n")
# # check if there is a system message
# if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# # override the system message
# prompt_messages[0] = SystemPromptMessage(
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content)
# )
# else:
# # insert the system message
# prompt_messages.insert(0, SystemPromptMessage(
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", "Please output a valid JSON object.")
# ))
# # check if the last message is a user message
# if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# # add ```JSON\n to the last message
# prompt_messages[-1].content += "\n```JSON\n"
# else:
# # append a user message
# prompt_messages.append(UserPromptMessage(
# content="```JSON\n"
# ))
def get_num_tokens(
self,
model: str,
@ -170,7 +124,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
extra_model_kwargs = {}
# request to glm-4v-plus with stop words will always response "finish_reason":"network_error"
# request to glm-4v-plus with stop words will always respond "finish_reason":"network_error"
if stop and model != "glm-4v-plus":
extra_model_kwargs["stop"] = stop
@ -186,7 +140,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
# resolve zhipuai model not support system message and user message, assistant message must be in sequence
new_prompt_messages: list[PromptMessage] = []
for prompt_message in prompt_messages:
copy_prompt_message = prompt_message.copy()
copy_prompt_message = prompt_message.model_copy()
if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}:
if isinstance(copy_prompt_message.content, list):
# check if model is 'glm-4v'
@ -238,59 +192,38 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
else:
params = {"model": model, "messages": [], **model_parameters}
# glm model
if not model.startswith("chatglm"):
for prompt_message in new_prompt_messages:
if prompt_message.role == PromptMessageRole.TOOL:
for prompt_message in new_prompt_messages:
if prompt_message.role == PromptMessageRole.TOOL:
params["messages"].append(
{
"role": "tool",
"content": prompt_message.content,
"tool_call_id": prompt_message.tool_call_id,
}
)
elif isinstance(prompt_message, AssistantPromptMessage):
if prompt_message.tool_calls:
params["messages"].append(
{
"role": "tool",
"role": "assistant",
"content": prompt_message.content,
"tool_call_id": prompt_message.tool_call_id,
"tool_calls": [
{
"id": tool_call.id,
"type": tool_call.type,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
}
for tool_call in prompt_message.tool_calls
],
}
)
elif isinstance(prompt_message, AssistantPromptMessage):
if prompt_message.tool_calls:
params["messages"].append(
{
"role": "assistant",
"content": prompt_message.content,
"tool_calls": [
{
"id": tool_call.id,
"type": tool_call.type,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
}
for tool_call in prompt_message.tool_calls
],
}
)
else:
params["messages"].append({"role": "assistant", "content": prompt_message.content})
else:
params["messages"].append(
{"role": prompt_message.role.value, "content": prompt_message.content}
)
else:
# chatglm model
for prompt_message in new_prompt_messages:
# merge system message to user message
if prompt_message.role in {
PromptMessageRole.SYSTEM,
PromptMessageRole.TOOL,
PromptMessageRole.USER,
}:
if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user":
params["messages"][-1]["content"] += "\n\n" + prompt_message.content
else:
params["messages"].append({"role": "user", "content": prompt_message.content})
else:
params["messages"].append(
{"role": prompt_message.role.value, "content": prompt_message.content}
)
params["messages"].append({"role": "assistant", "content": prompt_message.content})
else:
params["messages"].append({"role": prompt_message.role.value, "content": prompt_message.content})
if tools and len(tools) > 0:
params["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools]
@ -406,7 +339,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
Handle llm stream response
:param model: model name
:param response: response
:param responses: response
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
@ -505,7 +438,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
if tools and len(tools) > 0:
text += "\n\nTools:"
for tool in tools:
text += f"\n{tool.json()}"
text += f"\n{tool.model_dump_json()}"
# trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip()

View File

@ -5,7 +5,7 @@ BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找
CHAT_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: " # noqa: E501
"text": "{{#pre_prompt#}}\nHere are the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: " # noqa: E501
},
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
},

View File

@ -3,7 +3,6 @@ from typing import Optional
from flask import Flask, current_app
from configs import DifyConfig
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
@ -114,7 +113,7 @@ class RetrievalService:
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=DifyConfig.RETRIEVAL_TOP_N or top_k,
top_n=top_k,
)
return all_documents
@ -186,7 +185,7 @@ class RetrievalService:
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=DifyConfig.RETRIEVAL_TOP_N or len(documents),
top_n=len(documents),
)
)
else:
@ -231,7 +230,7 @@ class RetrievalService:
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=DifyConfig.RETRIEVAL_TOP_N or len(documents),
top_n=len(documents),
)
)
else:

View File

@ -104,8 +104,7 @@ class OceanBaseVector(BaseVector):
val = int(row[6])
vals.append(val)
if len(vals) == 0:
print("ob_vector_memory_limit_percentage not found in parameters.")
exit(1)
raise ValueError("ob_vector_memory_limit_percentage not found in parameters.")
if any(val == 0 for val in vals):
try:
self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30")
@ -200,10 +199,10 @@ class OceanBaseVectorFactory(AbstractVectorFactory):
return OceanBaseVector(
collection_name,
OceanBaseVectorConfig(
host=dify_config.OCEANBASE_VECTOR_HOST,
port=dify_config.OCEANBASE_VECTOR_PORT,
user=dify_config.OCEANBASE_VECTOR_USER,
host=dify_config.OCEANBASE_VECTOR_HOST or "",
port=dify_config.OCEANBASE_VECTOR_PORT or 0,
user=dify_config.OCEANBASE_VECTOR_USER or "",
password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""),
database=dify_config.OCEANBASE_VECTOR_DATABASE,
database=dify_config.OCEANBASE_VECTOR_DATABASE or "",
),
)

View File

@ -375,7 +375,6 @@ class TidbOnQdrantVector(BaseVector):
for result in results:
if result:
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
document.metadata["vector"] = result.vector
documents.append(document)
return documents
@ -394,6 +393,7 @@ class TidbOnQdrantVector(BaseVector):
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
vector=scored_point.vector,
metadata=scored_point.payload.get(metadata_payload_key) or {},
)

View File

@ -15,7 +15,7 @@ class ComfyUIProvider(BuiltinToolProviderController):
try:
ws.connect(ws_address)
except Exception:
except Exception as e:
raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}")
finally:
ws.close()

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.3 KiB

View File

@ -0,0 +1,8 @@
from typing import Any
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class FileExtractorProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
pass

View File

@ -0,0 +1,15 @@
identity:
author: Jyong
name: file_extractor
label:
en_US: File Extractor
zh_Hans: 文件提取
pt_BR: File Extractor
description:
en_US: Extract text from file
zh_Hans: 从文件中提取文本
pt_BR: Extract text from file
icon: icon.png
tags:
- utilities
- productivity

View File

@ -0,0 +1,45 @@
import tempfile
from typing import Any, Union
from core.file.enums import FileType
from core.file.file_manager import download_to_target_path
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.splitter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolParameterValidationError
from core.tools.tool.builtin_tool import BuiltinTool
class FileExtractorTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
# image file for workflow mode
file = tool_parameters.get("text_file")
if file and file.type != FileType.DOCUMENT:
raise ToolParameterValidationError("Not a valid document")
if file:
with tempfile.TemporaryDirectory() as temp_dir:
file_path = download_to_target_path(file, temp_dir)
extractor = TextExtractor(file_path, autodetect_encoding=True)
documents = extractor.extract()
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=tool_parameters.get("max_token", 500),
chunk_overlap=0,
fixed_separator=tool_parameters.get("separator", "\n\n"),
separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=None,
)
chunks = character_splitter.split_documents(documents)
content = "\n".join([chunk.page_content for chunk in chunks])
return self.create_text_message(content)
else:
raise ToolParameterValidationError("Please provide either file")

View File

@ -0,0 +1,49 @@
identity:
name: text extractor
author: Jyong
label:
en_US: Text extractor
zh_Hans: Text 文本解析
description:
en_US: Extract content from text file and support split to chunks by split characters and token length
zh_Hans: 支持从文本文件中提取内容并支持通过分割字符和令牌长度分割成块
pt_BR: Extract content from text file and support split to chunks by split characters and token length
description:
human:
en_US: Text extractor is a text extract tool
zh_Hans: Text extractor 是一个文本提取工具
pt_BR: Text extractor is a text extract tool
llm: Text extractor is a tool used to extract text file
parameters:
- name: text_file
type: file
label:
en_US: Text file
human_description:
en_US: The text file to be extracted.
zh_Hans: 要提取的 text 文档。
llm_description: you should not input this parameter. just input the image_id.
form: llm
- name: separator
type: string
required: false
label:
en_US: split character
zh_Hans: 分隔符号
human_description:
en_US: Text content split character
zh_Hans: 用于文档分隔的符号
llm_description: it is used for split content to chunks
form: form
- name: max_token
type: number
required: false
label:
en_US: Maximum chunk length
zh_Hans: 最大分段长度
human_description:
en_US: Maximum chunk length
zh_Hans: 最大分段长度
llm_description: it is used for limit chunk's max length
form: form

View File

@ -6,9 +6,9 @@ identity:
zh_Hans: GitLab 合并请求查询
description:
human:
en_US: A tool for query GitLab merge requests, Input should be a exists reposity or branch.
en_US: A tool for query GitLab merge requests, Input should be a exists repository or branch.
zh_Hans: 一个用于查询 GitLab 代码合并请求的工具,输入的内容应该是一个已存在的仓库名或者分支。
llm: A tool for query GitLab merge requests, Input should be a exists reposity or branch.
llm: A tool for query GitLab merge requests, Input should be a exists repository or branch.
parameters:
- name: repository
type: string

View File

@ -32,32 +32,32 @@ from .variables import (
)
__all__ = [
"IntegerVariable",
"FloatVariable",
"ObjectVariable",
"SecretVariable",
"StringVariable",
"ArrayAnyVariable",
"Variable",
"SegmentType",
"SegmentGroup",
"Segment",
"NoneSegment",
"NoneVariable",
"IntegerSegment",
"FloatSegment",
"ObjectSegment",
"ArrayAnySegment",
"StringSegment",
"ArrayStringVariable",
"ArrayAnyVariable",
"ArrayFileSegment",
"ArrayFileVariable",
"ArrayNumberSegment",
"ArrayNumberVariable",
"ArrayObjectSegment",
"ArrayObjectVariable",
"ArraySegment",
"ArrayFileSegment",
"ArrayNumberSegment",
"ArrayObjectSegment",
"ArrayStringSegment",
"ArrayStringVariable",
"FileSegment",
"FileVariable",
"ArrayFileVariable",
"FloatSegment",
"FloatVariable",
"IntegerSegment",
"IntegerVariable",
"NoneSegment",
"NoneVariable",
"ObjectSegment",
"ObjectVariable",
"SecretVariable",
"Segment",
"SegmentGroup",
"SegmentType",
"StringSegment",
"StringVariable",
"Variable",
]

View File

@ -2,16 +2,19 @@ from enum import StrEnum
class SegmentType(StrEnum):
NONE = "none"
NUMBER = "number"
STRING = "string"
OBJECT = "object"
SECRET = "secret"
FILE = "file"
ARRAY_ANY = "array[any]"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
OBJECT = "object"
FILE = "file"
ARRAY_FILE = "array[file]"
NONE = "none"
GROUP = "group"

View File

@ -2,6 +2,6 @@ from .base_workflow_callback import WorkflowCallback
from .workflow_logging_callback import WorkflowLoggingCallback
__all__ = [
"WorkflowLoggingCallback",
"WorkflowCallback",
"WorkflowLoggingCallback",
]

View File

@ -38,7 +38,7 @@ from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProce
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
@ -227,7 +227,8 @@ class GraphEngine:
# convert to specific node
node_type = NodeType(node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping[node_type]
node_version = node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None

View File

@ -1,4 +1,4 @@
from .answer_node import AnswerNode
from .entities import AnswerStreamGenerateRoute
__all__ = ["AnswerStreamGenerateRoute", "AnswerNode"]
__all__ = ["AnswerNode", "AnswerStreamGenerateRoute"]

View File

@ -153,7 +153,7 @@ class AnswerStreamGeneratorRouter:
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.CONVERSATION_VARIABLE_ASSIGNER,
NodeType.VARIABLE_ASSIGNER,
}:
answer_dependencies[answer_node_id].append(source_node_id)
else:

View File

@ -1,4 +1,4 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
from .node import BaseNode
__all__ = ["BaseNode", "BaseNodeData", "BaseIterationNodeData", "BaseIterationState"]
__all__ = ["BaseIterationNodeData", "BaseIterationState", "BaseNode", "BaseNodeData"]

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
version: str = "1"
class BaseIterationNodeData(BaseNodeData):

View File

@ -55,7 +55,9 @@ class BaseNode(Generic[GenericNodeData]):
raise ValueError("Node ID is required.")
self.node_id = node_id
self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {})))
node_data = self._node_data_cls.model_validate(config.get("data", {}))
self.node_data = cast(GenericNodeData, node_data)
@abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:

View File

@ -1,4 +1,4 @@
from .end_node import EndNode
from .entities import EndStreamParam
__all__ = ["EndStreamParam", "EndNode"]
__all__ = ["EndNode", "EndStreamParam"]

View File

@ -14,11 +14,11 @@ class NodeType(StrEnum):
HTTP_REQUEST = "http-request"
TOOL = "tool"
VARIABLE_AGGREGATOR = "variable-aggregator"
VARIABLE_ASSIGNER = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor"
CONVERSATION_VARIABLE_ASSIGNER = "assigner"
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"

View File

@ -2,9 +2,9 @@ from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverRes
from .types import NodeEvent
__all__ = [
"ModelInvokeCompletedEvent",
"NodeEvent",
"RunCompletedEvent",
"RunRetrieverResourceEvent",
"RunStreamChunkEvent",
"NodeEvent",
"ModelInvokeCompletedEvent",
]

View File

@ -1,4 +1,4 @@
from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData
from .node import HttpRequestNode
__all__ = ["HttpRequestNodeData", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "BodyData", "HttpRequestNode"]
__all__ = ["BodyData", "HttpRequestNode", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "HttpRequestNodeData"]

View File

@ -1,11 +1,9 @@
import logging
from collections.abc import Mapping, Sequence
from mimetypes import guess_extension
from os import path
from typing import Any
from configs import dify_config
from core.file import File, FileTransferMethod, FileType
from core.file import File, FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
@ -150,11 +148,6 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
content = response.content
if is_file and content_type:
# extract filename from url
filename = path.basename(url)
# extract extension if possible
extension = guess_extension(content_type) or ".bin"
tool_file = ToolFileManager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
@ -165,7 +158,6 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
mapping = {
"tool_file_id": tool_file.id,
"type": FileType.IMAGE.value,
"transfer_method": FileTransferMethod.TOOL_FILE.value,
}
file = file_factory.build_from_mapping(

View File

@ -116,7 +116,7 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
@ -162,8 +162,7 @@ class IterationNode(BaseNode[IterationNodeData]):
if self.node_data.is_parallel:
futures: list[Future] = []
q = Queue()
thread_pool = graph_engine.workflow_thread_pool_mapping[self.thread_pool_id]
thread_pool._max_workers = self.node_data.parallel_nums
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
self._run_single_iter_parallel,
@ -236,10 +235,7 @@ class IterationNode(BaseNode[IterationNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": jsonable_encoder(outputs)},
metadata={
NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
},
metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map},
)
)
except IterationNodeError as e:
@ -262,7 +258,6 @@ class IterationNode(BaseNode[IterationNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
)
)
finally:
@ -302,12 +297,13 @@ class IterationNode(BaseNode[IterationNodeData]):
# variable selector to variable mapping
try:
# Get node class
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping.get(node_type)
if not node_cls:
if node_type not in NODE_TYPE_CLASSES_MAPPING:
continue
node_version = sub_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config

View File

@ -815,7 +815,7 @@ class LLMNode(BaseNode[LLMNodeData]):
"completion_model": {
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
"prompt": {
"text": "Here is the chat histories between human and assistant, inside "
"text": "Here are the chat histories between human and assistant, inside "
"<histories></histories> XML tags.\n\n<histories>\n{{"
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
"edition_type": "basic",

View File

@ -1,3 +1,5 @@
from collections.abc import Mapping
from core.workflow.nodes.answer import AnswerNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.code import CodeNode
@ -16,26 +18,87 @@ from core.workflow.nodes.start import StartNode
from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from core.workflow.nodes.variable_assigner import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
node_type_classes_mapping: dict[NodeType, type[BaseNode]] = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.ITERATION_START: IterationStartNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode,
NodeType.LIST_OPERATOR: ListOperatorNode,
LATEST_VERSION = "latest"
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,
"1": StartNode,
},
NodeType.END: {
LATEST_VERSION: EndNode,
"1": EndNode,
},
NodeType.ANSWER: {
LATEST_VERSION: AnswerNode,
"1": AnswerNode,
},
NodeType.LLM: {
LATEST_VERSION: LLMNode,
"1": LLMNode,
},
NodeType.KNOWLEDGE_RETRIEVAL: {
LATEST_VERSION: KnowledgeRetrievalNode,
"1": KnowledgeRetrievalNode,
},
NodeType.IF_ELSE: {
LATEST_VERSION: IfElseNode,
"1": IfElseNode,
},
NodeType.CODE: {
LATEST_VERSION: CodeNode,
"1": CodeNode,
},
NodeType.TEMPLATE_TRANSFORM: {
LATEST_VERSION: TemplateTransformNode,
"1": TemplateTransformNode,
},
NodeType.QUESTION_CLASSIFIER: {
LATEST_VERSION: QuestionClassifierNode,
"1": QuestionClassifierNode,
},
NodeType.HTTP_REQUEST: {
LATEST_VERSION: HttpRequestNode,
"1": HttpRequestNode,
},
NodeType.TOOL: {
LATEST_VERSION: ToolNode,
"1": ToolNode,
},
NodeType.VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
},
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
}, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: {
LATEST_VERSION: IterationNode,
"1": IterationNode,
},
NodeType.ITERATION_START: {
LATEST_VERSION: IterationStartNode,
"1": IterationStartNode,
},
NodeType.PARAMETER_EXTRACTOR: {
LATEST_VERSION: ParameterExtractorNode,
"1": ParameterExtractorNode,
},
NodeType.VARIABLE_ASSIGNER: {
LATEST_VERSION: VariableAssignerNodeV2,
"1": VariableAssignerNodeV1,
"2": VariableAssignerNodeV2,
},
NodeType.DOCUMENT_EXTRACTOR: {
LATEST_VERSION: DocumentExtractorNode,
"1": DocumentExtractorNode,
},
NodeType.LIST_OPERATOR: {
LATEST_VERSION: ListOperatorNode,
"1": ListOperatorNode,
},
}

View File

@ -98,7 +98,7 @@ Step 3: Structure the extracted parameters to JSON object as specified in <struc
Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted.
### Memory
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
@ -125,7 +125,7 @@ CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and out
The structure of the JSON object you can found in the instructions.
### Memory
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>

View File

@ -1,4 +1,4 @@
from .entities import QuestionClassifierNodeData
from .question_classifier_node import QuestionClassifierNode
__all__ = ["QuestionClassifierNodeData", "QuestionClassifierNode"]
__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeData"]

View File

@ -8,7 +8,7 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
### Constraint
DO NOT include anything other than the JSON array in your response.
### Memory
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
@ -66,7 +66,7 @@ User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"
Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}}
</example>
### Memory
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>

View File

@ -1,8 +0,0 @@
from .node import VariableAssignerNode
from .node_data import VariableAssignerData, WriteMode
__all__ = [
"VariableAssignerNode",
"VariableAssignerData",
"WriteMode",
]

View File

@ -0,0 +1,4 @@
class VariableOperatorNodeError(Exception):
"""Base error type, don't use directly."""
pass

View File

@ -0,0 +1,19 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables import Variable
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from extensions.ext_database import db
from models import ConversationVariable
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()

View File

@ -1,2 +0,0 @@
class VariableAssignerNodeError(Exception):
pass

View File

@ -0,0 +1,3 @@
from .node import VariableAssignerNode
__all__ = ["VariableAssignerNode"]

View File

@ -1,40 +1,36 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables import SegmentType, Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode, BaseNodeData
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory
from models import ConversationVariable
from models.workflow import WorkflowNodeExecutionStatus
from .exc import VariableAssignerNodeError
from .node_data import VariableAssignerData, WriteMode
class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
_node_type = NodeType.VARIABLE_ASSIGNER
def _run(self) -> NodeRunResult:
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError("assigned variable not found")
raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError("input value not found")
raise VariableOperatorNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError("input value not found")
raise VariableOperatorNodeError("input value not found")
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={"value": updated_value})
@ -43,7 +39,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableAssignerNodeError(f"unsupported write mode: {self.node_data.write_mode}")
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
@ -52,8 +48,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
# Update conversation variable.
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise VariableAssignerNodeError("conversation_id not found")
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
raise VariableOperatorNodeError("conversation_id not found")
common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -63,18 +59,6 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
)
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableAssignerNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
def get_zero_value(t: SegmentType):
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
@ -86,4 +70,4 @@ def get_zero_value(t: SegmentType):
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f"unsupported variable type: {t}")
raise VariableOperatorNodeError(f"unsupported variable type: {t}")

View File

@ -1,6 +1,5 @@
from collections.abc import Sequence
from enum import StrEnum
from typing import Optional
from core.workflow.nodes.base import BaseNodeData
@ -12,8 +11,6 @@ class WriteMode(StrEnum):
class VariableAssignerData(BaseNodeData):
title: str = "Variable Assigner"
desc: Optional[str] = "Assign a value to a variable"
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]

View File

@ -0,0 +1,3 @@
from .node import VariableAssignerNode
__all__ = ["VariableAssignerNode"]

View File

@ -0,0 +1,11 @@
from core.variables import SegmentType
EMPTY_VALUE_MAPPING = {
SegmentType.STRING: "",
SegmentType.NUMBER: 0,
SegmentType.OBJECT: {},
SegmentType.ARRAY_ANY: [],
SegmentType.ARRAY_STRING: [],
SegmentType.ARRAY_NUMBER: [],
SegmentType.ARRAY_OBJECT: [],
}

View File

@ -0,0 +1,20 @@
from collections.abc import Sequence
from typing import Any
from pydantic import BaseModel
from core.workflow.nodes.base import BaseNodeData
from .enums import InputType, Operation
class VariableOperationItem(BaseModel):
variable_selector: Sequence[str]
input_type: InputType
operation: Operation
value: Any | None = None
class VariableAssignerNodeData(BaseNodeData):
version: str = "2"
items: Sequence[VariableOperationItem]

View File

@ -0,0 +1,18 @@
from enum import StrEnum
class Operation(StrEnum):
OVER_WRITE = "over-write"
CLEAR = "clear"
APPEND = "append"
EXTEND = "extend"
SET = "set"
ADD = "+="
SUBTRACT = "-="
MULTIPLY = "*="
DIVIDE = "/="
class InputType(StrEnum):
VARIABLE = "variable"
CONSTANT = "constant"

View File

@ -0,0 +1,31 @@
from collections.abc import Sequence
from typing import Any
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from .enums import InputType, Operation
class OperationNotSupportedError(VariableOperatorNodeError):
def __init__(self, *, operation: Operation, varialbe_type: str):
super().__init__(f"Operation {operation} is not supported for type {varialbe_type}")
class InputTypeNotSupportedError(VariableOperatorNodeError):
def __init__(self, *, input_type: InputType, operation: Operation):
super().__init__(f"Input type {input_type} is not supported for operation {operation}")
class VariableNotFoundError(VariableOperatorNodeError):
def __init__(self, *, variable_selector: Sequence[str]):
super().__init__(f"Variable {variable_selector} not found")
class InvalidInputValueError(VariableOperatorNodeError):
def __init__(self, *, value: Any):
super().__init__(f"Invalid input value {value}")
class ConversationIDNotFoundError(VariableOperatorNodeError):
def __init__(self):
super().__init__("conversation_id not found")

View File

@ -0,0 +1,91 @@
from typing import Any
from core.variables import SegmentType
from .enums import Operation
def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
match operation:
case Operation.OVER_WRITE | Operation.CLEAR:
return True
case Operation.SET:
return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER}
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
# Only number variable can be added, subtracted, multiplied or divided
return variable_type == SegmentType.NUMBER
case Operation.APPEND | Operation.EXTEND:
# Only array variable can be appended or extended
return variable_type in {
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
case _:
return False
def is_variable_input_supported(*, operation: Operation):
if operation in {Operation.SET, Operation.ADD, Operation.SUBTRACT, Operation.MULTIPLY, Operation.DIVIDE}:
return False
return True
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
match variable_type:
case SegmentType.STRING | SegmentType.OBJECT:
return operation in {Operation.OVER_WRITE, Operation.SET}
case SegmentType.NUMBER:
return operation in {
Operation.OVER_WRITE,
Operation.SET,
Operation.ADD,
Operation.SUBTRACT,
Operation.MULTIPLY,
Operation.DIVIDE,
}
case _:
return False
def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any):
if operation == Operation.CLEAR:
return True
match variable_type:
case SegmentType.STRING:
return isinstance(value, str)
case SegmentType.NUMBER:
if not isinstance(value, int | float):
return False
if operation == Operation.DIVIDE and value == 0:
return False
return True
case SegmentType.OBJECT:
return isinstance(value, dict)
# Array & Append
case SegmentType.ARRAY_ANY if operation == Operation.APPEND:
return isinstance(value, str | float | int | dict)
case SegmentType.ARRAY_STRING if operation == Operation.APPEND:
return isinstance(value, str)
case SegmentType.ARRAY_NUMBER if operation == Operation.APPEND:
return isinstance(value, int | float)
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
return isinstance(value, dict)
# Array & Extend / Overwrite
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, str | float | int | dict) for item in value)
case SegmentType.ARRAY_STRING if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, str) for item in value)
case SegmentType.ARRAY_NUMBER if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
case _:
return False

View File

@ -0,0 +1,159 @@
import json
from typing import Any
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from models.workflow import WorkflowNodeExecutionStatus
from . import helpers
from .constants import EMPTY_VALUE_MAPPING
from .entities import VariableAssignerNodeData
from .enums import InputType, Operation
from .exc import (
ConversationIDNotFoundError,
InputTypeNotSupportedError,
InvalidInputValueError,
OperationNotSupportedError,
VariableNotFoundError,
)
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER
def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump()
process_data = {}
# NOTE: This node has no outputs
updated_variables: list[Variable] = []
try:
for item in self.node_data.items:
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
# ==================== Validation Part
# Check if variable exists
if not isinstance(variable, Variable):
raise VariableNotFoundError(variable_selector=item.variable_selector)
# Check if operation is supported
if not helpers.is_operation_supported(variable_type=variable.value_type, operation=item.operation):
raise OperationNotSupportedError(operation=item.operation, varialbe_type=variable.value_type)
# Check if variable input is supported
if item.input_type == InputType.VARIABLE and not helpers.is_variable_input_supported(
operation=item.operation
):
raise InputTypeNotSupportedError(input_type=InputType.VARIABLE, operation=item.operation)
# Check if constant input is supported
if item.input_type == InputType.CONSTANT and not helpers.is_constant_input_supported(
variable_type=variable.value_type, operation=item.operation
):
raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation)
# Get value from variable pool
if (
item.input_type == InputType.VARIABLE
and item.operation != Operation.CLEAR
and item.value is not None
):
value = self.graph_runtime_state.variable_pool.get(item.value)
if value is None:
raise VariableNotFoundError(variable_selector=item.value)
# Skip if value is NoneSegment
if value.value_type == SegmentType.NONE:
continue
item.value = value.value
# If set string / bytes / bytearray to object, try convert string to object.
if (
item.operation == Operation.SET
and variable.value_type == SegmentType.OBJECT
and isinstance(item.value, str | bytes | bytearray)
):
try:
item.value = json.loads(item.value)
except json.JSONDecodeError:
raise InvalidInputValueError(value=item.value)
# Check if input value is valid
if not helpers.is_input_value_valid(
variable_type=variable.value_type, operation=item.operation, value=item.value
):
raise InvalidInputValueError(value=item.value)
# ==================== Execution Part
updated_value = self._handle_item(
variable=variable,
operation=item.operation,
value=item.value,
)
variable = variable.model_copy(update={"value": updated_value})
updated_variables.append(variable)
except VariableOperatorNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
error=str(e),
)
# Update variables
for variable in updated_variables:
self.graph_runtime_state.variable_pool.add(variable.selector, variable)
process_data[variable.name] = variable.value
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise ConversationIDNotFoundError
else:
conversation_id = conversation_id.value
common_helpers.update_conversation_variable(
conversation_id=conversation_id,
variable=variable,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
)
def _handle_item(
self,
*,
variable: Variable,
operation: Operation,
value: Any,
):
match operation:
case Operation.OVER_WRITE:
return value
case Operation.CLEAR:
return EMPTY_VALUE_MAPPING[variable.value_type]
case Operation.APPEND:
return variable.value + [value]
case Operation.EXTEND:
return variable.value + value
case Operation.SET:
return value
case Operation.ADD:
return variable.value + value
case Operation.SUBTRACT:
return variable.value - value
case Operation.MULTIPLY:
return variable.value * value
case Operation.DIVIDE:
return variable.value / value
case _:
raise OperationNotSupportedError(operation=operation, varialbe_type=variable.value_type)

View File

@ -2,7 +2,7 @@ import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
@ -19,7 +19,7 @@ from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from factories import file_factory
from models.enums import UserFrom
from models.workflow import (
@ -145,11 +145,8 @@ class WorkflowEntry:
# Get node class
node_type = NodeType(node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping.get(node_type)
node_cls = cast(type[BaseNode], node_cls)
if not node_cls:
raise ValueError(f"Node class not found for node type {node_type}")
node_version = node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(environment_variables=workflow.environment_variables)

View File

@ -1,3 +1,5 @@
from typing import Any, Union
import redis
from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
@ -46,11 +48,11 @@ redis_client = RedisClientWrapper()
def init_app(app: DifyApp):
global redis_client
connection_class = Connection
connection_class: type[Union[Connection, SSLConnection]] = Connection
if dify_config.REDIS_USE_SSL:
connection_class = SSLConnection
redis_params = {
redis_params: dict[str, Any] = {
"username": dify_config.REDIS_USERNAME,
"password": dify_config.REDIS_PASSWORD,
"db": dify_config.REDIS_DB,
@ -60,6 +62,7 @@ def init_app(app: DifyApp):
}
if dify_config.REDIS_USE_SENTINEL:
assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True"
sentinel_hosts = [
(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")
]
@ -74,11 +77,13 @@ def init_app(app: DifyApp):
master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
redis_client.initialize(master)
elif dify_config.REDIS_USE_CLUSTERS:
assert dify_config.REDIS_CLUSTERS is not None, "REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True"
nodes = [
ClusterNode(host=node.split(":")[0], port=int(node.split.split(":")[1]))
ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1]))
for node in dify_config.REDIS_CLUSTERS.split(",")
]
redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD))
# FIXME: mypy error here, try to figure out how to fix it
redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD)) # type: ignore
else:
redis_params.update(
{

View File

@ -36,6 +36,7 @@ from core.variables.variables import (
StringVariable,
Variable,
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
class InvalidSelectorError(ValueError):
@ -62,11 +63,25 @@ SEGMENT_TO_VARIABLE_MAP = {
}
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if (value_type := mapping.get("value_type")) is None:
raise VariableError("missing value type")
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
"""
This factory function is used to create the environment variable or the conversation variable,
not support the File type.
"""
if (value_type := mapping.get("value_type")) is None:
raise VariableError("missing value type")
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
match value_type:
@ -92,6 +107,8 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
raise VariableError(f"not supported value type {value_type}")
if result.size > dify_config.MAX_VARIABLE_SIZE:
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector:
result = result.model_copy(update={"selector": selector})
return result

View File

@ -10,10 +10,10 @@ from collections.abc import Generator, Mapping
from datetime import datetime
from hashlib import sha256
from typing import Any, Optional, Union
from zoneinfo import available_timezones
from flask import Response, stream_with_context
from flask_restful import fields
from zoneinfo import available_timezones
from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator

View File

@ -0,0 +1,96 @@
"""add_fat_test
Revision ID: 49f175ff56cb
Revises: 43fa78bc3b7d
Create Date: 2024-11-05 03:26:22.578321
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '49f175ff56cb'
down_revision = '01d6889832f7'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('component_failure',
sa.Column('FailureID', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('Date', sa.Date(), nullable=False),
sa.Column('Component', sa.String(length=255), nullable=False),
sa.Column('FailureMode', sa.String(length=255), nullable=False),
sa.Column('Cause', sa.String(length=255), nullable=False),
sa.Column('RepairAction', sa.Text(), nullable=True),
sa.Column('Technician', sa.String(length=255), nullable=False),
sa.PrimaryKeyConstraint('FailureID', name=op.f('component_failure_pkey')),
sa.UniqueConstraint('Date', 'Component', 'FailureMode', 'Cause', 'Technician', name='unique_failure_entry')
)
op.create_table('component_failure_stats',
sa.Column('StatID', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('Component', sa.String(length=255), nullable=False),
sa.Column('FailureMode', sa.String(length=255), nullable=False),
sa.Column('Cause', sa.String(length=255), nullable=False),
sa.Column('PossibleAction', sa.Text(), nullable=True),
sa.Column('Probability', sa.Float(), nullable=False),
sa.Column('MTBF', sa.Float(), nullable=False),
sa.PrimaryKeyConstraint('StatID', name=op.f('component_failure_stats_pkey'))
)
op.create_table('incident_data',
sa.Column('IncidentID', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('IncidentDescription', sa.Text(), nullable=False),
sa.Column('IncidentDate', sa.Date(), nullable=False),
sa.Column('Consequences', sa.Text(), nullable=True),
sa.Column('ResponseActions', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('IncidentID', name=op.f('incident_data_pkey'))
)
op.create_table('maintenance',
sa.Column('MaintenanceID', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('MaintenanceType', sa.String(length=255), nullable=False),
sa.Column('MaintenanceDate', sa.Date(), nullable=False),
sa.Column('ServiceDescription', sa.Text(), nullable=True),
sa.Column('PartsReplaced', sa.Text(), nullable=True),
sa.Column('Technician', sa.String(length=255), nullable=False),
sa.PrimaryKeyConstraint('MaintenanceID', name=op.f('maintenance_pkey'))
)
op.create_table('operational_data',
sa.Column('OperationID', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('CraneUsage', sa.Integer(), nullable=False),
sa.Column('LoadWeight', sa.Float(), nullable=False),
sa.Column('LoadFrequency', sa.Integer(), nullable=False),
sa.Column('EnvironmentalConditions', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('OperationID', name=op.f('operational_data_pkey'))
)
op.create_table('reliability_data',
sa.Column('ComponentID', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('ComponentName', sa.String(length=255), nullable=False),
sa.Column('MTBF', sa.Float(), nullable=False),
sa.Column('FailureRate', sa.Float(), nullable=False),
sa.PrimaryKeyConstraint('ComponentID', name=op.f('reliability_data_pkey'))
)
op.create_table('safety_data',
sa.Column('SafetyID', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('SafetyInspectionDate', sa.Date(), nullable=False),
sa.Column('SafetyFindings', sa.Text(), nullable=True),
sa.Column('SafetyIncidentDescription', sa.Text(), nullable=True),
sa.Column('ComplianceStatus', sa.String(length=50), nullable=False),
sa.PrimaryKeyConstraint('SafetyID', name=op.f('safety_data_pkey'))
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('safety_data')
op.drop_table('reliability_data')
op.drop_table('operational_data')
op.drop_table('maintenance')
op.drop_table('incident_data')
op.drop_table('component_failure_stats')
op.drop_table('component_failure')
# ### end Alembic commands ###

View File

@ -24,30 +24,30 @@ from .workflow import (
)
__all__ = [
"Account",
"AccountIntegrate",
"ApiToken",
"App",
"AppMode",
"Conversation",
"ConversationVariable",
"Document",
"DataSourceOauthBinding",
"Dataset",
"DatasetProcessRule",
"Document",
"DocumentSegment",
"DataSourceOauthBinding",
"AppMode",
"Workflow",
"App",
"Message",
"EndUser",
"InstalledApp",
"InvitationCode",
"Message",
"MessageAnnotation",
"MessageFile",
"RecommendedApp",
"Site",
"Tenant",
"ToolFile",
"UploadFile",
"Account",
"Workflow",
"WorkflowAppLog",
"WorkflowRun",
"Site",
"InstalledApp",
"RecommendedApp",
"ApiToken",
"AccountIntegrate",
"InvitationCode",
"Tenant",
"Conversation",
"MessageAnnotation",
"ToolFile",
]

78
api/models/fta.py Normal file
View File

@ -0,0 +1,78 @@
from extensions.ext_database import db
class ComponentFailure(db.Model):
__tablename__ = "component_failure"
__table_args__ = (
db.UniqueConstraint("Date", "Component", "FailureMode", "Cause", "Technician", name="unique_failure_entry"),
)
FailureID = db.Column(db.Integer, primary_key=True, autoincrement=True)
Date = db.Column(db.Date, nullable=False)
Component = db.Column(db.String(255), nullable=False)
FailureMode = db.Column(db.String(255), nullable=False)
Cause = db.Column(db.String(255), nullable=False)
RepairAction = db.Column(db.Text, nullable=True)
Technician = db.Column(db.String(255), nullable=False)
class Maintenance(db.Model):
__tablename__ = "maintenance"
MaintenanceID = db.Column(db.Integer, primary_key=True, autoincrement=True)
MaintenanceType = db.Column(db.String(255), nullable=False)
MaintenanceDate = db.Column(db.Date, nullable=False)
ServiceDescription = db.Column(db.Text, nullable=True)
PartsReplaced = db.Column(db.Text, nullable=True)
Technician = db.Column(db.String(255), nullable=False)
class OperationalData(db.Model):
__tablename__ = "operational_data"
OperationID = db.Column(db.Integer, primary_key=True, autoincrement=True)
CraneUsage = db.Column(db.Integer, nullable=False)
LoadWeight = db.Column(db.Float, nullable=False)
LoadFrequency = db.Column(db.Integer, nullable=False)
EnvironmentalConditions = db.Column(db.Text, nullable=True)
class IncidentData(db.Model):
__tablename__ = "incident_data"
IncidentID = db.Column(db.Integer, primary_key=True, autoincrement=True)
IncidentDescription = db.Column(db.Text, nullable=False)
IncidentDate = db.Column(db.Date, nullable=False)
Consequences = db.Column(db.Text, nullable=True)
ResponseActions = db.Column(db.Text, nullable=True)
class ReliabilityData(db.Model):
__tablename__ = "reliability_data"
ComponentID = db.Column(db.Integer, primary_key=True, autoincrement=True)
ComponentName = db.Column(db.String(255), nullable=False)
MTBF = db.Column(db.Float, nullable=False)
FailureRate = db.Column(db.Float, nullable=False)
class SafetyData(db.Model):
__tablename__ = "safety_data"
SafetyID = db.Column(db.Integer, primary_key=True, autoincrement=True)
SafetyInspectionDate = db.Column(db.Date, nullable=False)
SafetyFindings = db.Column(db.Text, nullable=True)
SafetyIncidentDescription = db.Column(db.Text, nullable=True)
ComplianceStatus = db.Column(db.String(50), nullable=False)
class ComponentFailureStats(db.Model):
__tablename__ = "component_failure_stats"
StatID = db.Column(db.Integer, primary_key=True, autoincrement=True)
Component = db.Column(db.String(255), nullable=False)
FailureMode = db.Column(db.String(255), nullable=False)
Cause = db.Column(db.String(255), nullable=False)
PossibleAction = db.Column(db.Text, nullable=True)
Probability = db.Column(db.Float, nullable=False)
MTBF = db.Column(db.Float, nullable=False)

View File

@ -238,7 +238,9 @@ class Workflow(db.Model):
tenant_id = contexts.tenant_id.get()
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables)
results = [variable_factory.build_variable_from_mapping(v) for v in environment_variables_dict.values()]
results = [
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
]
# decrypt secret variables value
decrypt_func = (
@ -303,7 +305,7 @@ class Workflow(db.Model):
self._conversation_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
results = [variable_factory.build_variable_from_mapping(v) for v in variables_dict.values()]
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
return results
@conversation_variables.setter
@ -793,4 +795,4 @@ class ConversationVariable(db.Model):
def to_variable(self) -> Variable:
mapping = json.loads(self.data)
return variable_factory.build_variable_from_mapping(mapping)
return variable_factory.build_conversation_variable_from_mapping(mapping)

View File

@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
IMPORT_INFO_REDIS_EXPIRY = 180 # 3 minutes
CURRENT_DSL_VERSION = "0.1.3"
CURRENT_DSL_VERSION = "0.1.4"
class ImportMode(StrEnum):
@ -387,11 +387,11 @@ class AppDslService:
environment_variables_list = workflow_data.get("environment_variables", [])
environment_variables = [
variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = workflow_data.get("conversation_variables", [])
conversation_variables = [
variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
workflow_service = WorkflowService()

View File

@ -14,16 +14,16 @@ from . import (
)
__all__ = [
"base",
"conversation",
"message",
"index",
"app_model_config",
"account",
"document",
"dataset",
"app",
"completion",
"app_model_config",
"audio",
"base",
"completion",
"conversation",
"dataset",
"document",
"file",
"index",
"message",
]

View File

@ -12,7 +12,7 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.nodes import NodeType
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
@ -176,7 +176,8 @@ class WorkflowService:
"""
# return default block config
default_block_configs = []
for node_type, node_class in node_type_classes_mapping.items():
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
node_class = node_class_mapping[LATEST_VERSION]
default_config = node_class.get_default_config()
if default_config:
default_block_configs.append(default_config)
@ -190,13 +191,13 @@ class WorkflowService:
:param filters: filter by node config parameters.
:return:
"""
node_type_enum: NodeType = NodeType(node_type)
node_type_enum = NodeType(node_type)
# return default block config
node_class = node_type_classes_mapping.get(node_type_enum)
if not node_class:
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
return None
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
default_config = node_class.get_default_config(filters=filters)
if not default_config:
return None

View File

@ -1,4 +1,4 @@
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbConfig, AnalyticdbVector
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis

View File

@ -19,36 +19,36 @@ from factories import variable_factory
def test_string_variable():
test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"}
result = variable_factory.build_variable_from_mapping(test_data)
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, StringVariable)
def test_integer_variable():
test_data = {"value_type": "number", "name": "test_int", "value": 42}
result = variable_factory.build_variable_from_mapping(test_data)
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, IntegerVariable)
def test_float_variable():
test_data = {"value_type": "number", "name": "test_float", "value": 3.14}
result = variable_factory.build_variable_from_mapping(test_data)
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, FloatVariable)
def test_secret_variable():
test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"}
result = variable_factory.build_variable_from_mapping(test_data)
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, SecretVariable)
def test_invalid_value_type():
test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"}
with pytest.raises(VariableError):
variable_factory.build_variable_from_mapping(test_data)
variable_factory.build_conversation_variable_from_mapping(test_data)
def test_build_a_blank_string():
result = variable_factory.build_variable_from_mapping(
result = variable_factory.build_conversation_variable_from_mapping(
{
"value_type": "string",
"name": "blank",
@ -80,7 +80,7 @@ def test_object_variable():
"key2": 2,
},
}
variable = variable_factory.build_variable_from_mapping(mapping)
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ObjectSegment)
assert isinstance(variable.value["key1"], str)
assert isinstance(variable.value["key2"], int)
@ -97,7 +97,7 @@ def test_array_string_variable():
"text",
],
}
variable = variable_factory.build_variable_from_mapping(mapping)
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayStringVariable)
assert isinstance(variable.value[0], str)
assert isinstance(variable.value[1], str)
@ -114,7 +114,7 @@ def test_array_number_variable():
2.0,
],
}
variable = variable_factory.build_variable_from_mapping(mapping)
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayNumberVariable)
assert isinstance(variable.value[0], int)
assert isinstance(variable.value[1], float)
@ -137,7 +137,7 @@ def test_array_object_variable():
},
],
}
variable = variable_factory.build_variable_from_mapping(mapping)
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayObjectVariable)
assert isinstance(variable.value[0], dict)
assert isinstance(variable.value[1], dict)
@ -149,7 +149,7 @@ def test_array_object_variable():
def test_variable_cannot_large_than_200_kb():
with pytest.raises(VariableError):
variable_factory.build_variable_from_mapping(
variable_factory.build_conversation_variable_from_mapping(
{
"id": str(uuid4()),
"value_type": "string",

View File

@ -10,7 +10,8 @@ from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
from models.enums import UserFrom
from models.workflow import WorkflowType
@ -84,6 +85,7 @@ def test_overwrite_string_variable():
config={
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.OVER_WRITE.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
@ -91,7 +93,7 @@ def test_overwrite_string_variable():
},
)
with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run:
with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
list(node.run())
mock_run.assert_called_once()
@ -166,6 +168,7 @@ def test_append_variable_to_array():
config={
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.APPEND.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
@ -173,7 +176,7 @@ def test_append_variable_to_array():
},
)
with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run:
with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
list(node.run())
mock_run.assert_called_once()
@ -237,6 +240,7 @@ def test_clear_array():
config={
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.CLEAR.value,
"input_variable_selector": [],
@ -244,7 +248,7 @@ def test_clear_array():
},
)
with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run:
with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
list(node.run())
mock_run.assert_called_once()

View File

@ -0,0 +1,24 @@
import pytest
from core.variables import SegmentType
from core.workflow.nodes.variable_assigner.v2.enums import Operation
from core.workflow.nodes.variable_assigner.v2.helpers import is_input_value_valid
def test_is_input_value_valid_overwrite_array_string():
# Valid cases
assert is_input_value_valid(
variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["hello", "world"]
)
assert is_input_value_valid(variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[])
# Invalid cases
assert not is_input_value_valid(
variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value="not an array"
)
assert not is_input_value_valid(
variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[1, 2, 3]
)
assert not is_input_value_valid(
variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["valid", 123, "invalid"]
)

View File

@ -6,7 +6,7 @@ from models import ConversationVariable
def test_from_variable_and_to_variable():
variable = variable_factory.build_variable_from_mapping(
variable = variable_factory.build_conversation_variable_from_mapping(
{
"id": str(uuid4()),
"name": "name",

View File

@ -24,10 +24,18 @@ def test_environment_variables():
)
# Create some EnvironmentVariable instances
variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())})
variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())})
variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())})
variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())})
variable1 = StringVariable.model_validate(
{"name": "var1", "value": "value1", "id": str(uuid4()), "selector": ["env", "var1"]}
)
variable2 = IntegerVariable.model_validate(
{"name": "var2", "value": 123, "id": str(uuid4()), "selector": ["env", "var2"]}
)
variable3 = SecretVariable.model_validate(
{"name": "var3", "value": "secret", "id": str(uuid4()), "selector": ["env", "var3"]}
)
variable4 = FloatVariable.model_validate(
{"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]}
)
with (
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
@ -58,10 +66,18 @@ def test_update_environment_variables():
)
# Create some EnvironmentVariable instances
variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())})
variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())})
variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())})
variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())})
variable1 = StringVariable.model_validate(
{"name": "var1", "value": "value1", "id": str(uuid4()), "selector": ["env", "var1"]}
)
variable2 = IntegerVariable.model_validate(
{"name": "var2", "value": 123, "id": str(uuid4()), "selector": ["env", "var2"]}
)
variable3 = SecretVariable.model_validate(
{"name": "var3", "value": "secret", "id": str(uuid4()), "selector": ["env", "var3"]}
)
variable4 = FloatVariable.model_validate(
{"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]}
)
with (
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),

View File

@ -2,7 +2,7 @@ version: '3'
services:
# API service
api:
image: langgenius/dify-api:0.12.1
image: langgenius/dify-api:0.13.0
restart: always
environment:
# Startup mode, 'api' starts the API server.
@ -227,7 +227,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.12.1
image: langgenius/dify-api:0.13.0
restart: always
environment:
CONSOLE_WEB_URL: ''
@ -397,7 +397,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.12.1
image: langgenius/dify-web:0.13.0
restart: always
environment:
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is

View File

@ -292,7 +292,7 @@ x-shared-env: &shared-api-worker-env
services:
# API service
api:
image: langgenius/dify-api:0.12.1
image: langgenius/dify-api:0.13.0
restart: always
environment:
# Use the shared environment variables.
@ -312,7 +312,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.12.1
image: langgenius/dify-api:0.13.0
restart: always
environment:
# Use the shared environment variables.
@ -331,7 +331,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.12.1
image: langgenius/dify-web:0.13.0
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

Some files were not shown because too many files have changed in this diff Show More