Compare commits

..

83 Commits

Author SHA1 Message Date
bdfb31f332 chore(api): ignore google.cloud.storage in MyPy 2025-08-26 12:45:04 +08:00
f762c969b3 chore(api): resolve MyPy issues 2025-08-26 12:44:25 +08:00
6b87b0fa42 Merge remote-tracking branch 'upstream/feat/support-bool-variable-fe' into feat/boolean-type 2025-08-26 10:07:32 +08:00
4386c9f4f1 main 2025-08-25 10:35:52 +08:00
62955f4d11 chore: hide json entrance 2025-08-25 10:26:11 +08:00
e38ba8403d feat: workflow support input json 2025-08-22 14:54:34 +08:00
a1f278bdc8 fix: one step run form 2025-08-22 14:43:23 +08:00
65c0f8d657 feat: webapp support json field 2025-08-22 14:36:12 +08:00
c3e4d5a059 feat: support choose and save deep obj 2025-08-21 15:31:34 +08:00
30775109fd feat: can edit and choose obj 2025-08-21 14:47:29 +08:00
76659843d7 fix: item value type 2025-08-20 16:32:29 +08:00
42258bf451 chore: node tracing node value 2025-08-20 16:24:53 +08:00
399d510e1b fix(api): handle array[boolean] in LoopNode
The frontend save `array[boolean]` as a JSON array with string
element (e.g. `["true", "false", "true"]`), which is not handled
by LoopNode correctly.

This commit fix the issue by converting each element into Python
bool type before building segment.
2025-08-20 15:33:19 +08:00
8462c79179 fix: can not use converstation vars in var aggregator node 2025-08-20 14:59:58 +08:00
f5f11f5a9e feat: basic input field edit 2025-08-20 14:23:33 +08:00
96e8321462 chore: add var can not has exist key 2025-08-20 14:06:46 +08:00
af6e5e8663 chore: edit filed exists 2025-08-20 14:00:38 +08:00
d4f976270d fix: chore: if boolean array bug 2025-08-19 16:16:10 +08:00
dc3cbd9a74 chore: handle duliplicate keys 2025-08-19 15:57:18 +08:00
056564c100 merge main 2025-08-19 15:48:01 +08:00
d5b99610fc chore: iteratino input support bool list 2025-08-07 10:18:45 +08:00
04d2b0775d fix: chatbot bool 2025-08-04 11:14:33 +08:00
6614e97b53 fix: basic chatbot bool 2025-08-04 10:59:20 +08:00
ca37e304cc fix: basic bool error 2025-08-04 10:53:35 +08:00
437729db97 feat: support var arrs 2025-08-01 18:36:18 +08:00
fbd0effa04 chore: handle bool value in vars 2025-08-01 18:26:58 +08:00
14c62850e0 fix: list filter operation filter boolean problem 2025-08-01 18:06:27 +08:00
c4824a1d4c fix: assigner set false can not public 2025-08-01 17:55:20 +08:00
d642e9a98e chore: fix can not publish 2025-08-01 17:39:25 +08:00
72c6195c10 fix: list operation boolean default 2025-08-01 17:32:09 +08:00
2bd096b454 fix: con vars set value 2025-08-01 17:24:53 +08:00
fa98505e09 fix(api): enhance compatibility with existing workflow
Fix issues flagged by unit tests.
2025-08-01 15:27:33 +08:00
65ae25e09b test(api): add test cases for ParameterConfig validation logic 2025-08-01 15:25:28 +08:00
4943d8bc93 test(api): Introduce test cases for ParameterExtractorNode._validate_result 2025-08-01 15:25:28 +08:00
168d82f9b0 feat(api): support boolean types in parameter extractor node 2025-08-01 15:25:28 +08:00
0e3ccb4dcc feat(api): update boolean type handling in ListOperatorNode 2025-08-01 13:11:24 +08:00
f7f6210a4c fix(api): Fix incorrect handling of array types in SegmentType.is_valid
Add comphensive tests for the behavior of the `is_valid` method.
2025-08-01 13:09:46 +08:00
ed8e06d1bf fix(api): Allow bool value in the FilterCondition of list operator 2025-07-31 12:24:18 +08:00
6287e3cd9e fix: boolean can picker var 2025-07-31 10:28:17 +08:00
94c33c6eed chore: boolean string to boolean 2025-07-30 17:56:59 +08:00
383695b820 chore(api): put import for type annotation inside TYPE_CHECKING in llm/nodes.py
this causes unnecessary dependencies and cycle imports.
2025-07-30 16:13:04 +08:00
3751702efc feat(api): Implement boolean types support for list operator node 2025-07-30 10:58:18 +08:00
111c34e50f feat(api): Implement boolean types support for Code Node 2025-07-30 10:56:51 +08:00
a2246be4f5 feat(api): Implement checkbox input support in Workflow and Chat App 2025-07-30 10:56:51 +08:00
80e08562be feat(api): Initial support for boolean / array[boolean] types 2025-07-30 10:56:51 +08:00
6d03a15e0f chore: start form boolean to checkbox 2025-07-30 10:53:12 +08:00
bd04ddd544 merge main 2025-07-30 10:26:31 +08:00
cb51360aa3 chore: remove arr bool 2025-07-29 11:34:16 +08:00
269d34304e fix : loop var default value 2025-07-28 16:17:52 +08:00
590c14c977 chore: conveersion pass boolean string boolean 2025-07-28 16:02:50 +08:00
5fbeb275b2 fix: loop not choose bool var value 2025-07-28 15:20:18 +08:00
ff9d051635 chore: fix if not value set value 2025-07-28 14:49:44 +08:00
1e18c828d9 chore: remove transform boolean to string in editor 2025-07-28 14:09:34 +08:00
5b895be412 feat: template support bools input 2025-07-28 13:46:55 +08:00
ea2a1b203c fix: list operator bool not show the bool input 2025-07-28 11:20:22 +08:00
7bfdfe73c2 chore: list filter bool value 2025-07-28 11:09:37 +08:00
171c5055f8 fix: loop arry bool inner item type value 2025-07-25 14:52:13 +08:00
e3e4369358 fix: webapp get boolean value 2025-07-25 14:05:14 +08:00
efa28453be fix: if value show 2025-07-24 17:54:22 +08:00
37de3b1b68 fix: single run bool type 2025-07-24 17:39:10 +08:00
75d4e519fc chore: loop false can run 2025-07-24 16:06:13 +08:00
cf63926e16 chore: assigner node 2025-07-24 15:59:44 +08:00
14eb5b7930 feat: list operator support bool 2025-07-24 15:15:35 +08:00
544ebde054 chore: if bool value problem 2025-07-24 15:05:09 +08:00
6012ad57ac feat: code support boolean 2025-07-24 14:35:41 +08:00
70f6d8b42a main 2025-07-24 14:10:29 +08:00
1d738f3fa6 feat: chat support bool 2025-07-09 16:48:17 +08:00
912d68a148 chore: competion webapp 2025-07-09 14:56:11 +08:00
8c8c250570 feat: completion support boolean 2025-07-09 14:28:36 +08:00
b5782fff8f feat: basic chat app support bool 2025-07-09 14:06:25 +08:00
f2dfb5363f feat: one step run support boolean 2025-07-09 11:23:21 +08:00
fc0dea5647 feat: workflow debug add bool 2025-07-09 10:53:52 +08:00
fcf0ae4c85 feat: extract parameter support bool 2025-07-08 16:41:58 +08:00
9ac629bcf5 feat: if node support bool 2025-07-08 16:30:54 +08:00
79bdb22f79 feat: loop termin support bool 2025-07-08 16:06:56 +08:00
b72d977871 feat: loop var support bool type 2025-07-08 15:23:24 +08:00
657d3d1dae fix: boolean problem 2025-07-08 14:34:18 +08:00
3ba23238e5 chore: boolarray in varibale 2025-07-04 17:47:44 +08:00
e1a7b59160 chore: missing files 2025-07-04 16:48:56 +08:00
18941beb34 feat: converstation support bool 2025-07-04 16:48:33 +08:00
6ff18d13e6 feat: struct output boolean 2025-07-04 16:12:47 +08:00
1094b3da23 feat: add var type 2025-07-04 16:02:28 +08:00
77aa35ff15 feat: add boolean type 2025-07-04 15:27:52 +08:00
174 changed files with 4005 additions and 1430 deletions

View File

@ -97,16 +97,8 @@ uv run celery -A app.celery beat
uv sync --dev
```
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md)
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
```cli
uv run --project api pytest # Run all tests
uv run --project api pytest tests/unit_tests/ # Unit tests only
uv run --project api pytest tests/integration_tests/ # Integration tests
# Code quality
./dev/reformat # Run all formatters and linters
uv run --project api ruff check --fix ./ # Fix linting issues
uv run --project api ruff format ./ # Format code
uv run --project api mypy . # Type checking
```bash
uv run -P api bash dev/pytest/pytest_all_tests.sh
```

View File

@ -5,8 +5,6 @@ from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from dify_app import DifyApp
logger = logging.getLogger(__name__)
# ----------------------------
# Application Factory Function
@ -34,7 +32,7 @@ def create_app() -> DifyApp:
initialize_extensions(app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
logging.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
return app
@ -95,14 +93,14 @@ def initialize_extensions(app: DifyApp):
is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True
if not is_enabled:
if dify_config.DEBUG:
logger.info("Skipped %s", short_name)
logging.info("Skipped %s", short_name)
continue
start_time = time.perf_counter()
ext.init_app(app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
logging.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
def create_migrations_app():

11
api/child_class.py Normal file
View File

@ -0,0 +1,11 @@
from tests.integration_tests.utils.parent_class import ParentClass
class ChildClass(ParentClass):
"""Test child class for module import helper tests"""
def __init__(self, name):
super().__init__(name)
def get_name(self):
return f"Child: {self.name}"

View File

@ -221,7 +221,7 @@ class EmailCodeLoginApi(Resource):
email=user_email, name=user_email, interface_language=languages[0]
)
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
return NotAllowedCreateWorkspace()
except AccountRegisterError as are:
raise AccountInFreezeError()
except WorkspacesLimitExceededError:

View File

@ -107,22 +107,6 @@ class PluginIconApi(Resource):
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
class PluginAssetApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
req = reqparse.RequestParser()
req.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
req.add_argument("file_name", type=str, required=True, location="args")
args = req.parse_args()
tenant_id = current_user.current_tenant_id
try:
binary = PluginService.extract_asset(tenant_id, args["plugin_unique_identifier"], args["file_name"])
return send_file(io.BytesIO(binary), mimetype="application/octet-stream")
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginUploadFromPkgApi(Resource):
@setup_required
@ -659,34 +643,11 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
class PluginReadmeApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
parser.add_argument("language", type=str, required=False, location="args")
args = parser.parse_args()
return jsonable_encoder(
{
"readme": PluginService.fetch_plugin_readme(
tenant_id,
args["plugin_unique_identifier"],
args.get("language", "en-US")
)
}
)
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginReadmeApi, "/workspaces/current/plugin/readme")
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
api.add_resource(PluginAssetApi, "/workspaces/current/plugin/asset")
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")

View File

@ -13,7 +13,7 @@ api = ExternalApi(
doc="/docs", # Enable Swagger UI at /files/docs
)
files_ns = Namespace("files", description="File operations", path="/")
files_ns = Namespace("files", description="File operations")
from . import image_preview, tool_files, upload

View File

@ -13,7 +13,7 @@ api = ExternalApi(
doc="/docs", # Enable Swagger UI at /mcp/docs
)
mcp_ns = Namespace("mcp", description="MCP operations", path="/")
mcp_ns = Namespace("mcp", description="MCP operations")
from . import mcp

View File

@ -13,7 +13,7 @@ api = ExternalApi(
doc="/docs", # Enable Swagger UI at /v1/docs
)
service_api_ns = Namespace("service_api", description="Service operations", path="/")
service_api_ns = Namespace("service_api", description="Service operations")
from . import index
from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow

View File

@ -3,6 +3,17 @@ import re
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.external_data_tool.factory import ExternalDataToolFactory
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
[
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.EXTERNAL_DATA_TOOL,
VariableEntityType.CHECKBOX,
]
)
class BasicVariablesConfigManager:
@classmethod
@ -47,6 +58,7 @@ class BasicVariablesConfigManager:
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.SELECT,
VariableEntityType.CHECKBOX,
}:
variable = variables[variable_type]
variable_entities.append(
@ -96,8 +108,17 @@ class BasicVariablesConfigManager:
variables = []
for item in config["user_input_form"]:
key = list(item.keys())[0]
if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
# if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
if key not in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.EXTERNAL_DATA_TOOL,
VariableEntityType.CHECKBOX,
}:
allowed_keys = ", ".join(i.value for i in _ALLOWED_VARIABLE_ENTITY_TYPE)
raise ValueError(f"Keys in user_input_form list can only be {allowed_keys}")
form_item = item[key]
if "label" not in form_item:

View File

@ -97,6 +97,7 @@ class VariableEntityType(StrEnum):
EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
CHECKBOX = "checkbox"
class VariableEntity(BaseModel):

View File

@ -103,18 +103,23 @@ class BaseAppGenerator:
f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string"
)
if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str):
# handle empty string case
if not value.strip():
return None
# may raise ValueError if user_input_value is not a valid number
try:
if "." in value:
return float(value)
else:
return int(value)
except ValueError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
if variable_entity.type == VariableEntityType.NUMBER:
if isinstance(value, (int, float)):
return value
elif isinstance(value, str):
# handle empty string case
if not value.strip():
return None
# may raise ValueError if user_input_value is not a valid number
try:
if "." in value:
return float(value)
else:
return int(value)
except ValueError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
else:
raise TypeError(f"expected value type int, float or str, got {type(value)}, value: {value}")
match variable_entity.type:
case VariableEntityType.SELECT:
@ -144,6 +149,11 @@ class BaseAppGenerator:
raise ValueError(
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
)
case VariableEntityType.CHECKBOX:
if not isinstance(value, bool):
raise ValueError(f"{variable_entity.variable} in input form must be a valid boolean value")
case _:
raise AssertionError("this statement should be unreachable.")
return value

View File

@ -196,7 +196,3 @@ class PluginListResponse(BaseModel):
class PluginDynamicSelectOptionsResponse(BaseModel):
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
class PluginReadmeResponse(BaseModel):
content: str = Field(description="The readme of the plugin.")
language: str = Field(description="The language of the readme.")

View File

@ -10,9 +10,3 @@ class PluginAssetManager(BasePluginClient):
if response.status_code != 200:
raise ValueError(f"can not found asset {id}")
return response.content
def extract_asset(self, tenant_id: str, plugin_unique_identifier: str, filename: str) -> bytes:
response = self._request(method="GET", path=f"plugin/{tenant_id}/asset/{plugin_unique_identifier}")
if response.status_code != 200:
raise ValueError(f"can not found asset {plugin_unique_identifier}, {str(response.status_code)}")
return response.content

View File

@ -1,7 +1,5 @@
from collections.abc import Sequence
from requests import HTTPError
from core.plugin.entities.bundle import PluginBundleDependency
from core.plugin.entities.plugin import (
GenericProviderID,
@ -16,34 +14,11 @@ from core.plugin.entities.plugin_daemon import (
PluginInstallTask,
PluginInstallTaskStartResponse,
PluginListResponse,
PluginReadmeResponse,
)
from core.plugin.impl.base import BasePluginClient
class PluginInstaller(BasePluginClient):
def fetch_plugin_readme(self, tenant_id: str, plugin_unique_identifier: str, language: str) -> str:
"""
Fetch plugin readme
"""
try:
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/fetch/readme",
PluginReadmeResponse,
params={
"tenant_id":tenant_id,
"plugin_unique_identifier": plugin_unique_identifier,
"language": language
}
)
return response.content
except HTTPError as e:
message = e.args[0]
if "404" in message:
return ""
raise e
def fetch_plugin_by_identifier(
self,
tenant_id: str,

View File

@ -151,6 +151,11 @@ class FileSegment(Segment):
return ""
class BooleanSegment(Segment):
value_type: SegmentType = SegmentType.BOOLEAN
value: bool
class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Any]
@ -198,6 +203,11 @@ class ArrayFileSegment(ArraySegment):
return ""
class ArrayBooleanSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
value: Sequence[bool]
def get_segment_discriminator(v: Any) -> SegmentType | None:
if isinstance(v, Segment):
return v.value_type
@ -231,11 +241,13 @@ SegmentUnion: TypeAlias = Annotated[
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
| Annotated[FileSegment, Tag(SegmentType.FILE)]
| Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)]
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
),
Discriminator(get_segment_discriminator),
]

View File

@ -6,7 +6,12 @@ from core.file.models import File
class ArrayValidation(StrEnum):
"""Strategy for validating array elements"""
"""Strategy for validating array elements.
Note:
The `NONE` and `FIRST` strategies are primarily for compatibility purposes.
Avoid using them in new code whenever possible.
"""
# Skip element validation (only check array container)
NONE = "none"
@ -27,12 +32,14 @@ class SegmentType(StrEnum):
SECRET = "secret"
FILE = "file"
BOOLEAN = "boolean"
ARRAY_ANY = "array[any]"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILE = "array[file]"
ARRAY_BOOLEAN = "array[boolean]"
NONE = "none"
@ -76,12 +83,18 @@ class SegmentType(StrEnum):
return SegmentType.ARRAY_FILE
case SegmentType.NONE:
return SegmentType.ARRAY_ANY
case SegmentType.BOOLEAN:
return SegmentType.ARRAY_BOOLEAN
case _:
# This should be unreachable.
raise ValueError(f"not supported value {value}")
if value is None:
return SegmentType.NONE
elif isinstance(value, int) and not isinstance(value, bool):
# Important: The check for `bool` must precede the check for `int`,
# as `bool` is a subclass of `int` in Python's type hierarchy.
elif isinstance(value, bool):
return SegmentType.BOOLEAN
elif isinstance(value, int):
return SegmentType.INTEGER
elif isinstance(value, float):
return SegmentType.FLOAT
@ -111,7 +124,7 @@ class SegmentType(StrEnum):
else:
return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value)
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.ALL) -> bool:
"""
Check if a value matches the segment type.
Users of `SegmentType` should call this method, instead of using
@ -126,6 +139,10 @@ class SegmentType(StrEnum):
"""
if self.is_array_type():
return self._validate_array(value, array_validation)
# Important: The check for `bool` must precede the check for `int`,
# as `bool` is a subclass of `int` in Python's type hierarchy.
elif self == SegmentType.BOOLEAN:
return isinstance(value, bool)
elif self in [SegmentType.INTEGER, SegmentType.FLOAT, SegmentType.NUMBER]:
return isinstance(value, (int, float))
elif self == SegmentType.STRING:
@ -141,6 +158,27 @@ class SegmentType(StrEnum):
else:
raise AssertionError("this statement should be unreachable.")
@staticmethod
def cast_value(value: Any, type_: "SegmentType") -> Any:
# Cast Python's `bool` type to `int` when the runtime type requires
# an integer or number.
#
# This ensures compatibility with existing workflows that may use `bool` as
# `int`, since in Python's type system, `bool` is a subtype of `int`.
#
# This function exists solely to maintain compatibility with existing workflows.
# It should not be used to compromise the integrity of the runtime type system.
# No additional casting rules should be introduced to this function.
if type_ in (
SegmentType.INTEGER,
SegmentType.NUMBER,
) and isinstance(value, bool):
return int(value)
if type_ == SegmentType.ARRAY_NUMBER and all(isinstance(i, bool) for i in value):
return [int(i) for i in value]
return value
def exposed_type(self) -> "SegmentType":
"""Returns the type exposed to the frontend.
@ -150,6 +188,20 @@ class SegmentType(StrEnum):
return SegmentType.NUMBER
return self
def element_type(self) -> "SegmentType | None":
"""Return the element type of the current segment type, or `None` if the element type is undefined.
Raises:
ValueError: If the current segment type is not an array type.
Note:
For certain array types, such as `SegmentType.ARRAY_ANY`, their element types are not defined
by the runtime system. In such cases, this method will return `None`.
"""
if not self.is_array_type():
raise ValueError(f"element_type is only supported by array type, got {self}")
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
# ARRAY_ANY does not have corresponding element type.
@ -157,6 +209,7 @@ _ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
SegmentType.ARRAY_FILE: SegmentType.FILE,
SegmentType.ARRAY_BOOLEAN: SegmentType.BOOLEAN,
}
_ARRAY_TYPES = frozenset(

View File

@ -8,11 +8,13 @@ from core.helper import encrypter
from .segments import (
ArrayAnySegment,
ArrayBooleanSegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArraySegment,
ArrayStringSegment,
BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
@ -96,10 +98,18 @@ class FileVariable(FileSegment, Variable):
pass
class BooleanVariable(BooleanSegment, Variable):
pass
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass
class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
pass
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
# Use `Variable` for type hinting when serialization is not required.
#
@ -114,11 +124,13 @@ VariableUnion: TypeAlias = Annotated[
| Annotated[IntegerVariable, Tag(SegmentType.INTEGER)]
| Annotated[ObjectVariable, Tag(SegmentType.OBJECT)]
| Annotated[FileVariable, Tag(SegmentType.FILE)]
| Annotated[BooleanVariable, Tag(SegmentType.BOOLEAN)]
| Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
),
Discriminator(get_segment_discriminator),

View File

@ -8,6 +8,7 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.variables.segments import ArrayFileSegment
from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@ -119,6 +120,14 @@ class CodeNode(BaseNode):
return value.replace("\x00", "")
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
if value is None:
return None
if not isinstance(value, bool):
raise OutputValidationError(f"Output variable `{variable}` must be a boolean")
return value
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
"""
Check number
@ -173,6 +182,8 @@ class CodeNode(BaseNode):
prefix=f"{prefix}.{output_name}" if prefix else output_name,
depth=depth + 1,
)
elif isinstance(output_value, bool):
self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name)
elif isinstance(output_value, int | float):
self._check_number(
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
@ -232,7 +243,7 @@ class CodeNode(BaseNode):
if output_name not in result:
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
if output_config.type == "object":
if output_config.type == SegmentType.OBJECT:
# check if output is object
if not isinstance(result.get(output_name), dict):
if result[output_name] is None:
@ -249,18 +260,28 @@ class CodeNode(BaseNode):
prefix=f"{prefix}.{output_name}",
depth=depth + 1,
)
elif output_config.type == "number":
elif output_config.type == SegmentType.NUMBER:
# check if number available
transformed_result[output_name] = self._check_number(
value=result[output_name], variable=f"{prefix}{dot}{output_name}"
)
elif output_config.type == "string":
checked = self._check_number(value=result[output_name], variable=f"{prefix}{dot}{output_name}")
# If the output is a boolean and the output schema specifies a NUMBER type,
# convert the boolean value to an integer.
#
# This ensures compatibility with existing workflows that may use
# `True` and `False` as values for NUMBER type outputs.
transformed_result[output_name] = self._convert_boolean_to_int(checked)
elif output_config.type == SegmentType.STRING:
# check if string available
transformed_result[output_name] = self._check_string(
value=result[output_name],
variable=f"{prefix}{dot}{output_name}",
)
elif output_config.type == "array[number]":
elif output_config.type == SegmentType.BOOLEAN:
transformed_result[output_name] = self._check_boolean(
value=result[output_name],
variable=f"{prefix}{dot}{output_name}",
)
elif output_config.type == SegmentType.ARRAY_NUMBER:
# check if array of number available
if not isinstance(result[output_name], list):
if result[output_name] is None:
@ -278,10 +299,17 @@ class CodeNode(BaseNode):
)
transformed_result[output_name] = [
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
# If the element is a boolean and the output schema specifies a `array[number]` type,
# convert the boolean value to an integer.
#
# This ensures compatibility with existing workflows that may use
# `True` and `False` as values for NUMBER type outputs.
self._convert_boolean_to_int(
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]"),
)
for i, value in enumerate(result[output_name])
]
elif output_config.type == "array[string]":
elif output_config.type == SegmentType.ARRAY_STRING:
# check if array of string available
if not isinstance(result[output_name], list):
if result[output_name] is None:
@ -302,7 +330,7 @@ class CodeNode(BaseNode):
self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
for i, value in enumerate(result[output_name])
]
elif output_config.type == "array[object]":
elif output_config.type == SegmentType.ARRAY_OBJECT:
# check if array of object available
if not isinstance(result[output_name], list):
if result[output_name] is None:
@ -340,6 +368,22 @@ class CodeNode(BaseNode):
)
for i, value in enumerate(result[output_name])
]
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
# check if array of object available
if not isinstance(result[output_name], list):
if result[output_name] is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
transformed_result[output_name] = [
self._check_boolean(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
for i, value in enumerate(result[output_name])
]
else:
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
@ -374,3 +418,16 @@ class CodeNode(BaseNode):
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
@staticmethod
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
"""This function convert boolean to integers when the output schema specifies a NUMBER type.
This ensures compatibility with existing workflows that may use
`True` and `False` as values for NUMBER type outputs.
"""
if value is None:
return None
if isinstance(value, bool):
return int(value)
return value

View File

@ -1,11 +1,31 @@
from typing import Literal, Optional
from typing import Annotated, Literal, Optional
from pydantic import BaseModel
from pydantic import AfterValidator, BaseModel
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.types import SegmentType
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
[
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.OBJECT,
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
]
)
def _validate_type(segment_type: SegmentType) -> SegmentType:
if segment_type not in _ALLOWED_OUTPUT_FROM_CODE:
raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}")
return segment_type
class CodeNodeData(BaseNodeData):
"""
@ -13,7 +33,7 @@ class CodeNodeData(BaseNodeData):
"""
class Output(BaseModel):
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: Optional[dict[str, "CodeNodeData.Output"]] = None
class Dependency(BaseModel):

View File

@ -1,36 +1,43 @@
from collections.abc import Sequence
from typing import Literal
from enum import StrEnum
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
_Condition = Literal[
class FilterOperator(StrEnum):
# string conditions
"contains",
"start with",
"end with",
"is",
"in",
"empty",
"not contains",
"is not",
"not in",
"not empty",
CONTAINS = "contains"
START_WITH = "start with"
END_WITH = "end with"
IS = "is"
IN = "in"
EMPTY = "empty"
NOT_CONTAINS = "not contains"
IS_NOT = "is not"
NOT_IN = "not in"
NOT_EMPTY = "not empty"
# number conditions
"=",
"",
"<",
">",
"",
"",
]
EQUAL = "="
NOT_EQUAL = ""
LESS_THAN = "<"
GREATER_THAN = ">"
GREATER_THAN_OR_EQUAL = ""
LESS_THAN_OR_EQUAL = ""
class Order(StrEnum):
ASC = "asc"
DESC = "desc"
class FilterCondition(BaseModel):
key: str = ""
comparison_operator: _Condition = "contains"
value: str | Sequence[str] = ""
comparison_operator: FilterOperator = FilterOperator.CONTAINS
# the value is bool if the filter operator is comparing with
# a boolean constant.
value: str | Sequence[str] | bool = ""
class FilterBy(BaseModel):
@ -38,10 +45,10 @@ class FilterBy(BaseModel):
conditions: Sequence[FilterCondition] = Field(default_factory=list)
class OrderBy(BaseModel):
class OrderByConfig(BaseModel):
enabled: bool = False
key: str = ""
value: Literal["asc", "desc"] = "asc"
value: Order = Order.ASC
class Limit(BaseModel):
@ -57,6 +64,6 @@ class ExtractConfig(BaseModel):
class ListOperatorNodeData(BaseNodeData):
variable: Sequence[str] = Field(default_factory=list)
filter_by: FilterBy
order_by: OrderBy
order_by: OrderByConfig
limit: Limit
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)

View File

@ -1,18 +1,40 @@
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal, Optional, Union
from typing import Any, Optional, TypeAlias, TypeVar
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .entities import ListOperatorNodeData
from .entities import FilterOperator, ListOperatorNodeData, Order
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
_SUPPORTED_TYPES_TUPLE = (
ArrayFileSegment,
ArrayNumberSegment,
ArrayStringSegment,
ArrayBooleanSegment,
)
_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment
_T = TypeVar("_T")
def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
"""Returns the negation of a given filter function. If the original filter
returns `True` for a value, the negated filter will return `False`, and vice versa.
"""
def wrapper(value: _T) -> bool:
return not filter_(value)
return wrapper
class ListOperatorNode(BaseNode):
_node_type = NodeType.LIST_OPERATOR
@ -69,11 +91,8 @@ class ListOperatorNode(BaseNode):
process_data=process_data,
outputs=outputs,
)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
)
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
@ -122,9 +141,7 @@ class ListOperatorNode(BaseNode):
outputs=outputs,
)
def _apply_filter(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
filter_func: Callable[[Any], bool]
result: list[Any] = []
for condition in self._node_data.filter_by.conditions:
@ -154,33 +171,35 @@ class ListOperatorNode(BaseNode):
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayBooleanSegment):
if not isinstance(condition.value, bool):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
else:
raise AssertionError("this statment should be unreachable.")
return variable
def _apply_order(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self._node_data.order_by.value, array=variable.value)
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
else:
raise AssertionError("this statement should be unreachable")
return variable
def _apply_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
result = variable.value[: self._node_data.limit.size]
return variable.model_copy(update={"value": result})
def _extract_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
@ -232,11 +251,11 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
case "empty":
return lambda x: x == ""
case "not contains":
return lambda x: not _contains(value)(x)
return _negation(_contains(value))
case "is not":
return lambda x: not _is(value)(x)
return _negation(_is(value))
case "not in":
return lambda x: not _in(value)(x)
return _negation(_in(value))
case "not empty":
return lambda x: x != ""
case _:
@ -248,7 +267,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
case "in":
return _in(value)
case "not in":
return lambda x: not _in(value)(x)
return _negation(_in(value))
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
@ -271,6 +290,16 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]:
match condition:
case FilterOperator.IS:
return _is(value)
case FilterOperator.IS_NOT:
return _negation(_is(value))
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
@ -298,7 +327,7 @@ def _endswith(value: str) -> Callable[[str], bool]:
return lambda x: x.endswith(value)
def _is(value: str) -> Callable[[str], bool]:
def _is(value: _T) -> Callable[[_T], bool]:
return lambda x: x == value
@ -330,21 +359,13 @@ def _ge(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x >= value
def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]):
return sorted(array, key=lambda x: x, reverse=order == "desc")
def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
return sorted(array, key=lambda x: x, reverse=order == "desc")
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
extract_func: Callable[[File], Any]
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
extract_func = _get_file_extract_string_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
elif order_by == "size":
extract_func = _get_file_extract_number_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
else:
raise InvalidKeyError(f"Invalid order key: {order_by}")

View File

@ -3,7 +3,7 @@ import io
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
@ -55,7 +55,6 @@ from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
@ -90,6 +89,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.event import InNodeEvent
logger = logging.getLogger(__name__)
@ -161,7 +161,7 @@ class LLMNode(BaseNode):
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
node_inputs: Optional[dict[str, Any]] = None
process_data = None
result_text = ""

View File

@ -12,9 +12,11 @@ _VALID_VAR_TYPE = frozenset(
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.OBJECT,
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
]
)

View File

@ -404,11 +404,11 @@ class LoopNode(BaseNode):
for node_id in loop_graph.node_ids:
variable_pool.remove([node_id])
_outputs = {}
_outputs: dict[str, Segment | int | None] = {}
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
_loop_variable_segment = variable_pool.get(loop_variable_selector)
if _loop_variable_segment:
_outputs[loop_variable_key] = _loop_variable_segment.value
_outputs[loop_variable_key] = _loop_variable_segment
else:
_outputs[loop_variable_key] = None
@ -522,21 +522,30 @@ class LoopNode(BaseNode):
return variable_mapping
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
if var_type in ["array[string]", "array[number]", "array[object]"]:
if value and isinstance(value, str):
value = json.loads(value)
if var_type in [
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
]:
if original_value and isinstance(original_value, str):
value = json.loads(original_value)
else:
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
value = []
elif var_type == SegmentType.ARRAY_BOOLEAN:
value = original_value
else:
raise AssertionError("this statement should be unreachable.")
try:
return build_segment_with_type(var_type, value)
return build_segment_with_type(var_type, value=value)
except TypeMismatchError as type_exc:
# Attempt to parse the value as a JSON-encoded string, if applicable.
if not isinstance(value, str):
if not isinstance(original_value, str):
raise
try:
value = json.loads(value)
value = json.loads(original_value)
except ValueError:
raise type_exc
return build_segment_with_type(var_type, value)

View File

@ -1,10 +1,46 @@
from typing import Any, Literal, Optional
from typing import Annotated, Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
from pydantic import (
BaseModel,
BeforeValidator,
Field,
field_validator,
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
_OLD_BOOL_TYPE_NAME = "bool"
_OLD_SELECT_TYPE_NAME = "select"
_VALID_PARAMETER_TYPES = frozenset(
[
SegmentType.STRING, # "string",
SegmentType.NUMBER, # "number",
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
_OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node
_OLD_SELECT_TYPE_NAME, # string type with enumeration choices.
]
)
def _validate_type(parameter_type: str) -> SegmentType:
if not isinstance(parameter_type, str):
raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}")
if parameter_type not in _VALID_PARAMETER_TYPES:
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
if parameter_type == _OLD_BOOL_TYPE_NAME:
return SegmentType.BOOLEAN
elif parameter_type == _OLD_SELECT_TYPE_NAME:
return SegmentType.STRING
return SegmentType(parameter_type)
class _ParameterConfigError(Exception):
@ -17,7 +53,7 @@ class ParameterConfig(BaseModel):
"""
name: str
type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"]
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
options: Optional[list[str]] = None
description: str
required: bool
@ -32,17 +68,20 @@ class ParameterConfig(BaseModel):
return str(value)
def is_array_type(self) -> bool:
return self.type in ("array[string]", "array[number]", "array[object]")
return self.type.is_array_type()
def element_type(self) -> Literal["string", "number", "object"]:
if self.type == "array[number]":
return "number"
elif self.type == "array[string]":
return "string"
elif self.type == "array[object]":
return "object"
else:
raise _ParameterConfigError(f"{self.type} is not array type.")
def element_type(self) -> SegmentType:
"""Return the element type of the parameter.
Raises a ValueError if the parameter's type is not an array type.
"""
element_type = self.type.element_type()
# At this point, self.type is guaranteed to be one of `ARRAY_STRING`,
# `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`.
#
# See: _VALID_PARAMETER_TYPES for reference.
assert element_type is not None, f"the element type should not be None, {self.type=}"
return element_type
class ParameterExtractorNodeData(BaseNodeData):
@ -74,16 +113,18 @@ class ParameterExtractorNodeData(BaseNodeData):
for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description}
if parameter.type in {"string", "select"}:
if parameter.type == SegmentType.STRING:
parameter_schema["type"] = "string"
elif parameter.type.startswith("array"):
elif parameter.type.is_array_type():
parameter_schema["type"] = "array"
nested_type = parameter.type[6:-1]
parameter_schema["items"] = {"type": nested_type}
element_type = parameter.type.element_type()
if element_type is None:
raise AssertionError("element type should not be None.")
parameter_schema["items"] = {"type": element_type.value}
else:
parameter_schema["type"] = parameter.type
if parameter.type == "select":
if parameter.options:
parameter_schema["enum"] = parameter.options
parameters["properties"][parameter.name] = parameter_schema

View File

@ -1,3 +1,8 @@
from typing import Any
from core.variables.types import SegmentType
class ParameterExtractorNodeError(ValueError):
"""Base error for ParameterExtractorNode."""
@ -48,3 +53,23 @@ class InvalidArrayValueError(ParameterExtractorNodeError):
class InvalidModelModeError(ParameterExtractorNodeError):
"""Raised when the model mode is invalid."""
class InvalidValueTypeError(ParameterExtractorNodeError):
def __init__(
self,
/,
parameter_name: str,
expected_type: SegmentType,
actual_type: SegmentType | None,
value: Any,
) -> None:
message = (
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
)
super().__init__(message)
self.parameter_name = parameter_name
self.expected_type = expected_type
self.actual_type = actual_type
self.value = value

View File

@ -26,7 +26,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import SegmentType
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -39,16 +39,13 @@ from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData
from .exc import (
InvalidArrayValueError,
InvalidBoolValueError,
InvalidInvokeResultError,
InvalidModelModeError,
InvalidModelTypeError,
InvalidNumberOfParametersError,
InvalidNumberValueError,
InvalidSelectValueError,
InvalidStringValueError,
InvalidTextContentTypeError,
InvalidValueTypeError,
ModelSchemaNotFoundError,
ParameterExtractorNodeError,
RequiredParameterMissingError,
@ -549,9 +546,6 @@ class ParameterExtractorNode(BaseNode):
return prompt_messages
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
"""
Validate result.
"""
if len(data.parameters) != len(result):
raise InvalidNumberOfParametersError("Invalid number of parameters")
@ -559,101 +553,106 @@ class ParameterExtractorNode(BaseNode):
if parameter.required and parameter.name not in result:
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
if parameter.type.startswith("array"):
parameters = result.get(parameter.name)
if not isinstance(parameters, list):
raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
nested_type = parameter.type[6:-1]
for item in parameters:
if nested_type == "number" and not isinstance(item, int | float):
raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
if nested_type == "string" and not isinstance(item, str):
raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
if nested_type == "object" and not isinstance(item, dict):
raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
param_value = result.get(parameter.name)
if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL):
inferred_type = SegmentType.infer_segment_type(param_value)
raise InvalidValueTypeError(
parameter_name=parameter.name,
expected_type=parameter.type,
actual_type=inferred_type,
value=param_value,
)
if parameter.type == SegmentType.STRING and parameter.options:
if param_value not in parameter.options:
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
return result
@staticmethod
def _transform_number(value: int | float | str | bool) -> int | float | None:
"""
Attempts to transform the input into an integer or float.
Returns:
int or float: The transformed number if the conversion is successful.
None: If the transformation fails.
Note:
Boolean values `True` and `False` are converted to integers `1` and `0`, respectively.
This behavior ensures compatibility with existing workflows that may use boolean types as integers.
"""
if isinstance(value, bool):
return int(value)
elif isinstance(value, (int, float)):
return value
elif not isinstance(value, str):
return None
if "." in value:
try:
return float(value)
except ValueError:
return None
else:
try:
return int(value)
except ValueError:
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
"""
Transform result into standard format.
"""
transformed_result = {}
transformed_result: dict[str, Any] = {}
for parameter in data.parameters:
if parameter.name in result:
param_value = result[parameter.name]
# transform value
if parameter.type == "number":
if isinstance(result[parameter.name], int | float):
transformed_result[parameter.name] = result[parameter.name]
elif isinstance(result[parameter.name], str):
try:
if "." in result[parameter.name]:
result[parameter.name] = float(result[parameter.name])
else:
result[parameter.name] = int(result[parameter.name])
except ValueError:
pass
else:
pass
# TODO: bool is not supported in the current version
# elif parameter.type == 'bool':
# if isinstance(result[parameter.name], bool):
# transformed_result[parameter.name] = bool(result[parameter.name])
# elif isinstance(result[parameter.name], str):
# if result[parameter.name].lower() in ['true', 'false']:
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
# elif isinstance(result[parameter.name], int):
# transformed_result[parameter.name] = bool(result[parameter.name])
elif parameter.type in {"string", "select"}:
if isinstance(result[parameter.name], str):
transformed_result[parameter.name] = result[parameter.name]
if parameter.type == SegmentType.NUMBER:
transformed = self._transform_number(param_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
elif parameter.type == SegmentType.BOOLEAN:
if isinstance(result[parameter.name], (bool, int)):
transformed_result[parameter.name] = bool(result[parameter.name])
# elif isinstance(result[parameter.name], str):
# if result[parameter.name].lower() in ["true", "false"]:
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true")
elif parameter.type == SegmentType.STRING:
if isinstance(param_value, str):
transformed_result[parameter.name] = param_value
elif parameter.is_array_type():
if isinstance(result[parameter.name], list):
if isinstance(param_value, list):
nested_type = parameter.element_type()
assert nested_type is not None
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
transformed_result[parameter.name] = segment_value
for item in result[parameter.name]:
if nested_type == "number":
if isinstance(item, int | float):
segment_value.value.append(item)
elif isinstance(item, str):
try:
if "." in item:
segment_value.value.append(float(item))
else:
segment_value.value.append(int(item))
except ValueError:
pass
elif nested_type == "string":
for item in param_value:
if nested_type == SegmentType.NUMBER:
transformed = self._transform_number(item)
if transformed is not None:
segment_value.value.append(transformed)
elif nested_type == SegmentType.STRING:
if isinstance(item, str):
segment_value.value.append(item)
elif nested_type == "object":
elif nested_type == SegmentType.OBJECT:
if isinstance(item, dict):
segment_value.value.append(item)
elif nested_type == SegmentType.BOOLEAN:
if isinstance(item, bool):
segment_value.value.append(item)
if parameter.name not in transformed_result:
if parameter.type == "number":
transformed_result[parameter.name] = 0
elif parameter.type == "bool":
transformed_result[parameter.name] = False
elif parameter.type in {"string", "select"}:
transformed_result[parameter.name] = ""
elif parameter.type.startswith("array"):
if parameter.type.is_array_type():
transformed_result[parameter.name] = build_segment_with_type(
segment_type=SegmentType(parameter.type), value=[]
)
elif parameter.type in (SegmentType.STRING, SegmentType.SECRET):
transformed_result[parameter.name] = ""
elif parameter.type == SegmentType.NUMBER:
transformed_result[parameter.name] = 0
elif parameter.type == SegmentType.BOOLEAN:
transformed_result[parameter.name] = False
else:
raise AssertionError("this statement should be unreachable.")
return transformed_result

View File

@ -2,6 +2,7 @@ from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
from core.variables import SegmentType, Variable
from core.variables.segments import BooleanSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
@ -158,8 +159,8 @@ class VariableAssignerNode(BaseNode):
def get_zero_value(t: SegmentType):
# TODO(QuantumGhost): this should be a method of `SegmentType`.
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return variable_factory.build_segment([])
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN:
return variable_factory.build_segment_with_type(t, [])
case SegmentType.OBJECT:
return variable_factory.build_segment({})
case SegmentType.STRING:
@ -170,5 +171,7 @@ def get_zero_value(t: SegmentType):
return variable_factory.build_segment(0.0)
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case SegmentType.BOOLEAN:
return BooleanSegment(value=False)
case _:
raise VariableOperatorNodeError(f"unsupported variable type: {t}")

View File

@ -4,9 +4,11 @@ from core.variables import SegmentType
EMPTY_VALUE_MAPPING = {
SegmentType.STRING: "",
SegmentType.NUMBER: 0,
SegmentType.BOOLEAN: False,
SegmentType.OBJECT: {},
SegmentType.ARRAY_ANY: [],
SegmentType.ARRAY_STRING: [],
SegmentType.ARRAY_NUMBER: [],
SegmentType.ARRAY_OBJECT: [],
SegmentType.ARRAY_BOOLEAN: [],
}

View File

@ -16,28 +16,15 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
SegmentType.NUMBER,
SegmentType.INTEGER,
SegmentType.FLOAT,
SegmentType.BOOLEAN,
}
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
# Only number variable can be added, subtracted, multiplied or divided
return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
case Operation.APPEND | Operation.EXTEND:
case Operation.APPEND | Operation.EXTEND | Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
# Only array variable can be appended or extended
return variable_type in {
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
case Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
# Only array variable can have elements removed
return variable_type in {
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
return variable_type.is_array_type()
case _:
return False
@ -50,7 +37,7 @@ def is_variable_input_supported(*, operation: Operation):
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
match variable_type:
case SegmentType.STRING | SegmentType.OBJECT:
case SegmentType.STRING | SegmentType.OBJECT | SegmentType.BOOLEAN:
return operation in {Operation.OVER_WRITE, Operation.SET}
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return operation in {
@ -72,6 +59,9 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
case SegmentType.STRING:
return isinstance(value, str)
case SegmentType.BOOLEAN:
return isinstance(value, bool)
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
if not isinstance(value, int | float):
return False
@ -91,6 +81,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
return isinstance(value, int | float)
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
return isinstance(value, dict)
case SegmentType.ARRAY_BOOLEAN if operation == Operation.APPEND:
return isinstance(value, bool)
# Array & Extend / Overwrite
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
@ -101,6 +93,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
case SegmentType.ARRAY_BOOLEAN if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, bool) for item in value)
case _:
return False

View File

@ -45,5 +45,5 @@ class SubVariableCondition(BaseModel):
class Condition(BaseModel):
variable_selector: list[str]
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None = None
value: str | Sequence[str] | bool | None = None
sub_variable_condition: SubVariableCondition | None = None

View File

@ -1,13 +1,27 @@
import json
from collections.abc import Sequence
from typing import Any, Literal
from typing import Any, Literal, Union
from core.file import FileAttribute, file_manager
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayBooleanSegment, BooleanSegment
from core.workflow.entities.variable_pool import VariablePool
from .entities import Condition, SubCondition, SupportedComparisonOperator
def _convert_to_bool(value: Any) -> bool:
if isinstance(value, int):
return bool(value)
if isinstance(value, str):
loaded = json.loads(value)
if isinstance(loaded, (int, bool)):
return bool(loaded)
raise TypeError(f"unexpected value: type={type(value)}, value={value}")
class ConditionProcessor:
def process_conditions(
self,
@ -48,9 +62,16 @@ class ConditionProcessor:
)
else:
actual_value = variable.value if variable else None
expected_value = condition.value
expected_value: str | Sequence[str] | bool | list[bool] | None = condition.value
if isinstance(expected_value, str):
expected_value = variable_pool.convert_template(expected_value).text
# Here we need to explicit convet the input string to boolean.
if isinstance(variable, (BooleanSegment, ArrayBooleanSegment)) and expected_value is not None:
# The following two lines is for compatibility with existing workflows.
if isinstance(expected_value, list):
expected_value = [_convert_to_bool(i) for i in expected_value]
else:
expected_value = _convert_to_bool(expected_value)
input_conditions.append(
{
"actual_value": actual_value,
@ -77,7 +98,7 @@ def _evaluate_condition(
*,
operator: SupportedComparisonOperator,
value: Any,
expected: str | Sequence[str] | None,
expected: Union[str, Sequence[str], bool | Sequence[bool], None],
) -> bool:
match operator:
case "contains":
@ -130,7 +151,7 @@ def _assert_contains(*, value: Any, expected: Any) -> bool:
if not value:
return False
if not isinstance(value, str | list):
if not isinstance(value, (str, list)):
raise ValueError("Invalid actual value type: string or array")
if expected not in value:
@ -142,7 +163,7 @@ def _assert_not_contains(*, value: Any, expected: Any) -> bool:
if not value:
return True
if not isinstance(value, str | list):
if not isinstance(value, (str, list)):
raise ValueError("Invalid actual value type: string or array")
if expected in value:
@ -178,8 +199,8 @@ def _assert_is(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, str):
raise ValueError("Invalid actual value type: string")
if not isinstance(value, (str, bool)):
raise ValueError("Invalid actual value type: string or boolean")
if value != expected:
return False
@ -190,8 +211,8 @@ def _assert_is_not(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, str):
raise ValueError("Invalid actual value type: string")
if not isinstance(value, (str, bool)):
raise ValueError("Invalid actual value type: string or boolean")
if value == expected:
return False
@ -214,10 +235,13 @@ def _assert_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
raise ValueError("Invalid actual value type: number")
if not isinstance(value, (int, float, bool)):
raise ValueError("Invalid actual value type: number or boolean")
if isinstance(value, int):
# Handle boolean comparison
if isinstance(value, bool):
expected = bool(expected)
elif isinstance(value, int):
expected = int(expected)
else:
expected = float(expected)
@ -231,10 +255,13 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
raise ValueError("Invalid actual value type: number")
if not isinstance(value, (int, float, bool)):
raise ValueError("Invalid actual value type: number or boolean")
if isinstance(value, int):
# Handle boolean comparison
if isinstance(value, bool):
expected = bool(expected)
elif isinstance(value, int):
expected = int(expected)
else:
expected = float(expected)
@ -248,7 +275,7 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
@ -265,7 +292,7 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
@ -282,7 +309,7 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
@ -299,7 +326,7 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):

View File

@ -7,11 +7,13 @@ from core.file import File
from core.variables.exc import VariableError
from core.variables.segments import (
ArrayAnySegment,
ArrayBooleanSegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArraySegment,
ArrayStringSegment,
BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
@ -23,10 +25,12 @@ from core.variables.segments import (
from core.variables.types import SegmentType
from core.variables.variables import (
ArrayAnyVariable,
ArrayBooleanVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
BooleanVariable,
FileVariable,
FloatVariable,
IntegerVariable,
@ -49,17 +53,19 @@ class TypeMismatchError(Exception):
# Define the constant
SEGMENT_TO_VARIABLE_MAP = {
StringSegment: StringVariable,
IntegerSegment: IntegerVariable,
FloatSegment: FloatVariable,
ObjectSegment: ObjectVariable,
FileSegment: FileVariable,
ArrayStringSegment: ArrayStringVariable,
ArrayAnySegment: ArrayAnyVariable,
ArrayBooleanSegment: ArrayBooleanVariable,
ArrayFileSegment: ArrayFileVariable,
ArrayNumberSegment: ArrayNumberVariable,
ArrayObjectSegment: ArrayObjectVariable,
ArrayFileSegment: ArrayFileVariable,
ArrayAnySegment: ArrayAnyVariable,
ArrayStringSegment: ArrayStringVariable,
BooleanSegment: BooleanVariable,
FileSegment: FileVariable,
FloatSegment: FloatVariable,
IntegerSegment: IntegerVariable,
NoneSegment: NoneVariable,
ObjectSegment: ObjectVariable,
StringSegment: StringVariable,
}
@ -99,6 +105,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
mapping = dict(mapping)
mapping["value_type"] = SegmentType.FLOAT
result = FloatVariable.model_validate(mapping)
case SegmentType.BOOLEAN:
result = BooleanVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f"invalid number value {value}")
case SegmentType.OBJECT if isinstance(value, dict):
@ -109,6 +117,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_BOOLEAN if isinstance(value, list):
result = ArrayBooleanVariable.model_validate(mapping)
case _:
raise VariableError(f"not supported value type {value_type}")
if result.size > dify_config.MAX_VARIABLE_SIZE:
@ -129,6 +139,8 @@ def build_segment(value: Any, /) -> Segment:
return NoneSegment()
if isinstance(value, str):
return StringSegment(value=value)
if isinstance(value, bool):
return BooleanSegment(value=value)
if isinstance(value, int):
return IntegerSegment(value=value)
if isinstance(value, float):
@ -152,6 +164,8 @@ def build_segment(value: Any, /) -> Segment:
return ArrayStringSegment(value=value)
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return ArrayNumberSegment(value=value)
case SegmentType.BOOLEAN:
return ArrayBooleanSegment(value=value)
case SegmentType.OBJECT:
return ArrayObjectSegment(value=value)
case SegmentType.FILE:
@ -170,6 +184,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
SegmentType.INTEGER: IntegerSegment,
SegmentType.FLOAT: FloatSegment,
SegmentType.FILE: FileSegment,
SegmentType.BOOLEAN: BooleanSegment,
SegmentType.OBJECT: ObjectSegment,
# Array types
SegmentType.ARRAY_ANY: ArrayAnySegment,
@ -177,6 +192,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
SegmentType.ARRAY_NUMBER: ArrayNumberSegment,
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
SegmentType.ARRAY_FILE: ArrayFileSegment,
SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment,
}
@ -225,6 +241,8 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
return ArrayAnySegment(value=value)
elif segment_type == SegmentType.ARRAY_STRING:
return ArrayStringSegment(value=value)
elif segment_type == SegmentType.ARRAY_BOOLEAN:
return ArrayBooleanSegment(value=value)
elif segment_type == SegmentType.ARRAY_NUMBER:
return ArrayNumberSegment(value=value)
elif segment_type == SegmentType.ARRAY_OBJECT:

11
api/lazy_load_class.py Normal file
View File

@ -0,0 +1,11 @@
from tests.integration_tests.utils.parent_class import ParentClass
class LazyLoadChildClass(ParentClass):
"""Test lazy load child class for module import helper tests"""
def __init__(self, name):
super().__init__(name)
def get_name(self):
return self.name

View File

@ -1,10 +1,10 @@
import re
import sys
from collections.abc import Mapping
from typing import Any
from flask import current_app, got_request_exception
from flask_restx import Api
from werkzeug.datastructures import Headers
from werkzeug.exceptions import HTTPException
from werkzeug.http import HTTP_STATUS_CODES
@ -12,97 +12,125 @@ from core.errors.error import AppInvokeQuotaExceededError
def http_status_message(code):
"""Maps an HTTP status code to the textual status"""
return HTTP_STATUS_CODES.get(code, "")
def register_external_error_handlers(api: Api) -> None:
"""Register error handlers for the API using decorators.
:param api: The Flask-RestX Api instance
"""
@api.errorhandler(HTTPException)
def handle_http_exception(e: HTTPException):
"""Handle HTTP exceptions."""
got_request_exception.send(current_app, exception=e)
# If Werkzeug already prepared a Response, just use it.
if getattr(e, "response", None) is not None:
return e.response
if e.response is not None:
return e.get_response()
status_code = getattr(e, "code", 500) or 500
# Build a safe, dict-like payload
headers = Headers()
status_code = e.code
default_data = {
"code": re.sub(r"(?<!^)(?=[A-Z])", "_", type(e).__name__).lower(),
"message": getattr(e, "description", http_status_message(status_code)),
"status": status_code,
}
if default_data["message"] == "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)":
if (
default_data["message"]
and default_data["message"] == "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
):
default_data["message"] = "Invalid JSON payload received or JSON payload is empty."
# Use headers on the exception if present; otherwise none.
headers = {}
exc_headers = getattr(e, "headers", None)
if exc_headers:
headers.update(exc_headers)
headers = e.get_response().headers
# Payload per status
# Handle specific status codes
if status_code == 406 and api.default_mediatype is None:
data = {"code": "not_acceptable", "message": default_data["message"], "status": status_code}
return data, status_code, headers
supported_mediatypes = list(api.representations.keys())
fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain"
data = {"code": "not_acceptable", "message": default_data.get("message")}
resp = api.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype)
elif status_code == 400:
msg = default_data["message"]
if isinstance(msg, Mapping) and msg:
# Convert param errors like {"field": "reason"} into a friendly shape
param_key, param_value = next(iter(msg.items()))
data = {
"code": "invalid_param",
"message": str(param_value),
"params": param_key,
"status": status_code,
}
if isinstance(default_data.get("message"), dict):
param_key, param_value = list(default_data.get("message", {}).items())[0]
data = {"code": "invalid_param", "message": param_value, "params": param_key}
else:
data = {**default_data}
data.setdefault("code", "unknown")
return data, status_code, headers
data = default_data
if "code" not in data:
data["code"] = "unknown"
resp = api.make_response(data, status_code, headers)
else:
data = {**default_data}
data.setdefault("code", "unknown")
# If you need WWW-Authenticate for 401, add it to headers
if status_code == 401:
headers["WWW-Authenticate"] = 'Bearer realm="api"'
return data, status_code, headers
data = default_data
if "code" not in data:
data["code"] = "unknown"
resp = api.make_response(data, status_code, headers)
if status_code == 401:
resp = api.unauthorized(resp)
# Remove duplicate Content-Length header
remove_headers = ("Content-Length",)
for header in remove_headers:
headers.pop(header, None)
return resp
@api.errorhandler(ValueError)
def handle_value_error(e: ValueError):
"""Handle ValueError exceptions."""
got_request_exception.send(current_app, exception=e)
status_code = 400
data = {"code": "invalid_param", "message": str(e), "status": status_code}
return data, status_code
data = {
"code": "invalid_param",
"message": str(e),
"status": status_code,
}
return api.make_response(data, status_code)
@api.errorhandler(AppInvokeQuotaExceededError)
def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
"""Handle AppInvokeQuotaExceededError exceptions."""
got_request_exception.send(current_app, exception=e)
status_code = 429
data = {"code": "too_many_requests", "message": str(e), "status": status_code}
return data, status_code
data = {
"code": "too_many_requests",
"message": str(e),
"status": status_code,
}
return api.make_response(data, status_code)
@api.errorhandler(Exception)
def handle_general_exception(e: Exception):
"""Handle general exceptions."""
got_request_exception.send(current_app, exception=e)
headers = Headers()
status_code = 500
data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)})
default_data = {
"message": http_status_message(status_code),
}
# 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response)
if not isinstance(data, Mapping):
data = {"message": str(e)}
data = getattr(e, "data", default_data)
data.setdefault("code", "unknown")
data.setdefault("status", status_code)
# Log stack
# Log server errors
exc_info: Any = sys.exc_info()
if exc_info[1] is None:
exc_info = None
current_app.log_exception(exc_info)
return data, status_code
if "code" not in data:
data["code"] = "unknown"
# Remove duplicate Content-Length header
remove_headers = ("Content-Length",)
for header in remove_headers:
headers.pop(header, None)
return api.make_response(data, status_code, headers)
class ExternalApi(Api):

View File

@ -20,3 +20,6 @@ ignore_missing_imports=True
[mypy-flask_restx.inputs]
ignore_missing_imports=True
[mypy-google.cloud.storage]
ignore_missing_imports=True

View File

@ -2344,9 +2344,13 @@ class SegmentService:
@classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
segments = (
db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count)
.filter(
# Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0:
return
index_node_ids = (
db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id)
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
@ -2354,15 +2358,7 @@ class SegmentService:
)
.all()
)
if not segments:
return
index_node_ids = [seg.index_node_id for seg in segments]
total_words = sum(seg.word_count for seg in segments)
document.word_count -= total_words
db.session.add(document)
index_node_ids = [index_node_id[0] for index_node_id in index_node_ids]
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)
db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete()

View File

@ -186,11 +186,6 @@ class PluginService:
mime_type, _ = guess_type(asset_file)
return manager.fetch_asset(tenant_id, asset_file), mime_type or "application/octet-stream"
@staticmethod
def extract_asset(tenant_id: str, plugin_unique_identifier: str, file_name: str) -> bytes:
manager = PluginAssetManager()
return manager.extract_asset(tenant_id, plugin_unique_identifier, file_name)
@staticmethod
def check_plugin_unique_identifier(tenant_id: str, plugin_unique_identifier: str) -> bool:
"""
@ -497,11 +492,3 @@ class PluginService:
"""
manager = PluginInstaller()
return manager.check_tools_existence(tenant_id, provider_ids)
@staticmethod
def fetch_plugin_readme(tenant_id: str, plugin_unique_identifier: str, language: str) -> str:
"""
Fetch plugin readme
"""
manager = PluginInstaller()
return manager.fetch_plugin_readme(tenant_id, plugin_unique_identifier, language)

View File

@ -410,18 +410,18 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotations with specific keywords
unique_keyword = f"unique_{fake.uuid4()[:8]}"
unique_keyword = fake.word()
annotation_args = {
"question": f"Question with {unique_keyword} keyword",
"answer": f"Answer with {unique_keyword} keyword",
}
AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id)
# Create another annotation without the keyword
other_args = {
"question": "Different question without special term",
"answer": "Different answer without special content",
"question": "Question without keyword",
"answer": "Answer without keyword",
}
AppAnnotationService.insert_app_annotation_directly(other_args, app.id)
# Search with keyword

View File

@ -23,6 +23,7 @@ class TestSegmentTypeIsArrayType:
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
]
expected_non_array_types = [
SegmentType.INTEGER,
@ -34,6 +35,7 @@ class TestSegmentTypeIsArrayType:
SegmentType.FILE,
SegmentType.NONE,
SegmentType.GROUP,
SegmentType.BOOLEAN,
]
for seg_type in expected_array_types:

View File

@ -0,0 +1,729 @@
"""
Comprehensive unit tests for SegmentType.is_valid and SegmentType._validate_array methods.
This module provides thorough testing of the validation logic for all SegmentType values,
including edge cases, error conditions, and different ArrayValidation strategies.
"""
from dataclasses import dataclass
from typing import Any
import pytest
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
from core.variables.types import ArrayValidation, SegmentType
def create_test_file(
file_type: FileType = FileType.DOCUMENT,
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
filename: str = "test.txt",
extension: str = ".txt",
mime_type: str = "text/plain",
size: int = 1024,
) -> File:
"""Factory function to create File objects for testing."""
return File(
tenant_id="test-tenant",
type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
storage_key="test-storage-key",
)
@dataclass
class ValidationTestCase:
"""Test case data structure for validation tests."""
segment_type: SegmentType
value: Any
expected: bool
description: str
def get_id(self):
return self.description
@dataclass
class ArrayValidationTestCase:
"""Test case data structure for array validation tests."""
segment_type: SegmentType
value: Any
array_validation: ArrayValidation
expected: bool
description: str
def get_id(self):
return self.description
# Test data construction functions
def get_boolean_cases() -> list[ValidationTestCase]:
return [
# valid values
ValidationTestCase(SegmentType.BOOLEAN, True, True, "True boolean"),
ValidationTestCase(SegmentType.BOOLEAN, False, True, "False boolean"),
# Invalid values
ValidationTestCase(SegmentType.BOOLEAN, 1, False, "Integer 1 (not boolean)"),
ValidationTestCase(SegmentType.BOOLEAN, 0, False, "Integer 0 (not boolean)"),
ValidationTestCase(SegmentType.BOOLEAN, "true", False, "String 'true'"),
ValidationTestCase(SegmentType.BOOLEAN, "false", False, "String 'false'"),
ValidationTestCase(SegmentType.BOOLEAN, None, False, "None value"),
ValidationTestCase(SegmentType.BOOLEAN, [], False, "Empty list"),
ValidationTestCase(SegmentType.BOOLEAN, {}, False, "Empty dict"),
]
def get_number_cases() -> list[ValidationTestCase]:
"""Get test cases for valid number values."""
return [
# valid values
ValidationTestCase(SegmentType.NUMBER, 42, True, "Positive integer"),
ValidationTestCase(SegmentType.NUMBER, -42, True, "Negative integer"),
ValidationTestCase(SegmentType.NUMBER, 0, True, "Zero integer"),
ValidationTestCase(SegmentType.NUMBER, 3.14, True, "Positive float"),
ValidationTestCase(SegmentType.NUMBER, -3.14, True, "Negative float"),
ValidationTestCase(SegmentType.NUMBER, 0.0, True, "Zero float"),
ValidationTestCase(SegmentType.NUMBER, float("inf"), True, "Positive infinity"),
ValidationTestCase(SegmentType.NUMBER, float("-inf"), True, "Negative infinity"),
ValidationTestCase(SegmentType.NUMBER, float("nan"), True, "float(NaN)"),
# invalid number values
ValidationTestCase(SegmentType.NUMBER, "42", False, "String number"),
ValidationTestCase(SegmentType.NUMBER, None, False, "None value"),
ValidationTestCase(SegmentType.NUMBER, [], False, "Empty list"),
ValidationTestCase(SegmentType.NUMBER, {}, False, "Empty dict"),
ValidationTestCase(SegmentType.NUMBER, "3.14", False, "String float"),
]
def get_string_cases() -> list[ValidationTestCase]:
"""Get test cases for valid string values."""
return [
# valid values
ValidationTestCase(SegmentType.STRING, "", True, "Empty string"),
ValidationTestCase(SegmentType.STRING, "hello", True, "Simple string"),
ValidationTestCase(SegmentType.STRING, "🚀", True, "Unicode emoji"),
ValidationTestCase(SegmentType.STRING, "line1\nline2", True, "Multiline string"),
# invalid values
ValidationTestCase(SegmentType.STRING, 123, False, "Integer"),
ValidationTestCase(SegmentType.STRING, 3.14, False, "Float"),
ValidationTestCase(SegmentType.STRING, True, False, "Boolean"),
ValidationTestCase(SegmentType.STRING, None, False, "None value"),
ValidationTestCase(SegmentType.STRING, [], False, "Empty list"),
ValidationTestCase(SegmentType.STRING, {}, False, "Empty dict"),
]
def get_object_cases() -> list[ValidationTestCase]:
"""Get test cases for valid object values."""
return [
# valid cases
ValidationTestCase(SegmentType.OBJECT, {}, True, "Empty dict"),
ValidationTestCase(SegmentType.OBJECT, {"key": "value"}, True, "Simple dict"),
ValidationTestCase(SegmentType.OBJECT, {"a": 1, "b": 2}, True, "Dict with numbers"),
ValidationTestCase(SegmentType.OBJECT, {"nested": {"key": "value"}}, True, "Nested dict"),
ValidationTestCase(SegmentType.OBJECT, {"list": [1, 2, 3]}, True, "Dict with list"),
ValidationTestCase(SegmentType.OBJECT, {"mixed": [1, "two", {"three": 3}]}, True, "Complex dict"),
# invalid cases
ValidationTestCase(SegmentType.OBJECT, "not a dict", False, "String"),
ValidationTestCase(SegmentType.OBJECT, 123, False, "Integer"),
ValidationTestCase(SegmentType.OBJECT, 3.14, False, "Float"),
ValidationTestCase(SegmentType.OBJECT, True, False, "Boolean"),
ValidationTestCase(SegmentType.OBJECT, None, False, "None value"),
ValidationTestCase(SegmentType.OBJECT, [], False, "Empty list"),
ValidationTestCase(SegmentType.OBJECT, [1, 2, 3], False, "List with values"),
]
def get_secret_cases() -> list[ValidationTestCase]:
"""Get test cases for valid secret values."""
return [
# valid cases
ValidationTestCase(SegmentType.SECRET, "", True, "Empty secret"),
ValidationTestCase(SegmentType.SECRET, "secret", True, "Simple secret"),
ValidationTestCase(SegmentType.SECRET, "api_key_123", True, "API key format"),
ValidationTestCase(SegmentType.SECRET, "very_long_secret_key_with_special_chars!@#", True, "Complex secret"),
# invalid cases
ValidationTestCase(SegmentType.SECRET, 123, False, "Integer"),
ValidationTestCase(SegmentType.SECRET, 3.14, False, "Float"),
ValidationTestCase(SegmentType.SECRET, True, False, "Boolean"),
ValidationTestCase(SegmentType.SECRET, None, False, "None value"),
ValidationTestCase(SegmentType.SECRET, [], False, "Empty list"),
ValidationTestCase(SegmentType.SECRET, {}, False, "Empty dict"),
]
def get_file_cases() -> list[ValidationTestCase]:
"""Get test cases for valid file values."""
test_file = create_test_file()
image_file = create_test_file(
file_type=FileType.IMAGE, filename="image.jpg", extension=".jpg", mime_type="image/jpeg"
)
remote_file = create_test_file(
transfer_method=FileTransferMethod.REMOTE_URL, filename="remote.pdf", extension=".pdf"
)
return [
# valid cases
ValidationTestCase(SegmentType.FILE, test_file, True, "Document file"),
ValidationTestCase(SegmentType.FILE, image_file, True, "Image file"),
ValidationTestCase(SegmentType.FILE, remote_file, True, "Remote file"),
# invalid cases
ValidationTestCase(SegmentType.FILE, "not a file", False, "String"),
ValidationTestCase(SegmentType.FILE, 123, False, "Integer"),
ValidationTestCase(SegmentType.FILE, {"filename": "test.txt"}, False, "Dict resembling file"),
ValidationTestCase(SegmentType.FILE, None, False, "None value"),
ValidationTestCase(SegmentType.FILE, [], False, "Empty list"),
ValidationTestCase(SegmentType.FILE, True, False, "Boolean"),
]
def get_none_cases() -> list[ValidationTestCase]:
"""Get test cases for valid none values."""
return [
# valid cases
ValidationTestCase(SegmentType.NONE, None, True, "None value"),
# invalid cases
ValidationTestCase(SegmentType.NONE, "", False, "Empty string"),
ValidationTestCase(SegmentType.NONE, 0, False, "Zero integer"),
ValidationTestCase(SegmentType.NONE, 0.0, False, "Zero float"),
ValidationTestCase(SegmentType.NONE, False, False, "False boolean"),
ValidationTestCase(SegmentType.NONE, [], False, "Empty list"),
ValidationTestCase(SegmentType.NONE, {}, False, "Empty dict"),
ValidationTestCase(SegmentType.NONE, "null", False, "String 'null'"),
]
def get_array_any_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_ANY validation."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_ANY,
[1, "string", 3.14, {"key": "value"}, True],
ArrayValidation.NONE,
True,
"Mixed types with NONE validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_ANY,
[1, "string", 3.14, {"key": "value"}, True],
ArrayValidation.FIRST,
True,
"Mixed types with FIRST validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_ANY,
[1, "string", 3.14, {"key": "value"}, True],
ArrayValidation.ALL,
True,
"Mixed types with ALL validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_ANY, [None, None, None], ArrayValidation.ALL, True, "All None values"
),
]
def get_array_string_validation_none_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_STRING validation with NONE strategy."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
["hello", "world"],
ArrayValidation.NONE,
True,
"Valid strings with NONE validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
[123, 456],
ArrayValidation.NONE,
True,
"Invalid elements with NONE validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
["valid", 123, True],
ArrayValidation.NONE,
True,
"Mixed types with NONE validation",
),
]
def get_array_string_validation_first_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_STRING validation with FIRST strategy."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["hello", "world"], ArrayValidation.FIRST, True, "All valid strings"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
["hello", 123, True],
ArrayValidation.FIRST,
True,
"First valid, others invalid",
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
[123, "hello", "world"],
ArrayValidation.FIRST,
False,
"First invalid, others valid",
),
ArrayValidationTestCase(SegmentType.ARRAY_STRING, [None, "hello"], ArrayValidation.FIRST, False, "First None"),
]
def get_array_string_validation_all_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_STRING validation with ALL strategy."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["hello", "world", "test"], ArrayValidation.ALL, True, "All valid strings"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["hello", 123, "world"], ArrayValidation.ALL, False, "One invalid element"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, [123, 456, 789], ArrayValidation.ALL, False, "All invalid elements"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["valid", None, "also_valid"], ArrayValidation.ALL, False, "Contains None"
),
]
def get_array_number_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_NUMBER validation with different strategies."""
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [1, 2.5, 3], ArrayValidation.NONE, True, "Valid numbers with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, ["not", "numbers"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [42, "not a number"], ArrayValidation.FIRST, True, "First valid number"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, ["not a number", 42], ArrayValidation.FIRST, False, "First invalid"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [3.14, 2.71, 1.41], ArrayValidation.FIRST, True, "All valid floats"
),
# ALL strategy
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [1, 2, 3, 4.5], ArrayValidation.ALL, True, "All valid numbers"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [1, "invalid", 3], ArrayValidation.ALL, False, "One invalid element"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER,
[float("inf"), float("-inf"), float("nan")],
ArrayValidation.ALL,
True,
"Special float values",
),
]
def get_array_object_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_OBJECT validation with different strategies."""
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT, [{}, {"key": "value"}], ArrayValidation.NONE, True, "Valid objects with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT, ["not", "objects"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
[{"valid": "object"}, "not an object"],
ArrayValidation.FIRST,
True,
"First valid object",
),
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
["not an object", {"valid": "object"}],
ArrayValidation.FIRST,
False,
"First invalid",
),
# ALL strategy
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
[{}, {"a": 1}, {"nested": {"key": "value"}}],
ArrayValidation.ALL,
True,
"All valid objects",
),
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
[{"valid": "object"}, "invalid", {"another": "object"}],
ArrayValidation.ALL,
False,
"One invalid element",
),
]
def get_array_file_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_FILE validation with different strategies."""
file1 = create_test_file(filename="file1.txt")
file2 = create_test_file(filename="file2.txt")
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.NONE, True, "Valid files with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, ["not", "files"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, [file1, "not a file"], ArrayValidation.FIRST, True, "First valid file"
),
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, ["not a file", file1], ArrayValidation.FIRST, False, "First invalid"
),
# ALL strategy
ArrayValidationTestCase(SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.ALL, True, "All valid files"),
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, [file1, "invalid", file2], ArrayValidation.ALL, False, "One invalid element"
),
]
def get_array_boolean_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_BOOLEAN validation with different strategies."""
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, False, True], ArrayValidation.NONE, True, "Valid booleans with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [1, 0, "true"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, 1, 0], ArrayValidation.FIRST, True, "First valid boolean"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [1, True, False], ArrayValidation.FIRST, False, "First invalid (integer 1)"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [0, True, False], ArrayValidation.FIRST, False, "First invalid (integer 0)"
),
# ALL strategy
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, False, True, False], ArrayValidation.ALL, True, "All valid booleans"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, 1, False], ArrayValidation.ALL, False, "One invalid element (integer)"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN,
[True, "false", False],
ArrayValidation.ALL,
False,
"One invalid element (string)",
),
]
class TestSegmentTypeIsValid:
"""Test suite for SegmentType.is_valid method covering all non-array types."""
@pytest.mark.parametrize("case", get_boolean_cases(), ids=lambda case: case.description)
def test_boolean_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_number_cases(), ids=lambda case: case.description)
def test_number_validation(self, case: ValidationTestCase):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_string_cases(), ids=lambda case: case.description)
def test_string_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_object_cases(), ids=lambda case: case.description)
def test_object_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_secret_cases(), ids=lambda case: case.description)
def test_secret_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_file_cases(), ids=lambda case: case.description)
def test_file_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_none_cases(), ids=lambda case: case.description)
def test_none_validation_valid_cases(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
def test_unsupported_segment_type_raises_assertion_error(self):
"""Test that unsupported SegmentType values raise AssertionError."""
# GROUP is not handled in is_valid method
with pytest.raises(AssertionError, match="this statement should be unreachable"):
SegmentType.GROUP.is_valid("any value")
class TestSegmentTypeArrayValidation:
"""Test suite for SegmentType._validate_array method and array type validation."""
def test_array_validation_non_list_values(self):
"""Test that non-list values return False for all array types."""
array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
]
non_list_values = [
"not a list",
123,
3.14,
True,
None,
{"key": "value"},
create_test_file(),
]
for array_type in array_types:
for value in non_list_values:
assert array_type.is_valid(value) is False, f"{array_type} should reject {type(value).__name__}"
def test_empty_array_validation(self):
"""Test that empty arrays are valid for all array types regardless of validation strategy."""
array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
]
validation_strategies = [ArrayValidation.NONE, ArrayValidation.FIRST, ArrayValidation.ALL]
for array_type in array_types:
for strategy in validation_strategies:
assert array_type.is_valid([], strategy) is True, (
f"{array_type} should accept empty array with {strategy}"
)
@pytest.mark.parametrize("case", get_array_any_validation_cases(), ids=lambda case: case.description)
def test_array_any_validation(self, case):
"""Test ARRAY_ANY validation accepts any list regardless of content."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_string_validation_none_cases(), ids=lambda case: case.description)
def test_array_string_validation_with_none_strategy(self, case):
"""Test ARRAY_STRING validation with NONE strategy (no element validation)."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_string_validation_first_cases(), ids=lambda case: case.description)
def test_array_string_validation_with_first_strategy(self, case):
"""Test ARRAY_STRING validation with FIRST strategy (validate first element only)."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_string_validation_all_cases(), ids=lambda case: case.description)
def test_array_string_validation_with_all_strategy(self, case):
"""Test ARRAY_STRING validation with ALL strategy (validate all elements)."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_number_validation_cases(), ids=lambda case: case.description)
def test_array_number_validation_with_different_strategies(self, case):
"""Test ARRAY_NUMBER validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_object_validation_cases(), ids=lambda case: case.description)
def test_array_object_validation_with_different_strategies(self, case):
"""Test ARRAY_OBJECT validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_file_validation_cases(), ids=lambda case: case.description)
def test_array_file_validation_with_different_strategies(self, case):
"""Test ARRAY_FILE validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_boolean_validation_cases(), ids=lambda case: case.description)
def test_array_boolean_validation_with_different_strategies(self, case):
"""Test ARRAY_BOOLEAN validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
def test_default_array_validation_strategy(self):
"""Test that default array validation strategy is FIRST."""
# When no array_validation parameter is provided, it should default to FIRST
assert SegmentType.ARRAY_STRING.is_valid(["valid", 123]) is False # First element valid
assert SegmentType.ARRAY_STRING.is_valid([123, "valid"]) is False # First element invalid
assert SegmentType.ARRAY_NUMBER.is_valid([42, "invalid"]) is False # First element valid
assert SegmentType.ARRAY_NUMBER.is_valid(["invalid", 42]) is False # First element invalid
def test_array_validation_edge_cases(self):
"""Test edge cases for array validation."""
# Test with nested arrays (should be invalid for specific array types)
nested_array = [["nested", "array"], ["another", "nested"]]
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.FIRST) is False
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.ALL) is False
assert SegmentType.ARRAY_ANY.is_valid(nested_array, ArrayValidation.ALL) is True
# Test with very large arrays (performance consideration)
large_valid_array = ["string"] * 1000
large_mixed_array = ["string"] * 999 + [123] # Last element invalid
assert SegmentType.ARRAY_STRING.is_valid(large_valid_array, ArrayValidation.ALL) is True
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.ALL) is False
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.FIRST) is True
class TestSegmentTypeValidationIntegration:
"""Integration tests for SegmentType validation covering interactions between methods."""
def test_non_array_types_ignore_array_validation_parameter(self):
"""Test that non-array types ignore the array_validation parameter."""
non_array_types = [
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
SegmentType.OBJECT,
SegmentType.SECRET,
SegmentType.FILE,
SegmentType.NONE,
]
for segment_type in non_array_types:
# Create appropriate valid value for each type
valid_value: Any
if segment_type == SegmentType.STRING:
valid_value = "test"
elif segment_type == SegmentType.NUMBER:
valid_value = 42
elif segment_type == SegmentType.BOOLEAN:
valid_value = True
elif segment_type == SegmentType.OBJECT:
valid_value = {"key": "value"}
elif segment_type == SegmentType.SECRET:
valid_value = "secret"
elif segment_type == SegmentType.FILE:
valid_value = create_test_file()
elif segment_type == SegmentType.NONE:
valid_value = None
else:
continue # Skip unsupported types
# All array validation strategies should give the same result
result_none = segment_type.is_valid(valid_value, ArrayValidation.NONE)
result_first = segment_type.is_valid(valid_value, ArrayValidation.FIRST)
result_all = segment_type.is_valid(valid_value, ArrayValidation.ALL)
assert result_none == result_first == result_all == True, (
f"{segment_type} should ignore array_validation parameter"
)
def test_comprehensive_type_coverage(self):
"""Test that all SegmentType enum values are covered in validation tests."""
all_segment_types = set(SegmentType)
# Types that should be handled by is_valid method
handled_types = {
# Non-array types
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
SegmentType.OBJECT,
SegmentType.SECRET,
SegmentType.FILE,
SegmentType.NONE,
# Array types
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
}
# Types that are not handled by is_valid (should raise AssertionError)
unhandled_types = {
SegmentType.GROUP,
SegmentType.INTEGER, # Handled by NUMBER validation logic
SegmentType.FLOAT, # Handled by NUMBER validation logic
}
# Verify all types are accounted for
assert handled_types | unhandled_types == all_segment_types, "All SegmentType values should be categorized"
# Test that handled types work correctly
for segment_type in handled_types:
if segment_type.is_array_type():
# Test with empty array (should always be valid)
assert segment_type.is_valid([]) is True, f"{segment_type} should accept empty array"
else:
# Test with appropriate valid value
if segment_type == SegmentType.STRING:
assert segment_type.is_valid("test") is True
elif segment_type == SegmentType.NUMBER:
assert segment_type.is_valid(42) is True
elif segment_type == SegmentType.BOOLEAN:
assert segment_type.is_valid(True) is True
elif segment_type == SegmentType.OBJECT:
assert segment_type.is_valid({}) is True
elif segment_type == SegmentType.SECRET:
assert segment_type.is_valid("secret") is True
elif segment_type == SegmentType.FILE:
assert segment_type.is_valid(create_test_file()) is True
elif segment_type == SegmentType.NONE:
assert segment_type.is_valid(None) is True
def test_boolean_vs_integer_type_distinction(self):
"""Test the important distinction between boolean and integer types in validation."""
# This tests the comment in the code about bool being a subclass of int
# Boolean type should only accept actual booleans, not integers
assert SegmentType.BOOLEAN.is_valid(True) is True
assert SegmentType.BOOLEAN.is_valid(False) is True
assert SegmentType.BOOLEAN.is_valid(1) is False # Integer 1, not boolean
assert SegmentType.BOOLEAN.is_valid(0) is False # Integer 0, not boolean
# Number type should accept both integers and floats, including booleans (since bool is subclass of int)
assert SegmentType.NUMBER.is_valid(42) is True
assert SegmentType.NUMBER.is_valid(3.14) is True
assert SegmentType.NUMBER.is_valid(True) is True # bool is subclass of int
assert SegmentType.NUMBER.is_valid(False) is True # bool is subclass of int
def test_array_validation_recursive_behavior(self):
"""Test that array validation correctly handles recursive validation calls."""
# When validating array elements, _validate_array calls is_valid recursively
# with ArrayValidation.NONE to avoid infinite recursion
# Test nested validation doesn't cause issues
nested_arrays = [["inner", "array"], ["another", "inner"]]
# ARRAY_ANY should accept nested arrays
assert SegmentType.ARRAY_ANY.is_valid(nested_arrays, ArrayValidation.ALL) is True
# ARRAY_STRING should reject nested arrays (first element is not a string)
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.FIRST) is False
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.ALL) is False

View File

@ -0,0 +1,27 @@
from core.variables.types import SegmentType
from core.workflow.nodes.parameter_extractor.entities import ParameterConfig
class TestParameterConfig:
def test_select_type(self):
data = {
"name": "yes_or_no",
"type": "select",
"options": ["yes", "no"],
"description": "a simple select made of `yes` and `no`",
"required": True,
}
pc = ParameterConfig.model_validate(data)
assert pc.type == SegmentType.STRING
assert pc.options == data["options"]
def test_validate_bool_type(self):
data = {
"name": "boolean",
"type": "bool",
"description": "a simple boolean parameter",
"required": True,
}
pc = ParameterConfig.model_validate(data)
assert pc.type == SegmentType.BOOLEAN

View File

@ -0,0 +1,567 @@
"""
Test cases for ParameterExtractorNode._validate_result and _transform_result methods.
"""
from dataclasses import dataclass
from typing import Any
import pytest
from core.model_runtime.entities import LLMMode
from core.variables.types import SegmentType
from core.workflow.nodes.llm import ModelConfig, VisionConfig
from core.workflow.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData
from core.workflow.nodes.parameter_extractor.exc import (
InvalidNumberOfParametersError,
InvalidSelectValueError,
InvalidValueTypeError,
RequiredParameterMissingError,
)
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from factories.variable_factory import build_segment_with_type
@dataclass
class ValidTestCase:
"""Test case data for valid scenarios."""
name: str
parameters: list[ParameterConfig]
result: dict[str, Any]
def get_name(self) -> str:
return self.name
@dataclass
class ErrorTestCase:
"""Test case data for error scenarios."""
name: str
parameters: list[ParameterConfig]
result: dict[str, Any]
expected_exception: type[Exception]
expected_message: str
def get_name(self) -> str:
return self.name
@dataclass
class TransformTestCase:
"""Test case data for transformation scenarios."""
name: str
parameters: list[ParameterConfig]
input_result: dict[str, Any]
expected_result: dict[str, Any]
def get_name(self) -> str:
return self.name
class TestParameterExtractorNodeMethods:
"""Test helper class that provides access to the methods under test."""
def validate_result(self, data: ParameterExtractorNodeData, result: dict[str, Any]) -> dict[str, Any]:
"""Wrapper to call _validate_result method."""
node = ParameterExtractorNode.__new__(ParameterExtractorNode)
return node._validate_result(data=data, result=result)
def transform_result(self, data: ParameterExtractorNodeData, result: dict[str, Any]) -> dict[str, Any]:
"""Wrapper to call _transform_result method."""
node = ParameterExtractorNode.__new__(ParameterExtractorNode)
return node._transform_result(data=data, result=result)
class TestValidateResult:
"""Test cases for _validate_result method."""
@staticmethod
def get_valid_test_cases() -> list[ValidTestCase]:
"""Get test cases that should pass validation."""
return [
ValidTestCase(
name="single_string_parameter",
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
result={"name": "John"},
),
ValidTestCase(
name="single_number_parameter_int",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
result={"age": 25},
),
ValidTestCase(
name="single_number_parameter_float",
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
result={"price": 19.99},
),
ValidTestCase(
name="single_bool_parameter_true",
parameters=[
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
],
result={"active": True},
),
ValidTestCase(
name="single_bool_parameter_true",
parameters=[
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
],
result={"active": True},
),
ValidTestCase(
name="single_bool_parameter_false",
parameters=[
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
],
result={"active": False},
),
ValidTestCase(
name="select_parameter_valid_option",
parameters=[
ParameterConfig(
name="status",
type="select", # pyright: ignore[reportArgumentType]
description="Status",
required=True,
options=["active", "inactive"],
)
],
result={"status": "active"},
),
ValidTestCase(
name="array_string_parameter",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
result={"tags": ["tag1", "tag2", "tag3"]},
),
ValidTestCase(
name="array_number_parameter",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
result={"scores": [85, 92.5, 78]},
),
ValidTestCase(
name="array_object_parameter",
parameters=[
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
],
result={"items": [{"name": "item1"}, {"name": "item2"}]},
),
ValidTestCase(
name="multiple_parameters",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True),
],
result={"name": "John", "age": 25, "active": True},
),
ValidTestCase(
name="optional_parameter_present",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
ParameterConfig(name="nickname", type=SegmentType.STRING, description="Nickname", required=False),
],
result={"name": "John", "nickname": "Johnny"},
),
ValidTestCase(
name="empty_array_parameter",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
result={"tags": []},
),
]
@staticmethod
def get_error_test_cases() -> list[ErrorTestCase]:
"""Get test cases that should raise exceptions."""
return [
ErrorTestCase(
name="invalid_number_of_parameters_too_few",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
],
result={"name": "John"},
expected_exception=InvalidNumberOfParametersError,
expected_message="Invalid number of parameters",
),
ErrorTestCase(
name="invalid_number_of_parameters_too_many",
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
result={"name": "John", "age": 25},
expected_exception=InvalidNumberOfParametersError,
expected_message="Invalid number of parameters",
),
ErrorTestCase(
name="invalid_string_value_none",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
],
result={"name": None}, # Parameter present but None value, will trigger type check first
expected_exception=InvalidValueTypeError,
expected_message="Invalid value for parameter name, expected segment type: string, actual_type: none",
),
ErrorTestCase(
name="invalid_select_value",
parameters=[
ParameterConfig(
name="status",
type="select", # type: ignore
description="Status",
required=True,
options=["active", "inactive"],
)
],
result={"status": "pending"},
expected_exception=InvalidSelectValueError,
expected_message="Invalid `select` value for parameter status",
),
ErrorTestCase(
name="invalid_number_value_string",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
result={"age": "twenty-five"},
expected_exception=InvalidValueTypeError,
expected_message="Invalid value for parameter age, expected segment type: number, actual_type: string",
),
ErrorTestCase(
name="invalid_bool_value_string",
parameters=[
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
],
result={"active": "yes"},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter active, expected segment type: boolean, actual_type: string"
),
),
ErrorTestCase(
name="invalid_string_value_number",
parameters=[
ParameterConfig(
name="description", type=SegmentType.STRING, description="Description", required=True
)
],
result={"description": 123},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter description, expected segment type: string, actual_type: integer"
),
),
ErrorTestCase(
name="invalid_array_value_not_list",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
result={"tags": "tag1,tag2,tag3"},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter tags, expected segment type: array[string], actual_type: string"
),
),
ErrorTestCase(
name="invalid_array_number_wrong_element_type",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
result={"scores": [85, "ninety-two", 78]},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter scores, expected segment type: array[number], actual_type: array[any]"
),
),
ErrorTestCase(
name="invalid_array_string_wrong_element_type",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
result={"tags": ["tag1", 123, "tag3"]},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter tags, expected segment type: array[string], actual_type: array[any]"
),
),
ErrorTestCase(
name="invalid_array_object_wrong_element_type",
parameters=[
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
],
result={"items": [{"name": "item1"}, "item2"]},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter items, expected segment type: array[object], actual_type: array[any]"
),
),
ErrorTestCase(
name="required_parameter_missing",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=False),
],
result={"age": 25, "other": "value"}, # Missing required 'name' parameter, but has correct count
expected_exception=RequiredParameterMissingError,
expected_message="Parameter name is required",
),
]
@pytest.mark.parametrize("test_case", get_valid_test_cases(), ids=ValidTestCase.get_name)
def test_validate_result_valid_cases(self, test_case):
"""Test _validate_result with valid inputs."""
helper = TestParameterExtractorNodeMethods()
node_data = ParameterExtractorNodeData(
title="Test Node",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
query=["test_query"],
parameters=test_case.parameters,
reasoning_mode="function_call",
vision=VisionConfig(),
)
result = helper.validate_result(data=node_data, result=test_case.result)
assert result == test_case.result, f"Failed for case: {test_case.name}"
@pytest.mark.parametrize("test_case", get_error_test_cases(), ids=ErrorTestCase.get_name)
def test_validate_result_error_cases(self, test_case):
"""Test _validate_result with invalid inputs that should raise exceptions."""
helper = TestParameterExtractorNodeMethods()
node_data = ParameterExtractorNodeData(
title="Test Node",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
query=["test_query"],
parameters=test_case.parameters,
reasoning_mode="function_call",
vision=VisionConfig(),
)
with pytest.raises(test_case.expected_exception) as exc_info:
helper.validate_result(data=node_data, result=test_case.result)
assert test_case.expected_message in str(exc_info.value), f"Failed for case: {test_case.name}"
class TestTransformResult:
"""Test cases for _transform_result method."""
@staticmethod
def get_transform_test_cases() -> list[TransformTestCase]:
"""Get test cases for result transformation."""
return [
# String parameter transformation
TransformTestCase(
name="string_parameter_present",
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
input_result={"name": "John"},
expected_result={"name": "John"},
),
TransformTestCase(
name="string_parameter_missing",
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
input_result={},
expected_result={"name": ""},
),
# Number parameter transformation
TransformTestCase(
name="number_parameter_int_present",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
input_result={"age": 25},
expected_result={"age": 25},
),
TransformTestCase(
name="number_parameter_float_present",
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
input_result={"price": 19.99},
expected_result={"price": 19.99},
),
TransformTestCase(
name="number_parameter_missing",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
input_result={},
expected_result={"age": 0},
),
# Bool parameter transformation
TransformTestCase(
name="bool_parameter_missing",
parameters=[
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
],
input_result={},
expected_result={"active": False},
),
# Select parameter transformation
TransformTestCase(
name="select_parameter_present",
parameters=[
ParameterConfig(
name="status",
type="select", # type: ignore
description="Status",
required=True,
options=["active", "inactive"],
)
],
input_result={"status": "active"},
expected_result={"status": "active"},
),
TransformTestCase(
name="select_parameter_missing",
parameters=[
ParameterConfig(
name="status",
type="select", # type: ignore
description="Status",
required=True,
options=["active", "inactive"],
)
],
input_result={},
expected_result={"status": ""},
),
# Array parameter transformation - present cases
TransformTestCase(
name="array_string_parameter_present",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
input_result={"tags": ["tag1", "tag2"]},
expected_result={
"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=["tag1", "tag2"])
},
),
TransformTestCase(
name="array_number_parameter_present",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
input_result={"scores": [85, 92.5]},
expected_result={
"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[85, 92.5])
},
),
TransformTestCase(
name="array_number_parameter_with_string_conversion",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
input_result={"scores": [85, "92.5", "78"]},
expected_result={
"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[85, 92.5, 78])
},
),
TransformTestCase(
name="array_object_parameter_present",
parameters=[
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
],
input_result={"items": [{"name": "item1"}, {"name": "item2"}]},
expected_result={
"items": build_segment_with_type(
segment_type=SegmentType.ARRAY_OBJECT, value=[{"name": "item1"}, {"name": "item2"}]
)
},
),
# Array parameter transformation - missing cases
TransformTestCase(
name="array_string_parameter_missing",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
input_result={},
expected_result={"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=[])},
),
TransformTestCase(
name="array_number_parameter_missing",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
input_result={},
expected_result={"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[])},
),
TransformTestCase(
name="array_object_parameter_missing",
parameters=[
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
],
input_result={},
expected_result={"items": build_segment_with_type(segment_type=SegmentType.ARRAY_OBJECT, value=[])},
),
# Multiple parameters transformation
TransformTestCase(
name="multiple_parameters_mixed",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True),
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True),
],
input_result={"name": "John", "age": 25},
expected_result={
"name": "John",
"age": 25,
"active": False,
"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=[]),
},
),
# Number parameter transformation with string conversion
TransformTestCase(
name="number_parameter_string_to_float",
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
input_result={"price": "19.99"},
expected_result={"price": 19.99}, # String not converted, falls back to default
),
TransformTestCase(
name="number_parameter_string_to_int",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
input_result={"age": "25"},
expected_result={"age": 25}, # String not converted, falls back to default
),
TransformTestCase(
name="number_parameter_invalid_string",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
input_result={"age": "invalid_number"},
expected_result={"age": 0}, # Invalid string conversion fails, falls back to default
),
TransformTestCase(
name="number_parameter_non_string_non_number",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
input_result={"age": ["not_a_number"]}, # Non-string, non-number value
expected_result={"age": 0}, # Falls back to default
),
TransformTestCase(
name="array_number_parameter_with_invalid_string_conversion",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
input_result={"scores": [85, "invalid", "78"]},
expected_result={
"scores": build_segment_with_type(
segment_type=SegmentType.ARRAY_NUMBER, value=[85, 78]
) # Invalid string skipped
},
),
]
@pytest.mark.parametrize("test_case", get_transform_test_cases(), ids=TransformTestCase.get_name)
def test_transform_result_cases(self, test_case):
"""Test _transform_result with various inputs."""
helper = TestParameterExtractorNodeMethods()
node_data = ParameterExtractorNodeData(
title="Test Node",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
query=["test_query"],
parameters=test_case.parameters,
reasoning_mode="function_call",
vision=VisionConfig(),
)
result = helper.transform_result(data=node_data, result=test_case.input_result)
assert result == test_case.expected_result, (
f"Failed for case: {test_case.name}. Expected: {test_case.expected_result}, Got: {result}"
)

View File

@ -2,6 +2,8 @@ import time
import uuid
from unittest.mock import MagicMock, Mock
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileSegment
@ -272,3 +274,220 @@ def test_array_file_contains_file_name():
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is True
def _get_test_conditions() -> list:
conditions = [
# Test boolean "is" operator
{"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "true"},
# Test boolean "is not" operator
{"comparison_operator": "is not", "variable_selector": ["start", "bool_false"], "value": "true"},
# Test boolean "=" operator
{"comparison_operator": "=", "variable_selector": ["start", "bool_true"], "value": "1"},
# Test boolean "≠" operator
{"comparison_operator": "", "variable_selector": ["start", "bool_false"], "value": "1"},
# Test boolean "not null" operator
{"comparison_operator": "not null", "variable_selector": ["start", "bool_true"]},
# Test boolean array "contains" operator
{"comparison_operator": "contains", "variable_selector": ["start", "bool_array"], "value": "true"},
# Test boolean "in" operator
{
"comparison_operator": "in",
"variable_selector": ["start", "bool_true"],
"value": ["true", "false"],
},
]
return [Condition.model_validate(i) for i in conditions]
def _get_condition_test_id(c: Condition):
return c.comparison_operator
@pytest.mark.parametrize("condition", _get_test_conditions(), ids=_get_condition_test_id)
def test_execute_if_else_boolean_conditions(condition: Condition):
"""Test IfElseNode with boolean conditions using various operators"""
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool with boolean values
pool = VariablePool(
system_variables=SystemVariable(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)
pool.add(["start", "bool_false"], False)
pool.add(["start", "bool_array"], [True, False, True])
pool.add(["start", "mixed_array"], [True, "false", 1, 0])
node_data = {
"title": "Boolean Test",
"type": "if-else",
"logical_operator": "and",
"conditions": [condition.model_dump()],
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={"id": "if-else", "data": node_data},
)
node.init_node_data(node_data)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is True
def test_execute_if_else_boolean_false_conditions():
"""Test IfElseNode with boolean conditions that should evaluate to false"""
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool with boolean values
pool = VariablePool(
system_variables=SystemVariable(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)
pool.add(["start", "bool_false"], False)
pool.add(["start", "bool_array"], [True, False, True])
node_data = {
"title": "Boolean False Test",
"type": "if-else",
"logical_operator": "or",
"conditions": [
# Test boolean "is" operator (should be false)
{"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "false"},
# Test boolean "=" operator (should be false)
{"comparison_operator": "=", "variable_selector": ["start", "bool_false"], "value": "1"},
# Test boolean "not contains" operator (should be false)
{
"comparison_operator": "not contains",
"variable_selector": ["start", "bool_array"],
"value": "true",
},
],
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "if-else",
"data": node_data,
},
)
node.init_node_data(node_data)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is False
def test_execute_if_else_boolean_cases_structure():
"""Test IfElseNode with boolean conditions using the new cases structure"""
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool with boolean values
pool = VariablePool(
system_variables=SystemVariable(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)
pool.add(["start", "bool_false"], False)
node_data = {
"title": "Boolean Cases Test",
"type": "if-else",
"cases": [
{
"case_id": "true",
"logical_operator": "and",
"conditions": [
{
"comparison_operator": "is",
"variable_selector": ["start", "bool_true"],
"value": "true",
},
{
"comparison_operator": "is not",
"variable_selector": ["start", "bool_false"],
"value": "true",
},
],
}
],
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={"id": "if-else", "data": node_data},
)
node.init_node_data(node_data)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is True
assert result.outputs["selected_case_id"] == "true"

View File

@ -11,7 +11,8 @@ from core.workflow.nodes.list_operator.entities import (
FilterCondition,
Limit,
ListOperatorNodeData,
OrderBy,
Order,
OrderByConfig,
)
from core.workflow.nodes.list_operator.exc import InvalidKeyError
from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func
@ -27,7 +28,7 @@ def list_operator_node():
FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT])
],
),
"order_by": OrderBy(enabled=False, value="asc"),
"order_by": OrderByConfig(enabled=False, value=Order.ASC),
"limit": Limit(enabled=False, size=0),
"extract_by": ExtractConfig(enabled=False, serial="1"),
"title": "Test Title",

View File

@ -24,16 +24,18 @@ from core.variables.segments import (
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from core.variables.types import SegmentType
from factories import variable_factory
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type
def test_string_variable():
@ -139,6 +141,26 @@ def test_array_number_variable():
assert isinstance(variable.value[1], float)
def test_build_segment_scalar_values():
@dataclass
class TestCase:
value: Any
expected: Segment
description: str
cases = [
TestCase(
value=True,
expected=BooleanSegment(value=True),
description="build_segment with boolean should yield BooleanSegment",
)
]
for idx, c in enumerate(cases, 1):
seg = build_segment(c.value)
assert seg == c.expected, f"Test case {idx} failed: {c.description}"
def test_array_object_variable():
mapping = {
"id": str(uuid4()),
@ -847,15 +869,22 @@ class TestBuildSegmentValueErrors:
f"but got: {error_message}"
)
def test_build_segment_boolean_type_note(self):
"""Note: Boolean values are actually handled as integers in Python, so they don't raise ValueError."""
# Boolean values in Python are subclasses of int, so they get processed as integers
# True becomes IntegerSegment(value=1) and False becomes IntegerSegment(value=0)
def test_build_segment_boolean_type(self):
"""Test that Boolean values are correctly handled as boolean type, not integers."""
# Boolean values should now be processed as BooleanSegment, not IntegerSegment
# This is because the bool check now comes before the int check in build_segment
true_segment = variable_factory.build_segment(True)
false_segment = variable_factory.build_segment(False)
# Verify they are processed as integers, not as errors
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
assert true_segment.value_type == SegmentType.INTEGER
assert false_segment.value_type == SegmentType.INTEGER
# Verify they are processed as booleans, not integers
assert true_segment.value is True, "Test case 1 (boolean_true): Expected True to be processed as boolean True"
assert false_segment.value is False, (
"Test case 2 (boolean_false): Expected False to be processed as boolean False"
)
assert true_segment.value_type == SegmentType.BOOLEAN
assert false_segment.value_type == SegmentType.BOOLEAN
# Test array of booleans
bool_array_segment = variable_factory.build_segment([True, False, True])
assert bool_array_segment.value_type == SegmentType.ARRAY_BOOLEAN
assert bool_array_segment.value == [True, False, True]

47
simple_boolean_test.py Normal file
View File

@ -0,0 +1,47 @@
#!/usr/bin/env python3
"""
Simple test to verify boolean classes can be imported correctly.
"""
import sys
import os
# Add the api directory to the Python path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "api"))
try:
# Test that we can import the boolean classes
from core.variables.segments import BooleanSegment, ArrayBooleanSegment
from core.variables.variables import BooleanVariable, ArrayBooleanVariable
from core.variables.types import SegmentType
print("✅ Successfully imported BooleanSegment")
print("✅ Successfully imported ArrayBooleanSegment")
print("✅ Successfully imported BooleanVariable")
print("✅ Successfully imported ArrayBooleanVariable")
print("✅ Successfully imported SegmentType")
# Test that the segment types exist
print(f"✅ SegmentType.BOOLEAN = {SegmentType.BOOLEAN}")
print(f"✅ SegmentType.ARRAY_BOOLEAN = {SegmentType.ARRAY_BOOLEAN}")
# Test creating boolean segments directly
bool_seg = BooleanSegment(value=True)
print(f"✅ Created BooleanSegment: {bool_seg}")
print(f" Value type: {bool_seg.value_type}")
print(f" Value: {bool_seg.value}")
array_bool_seg = ArrayBooleanSegment(value=[True, False, True])
print(f"✅ Created ArrayBooleanSegment: {array_bool_seg}")
print(f" Value type: {array_bool_seg.value_type}")
print(f" Value: {array_bool_seg.value}")
print("\n🎉 All boolean class imports and basic functionality work correctly!")
except ImportError as e:
print(f"❌ Import error: {e}")
except Exception as e:
print(f"❌ Error: {e}")
import traceback
traceback.print_exc()

118
test_boolean_conditions.py Normal file
View File

@ -0,0 +1,118 @@
#!/usr/bin/env python3
"""
Simple test script to verify boolean condition support in IfElseNode
"""
import sys
import os
# Add the api directory to the Python path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "api"))
from core.workflow.utils.condition.processor import (
ConditionProcessor,
_evaluate_condition,
)
def test_boolean_conditions():
"""Test boolean condition evaluation"""
print("Testing boolean condition support...")
# Test boolean "is" operator
result = _evaluate_condition(value=True, operator="is", expected="true")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'is' with True value passed")
result = _evaluate_condition(value=False, operator="is", expected="false")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'is' with False value passed")
# Test boolean "is not" operator
result = _evaluate_condition(value=True, operator="is not", expected="false")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'is not' with True value passed")
result = _evaluate_condition(value=False, operator="is not", expected="true")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'is not' with False value passed")
# Test boolean "=" operator
result = _evaluate_condition(value=True, operator="=", expected="1")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean '=' with True=1 passed")
result = _evaluate_condition(value=False, operator="=", expected="0")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean '=' with False=0 passed")
# Test boolean "≠" operator
result = _evaluate_condition(value=True, operator="", expected="0")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean '' with True≠0 passed")
result = _evaluate_condition(value=False, operator="", expected="1")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean '' with False≠1 passed")
# Test boolean "in" operator
result = _evaluate_condition(value=True, operator="in", expected=["true", "false"])
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'in' with True in array passed")
result = _evaluate_condition(value=False, operator="in", expected=["true", "false"])
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'in' with False in array passed")
# Test boolean "not in" operator
result = _evaluate_condition(value=True, operator="not in", expected=["false", "0"])
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'not in' with True not in [false, 0] passed")
# Test boolean "null" and "not null" operators
result = _evaluate_condition(value=True, operator="not null", expected=None)
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'not null' with True passed")
result = _evaluate_condition(value=False, operator="not null", expected=None)
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'not null' with False passed")
print("\n🎉 All boolean condition tests passed!")
def test_backward_compatibility():
"""Test that existing string and number conditions still work"""
print("\nTesting backward compatibility...")
# Test string conditions
result = _evaluate_condition(value="hello", operator="is", expected="hello")
assert result == True, f"Expected True, got {result}"
print("✓ String 'is' condition still works")
result = _evaluate_condition(value="hello", operator="contains", expected="ell")
assert result == True, f"Expected True, got {result}"
print("✓ String 'contains' condition still works")
# Test number conditions
result = _evaluate_condition(value=42, operator="=", expected="42")
assert result == True, f"Expected True, got {result}"
print("✓ Number '=' condition still works")
result = _evaluate_condition(value=42, operator=">", expected="40")
assert result == True, f"Expected True, got {result}"
print("✓ Number '>' condition still works")
print("✓ Backward compatibility maintained!")
if __name__ == "__main__":
try:
test_boolean_conditions()
test_backward_compatibility()
print(
"\n✅ All tests passed! Boolean support has been successfully added to IfElseNode."
)
except Exception as e:
print(f"\n❌ Test failed: {e}")
sys.exit(1)

View File

@ -0,0 +1,67 @@
#!/usr/bin/env python3
"""
Test script to verify the boolean array comparison fix in condition processor.
"""
import sys
import os
# Add the api directory to the Python path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "api"))
from core.workflow.utils.condition.processor import (
_assert_contains,
_assert_not_contains,
)
def test_boolean_array_contains():
"""Test that boolean arrays work correctly with string comparisons."""
# Test case 1: Boolean array [True, False, True] contains "true"
bool_array = [True, False, True]
# Should return True because "true" converts to True and True is in the array
result1 = _assert_contains(value=bool_array, expected="true")
print(f"Test 1 - [True, False, True] contains 'true': {result1}")
assert result1 == True, "Expected True but got False"
# Should return True because "false" converts to False and False is in the array
result2 = _assert_contains(value=bool_array, expected="false")
print(f"Test 2 - [True, False, True] contains 'false': {result2}")
assert result2 == True, "Expected True but got False"
# Test case 2: Boolean array [True, True] does not contain "false"
bool_array2 = [True, True]
result3 = _assert_contains(value=bool_array2, expected="false")
print(f"Test 3 - [True, True] contains 'false': {result3}")
assert result3 == False, "Expected False but got True"
# Test case 3: Test not_contains
result4 = _assert_not_contains(value=bool_array2, expected="false")
print(f"Test 4 - [True, True] not contains 'false': {result4}")
assert result4 == True, "Expected True but got False"
result5 = _assert_not_contains(value=bool_array, expected="true")
print(f"Test 5 - [True, False, True] not contains 'true': {result5}")
assert result5 == False, "Expected False but got True"
# Test case 4: Test with different string representations
result6 = _assert_contains(
value=bool_array, expected="1"
) # "1" should convert to True
print(f"Test 6 - [True, False, True] contains '1': {result6}")
assert result6 == True, "Expected True but got False"
result7 = _assert_contains(
value=bool_array, expected="0"
) # "0" should convert to False
print(f"Test 7 - [True, False, True] contains '0': {result7}")
assert result7 == True, "Expected True but got False"
print("\n✅ All boolean array comparison tests passed!")
if __name__ == "__main__":
test_boolean_array_contains()

99
test_boolean_factory.py Normal file
View File

@ -0,0 +1,99 @@
#!/usr/bin/env python3
"""
Simple test script to verify boolean type inference in variable factory.
"""
import sys
import os
# Add the api directory to the Python path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "api"))
try:
from factories.variable_factory import build_segment, segment_to_variable
from core.variables.segments import BooleanSegment, ArrayBooleanSegment
from core.variables.variables import BooleanVariable, ArrayBooleanVariable
from core.variables.types import SegmentType
def test_boolean_inference():
print("Testing boolean type inference...")
# Test single boolean values
true_segment = build_segment(True)
false_segment = build_segment(False)
print(f"True value: {true_segment}")
print(f"Type: {type(true_segment)}")
print(f"Value type: {true_segment.value_type}")
print(f"Is BooleanSegment: {isinstance(true_segment, BooleanSegment)}")
print(f"\nFalse value: {false_segment}")
print(f"Type: {type(false_segment)}")
print(f"Value type: {false_segment.value_type}")
print(f"Is BooleanSegment: {isinstance(false_segment, BooleanSegment)}")
# Test array of booleans
bool_array_segment = build_segment([True, False, True])
print(f"\nBoolean array: {bool_array_segment}")
print(f"Type: {type(bool_array_segment)}")
print(f"Value type: {bool_array_segment.value_type}")
print(
f"Is ArrayBooleanSegment: {isinstance(bool_array_segment, ArrayBooleanSegment)}"
)
# Test empty boolean array
empty_bool_array = build_segment([])
print(f"\nEmpty array: {empty_bool_array}")
print(f"Type: {type(empty_bool_array)}")
print(f"Value type: {empty_bool_array.value_type}")
# Test segment to variable conversion
bool_var = segment_to_variable(
segment=true_segment, selector=["test", "bool_var"], name="test_boolean"
)
print(f"\nBoolean variable: {bool_var}")
print(f"Type: {type(bool_var)}")
print(f"Is BooleanVariable: {isinstance(bool_var, BooleanVariable)}")
array_bool_var = segment_to_variable(
segment=bool_array_segment,
selector=["test", "array_bool_var"],
name="test_array_boolean",
)
print(f"\nArray boolean variable: {array_bool_var}")
print(f"Type: {type(array_bool_var)}")
print(
f"Is ArrayBooleanVariable: {isinstance(array_bool_var, ArrayBooleanVariable)}"
)
# Test that bool comes before int (critical ordering)
print(f"\nTesting bool vs int precedence:")
print(f"True is instance of bool: {isinstance(True, bool)}")
print(f"True is instance of int: {isinstance(True, int)}")
print(f"False is instance of bool: {isinstance(False, bool)}")
print(f"False is instance of int: {isinstance(False, int)}")
# Verify that boolean values are correctly inferred as boolean, not int
assert true_segment.value_type == SegmentType.BOOLEAN, (
"True should be inferred as BOOLEAN"
)
assert false_segment.value_type == SegmentType.BOOLEAN, (
"False should be inferred as BOOLEAN"
)
assert bool_array_segment.value_type == SegmentType.ARRAY_BOOLEAN, (
"Boolean array should be inferred as ARRAY_BOOLEAN"
)
print("\n✅ All boolean inference tests passed!")
if __name__ == "__main__":
test_boolean_inference()
except ImportError as e:
print(f"Import error: {e}")
print("Make sure you're running this from the correct directory")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()

View File

@ -0,0 +1,230 @@
#!/usr/bin/env python3
"""
Test script to verify boolean support in VariableAssigner node
"""
import sys
import os
# Add the api directory to the Python path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "api"))
from core.variables import SegmentType
from core.workflow.nodes.variable_assigner.v2.helpers import (
is_operation_supported,
is_constant_input_supported,
is_input_value_valid,
)
from core.workflow.nodes.variable_assigner.v2.enums import Operation
from core.workflow.nodes.variable_assigner.v2.constants import EMPTY_VALUE_MAPPING
def test_boolean_operation_support():
"""Test that boolean types support the correct operations"""
print("Testing boolean operation support...")
# Boolean should support SET, OVER_WRITE, and CLEAR
assert is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.SET
)
assert is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.OVER_WRITE
)
assert is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.CLEAR
)
# Boolean should NOT support arithmetic operations
assert not is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.ADD
)
assert not is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.SUBTRACT
)
assert not is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.MULTIPLY
)
assert not is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.DIVIDE
)
# Boolean should NOT support array operations
assert not is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.APPEND
)
assert not is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.EXTEND
)
print("✓ Boolean operation support tests passed")
def test_array_boolean_operation_support():
"""Test that array boolean types support the correct operations"""
print("Testing array boolean operation support...")
# Array boolean should support APPEND, EXTEND, SET, OVER_WRITE, CLEAR
assert is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.APPEND
)
assert is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.EXTEND
)
assert is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.OVER_WRITE
)
assert is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.CLEAR
)
assert is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.REMOVE_FIRST
)
assert is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.REMOVE_LAST
)
# Array boolean should NOT support arithmetic operations
assert not is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.ADD
)
assert not is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.SUBTRACT
)
assert not is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.MULTIPLY
)
assert not is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.DIVIDE
)
print("✓ Array boolean operation support tests passed")
def test_boolean_constant_input_support():
"""Test that boolean types support constant input for correct operations"""
print("Testing boolean constant input support...")
# Boolean should support constant input for SET and OVER_WRITE
assert is_constant_input_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.SET
)
assert is_constant_input_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.OVER_WRITE
)
# Boolean should NOT support constant input for arithmetic operations
assert not is_constant_input_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.ADD
)
print("✓ Boolean constant input support tests passed")
def test_boolean_input_validation():
"""Test that boolean input validation works correctly"""
print("Testing boolean input validation...")
# Boolean values should be valid for boolean type
assert is_input_value_valid(
variable_type=SegmentType.BOOLEAN, operation=Operation.SET, value=True
)
assert is_input_value_valid(
variable_type=SegmentType.BOOLEAN, operation=Operation.SET, value=False
)
assert is_input_value_valid(
variable_type=SegmentType.BOOLEAN, operation=Operation.OVER_WRITE, value=True
)
# Non-boolean values should be invalid for boolean type
assert not is_input_value_valid(
variable_type=SegmentType.BOOLEAN, operation=Operation.SET, value="true"
)
assert not is_input_value_valid(
variable_type=SegmentType.BOOLEAN, operation=Operation.SET, value=1
)
assert not is_input_value_valid(
variable_type=SegmentType.BOOLEAN, operation=Operation.SET, value=0
)
print("✓ Boolean input validation tests passed")
def test_array_boolean_input_validation():
"""Test that array boolean input validation works correctly"""
print("Testing array boolean input validation...")
# Boolean values should be valid for array boolean append
assert is_input_value_valid(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.APPEND, value=True
)
assert is_input_value_valid(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.APPEND, value=False
)
# Boolean arrays should be valid for extend/overwrite
assert is_input_value_valid(
variable_type=SegmentType.ARRAY_BOOLEAN,
operation=Operation.EXTEND,
value=[True, False, True],
)
assert is_input_value_valid(
variable_type=SegmentType.ARRAY_BOOLEAN,
operation=Operation.OVER_WRITE,
value=[False, False],
)
# Non-boolean values should be invalid
assert not is_input_value_valid(
variable_type=SegmentType.ARRAY_BOOLEAN,
operation=Operation.APPEND,
value="true",
)
assert not is_input_value_valid(
variable_type=SegmentType.ARRAY_BOOLEAN,
operation=Operation.EXTEND,
value=[True, "false"],
)
print("✓ Array boolean input validation tests passed")
def test_empty_value_mapping():
"""Test that empty value mapping includes boolean types"""
print("Testing empty value mapping...")
# Check that boolean types have correct empty values
assert SegmentType.BOOLEAN in EMPTY_VALUE_MAPPING
assert EMPTY_VALUE_MAPPING[SegmentType.BOOLEAN] is False
assert SegmentType.ARRAY_BOOLEAN in EMPTY_VALUE_MAPPING
assert EMPTY_VALUE_MAPPING[SegmentType.ARRAY_BOOLEAN] == []
print("✓ Empty value mapping tests passed")
def main():
"""Run all tests"""
print("Running VariableAssigner boolean support tests...\n")
try:
test_boolean_operation_support()
test_array_boolean_operation_support()
test_boolean_constant_input_support()
test_boolean_input_validation()
test_array_boolean_input_validation()
test_empty_value_mapping()
print(
"\n🎉 All tests passed! Boolean support has been successfully added to VariableAssigner."
)
except Exception as e:
print(f"\n❌ Test failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -34,7 +34,7 @@ COPY --from=packages /app/web/ .
COPY . .
ENV NODE_OPTIONS="--max-old-space-size=4096"
RUN pnpm build:docker
RUN pnpm build
# production stage

View File

@ -8,7 +8,6 @@ import Header from '@/app/components/header'
import { EventEmitterContextProvider } from '@/context/event-emitter'
import { ProviderContextProvider } from '@/context/provider-context'
import { ModalContextProvider } from '@/context/modal-context'
import { ReadmePanelProvider } from '@/app/components/plugins/readme-panel/context'
import GotoAnything from '@/app/components/goto-anything'
const Layout = ({ children }: { children: ReactNode }) => {
@ -20,13 +19,11 @@ const Layout = ({ children }: { children: ReactNode }) => {
<EventEmitterContextProvider>
<ProviderContextProvider>
<ModalContextProvider>
<ReadmePanelProvider>
<HeaderWrapper>
<Header />
</HeaderWrapper>
{children}
<GotoAnything />
</ReadmePanelProvider>
<HeaderWrapper>
<Header />
</HeaderWrapper>
{children}
<GotoAnything />
</ModalContextProvider>
</ProviderContextProvider>
</EventEmitterContextProvider>

View File

@ -0,0 +1,24 @@
export const jsonObjectWrap = {
type: 'object',
properties: {},
required: [],
additionalProperties: true,
}
export const jsonConfigPlaceHolder = JSON.stringify(
{
foo: {
type: 'string',
},
bar: {
type: 'object',
properties: {
sub: {
type: 'number',
},
},
required: [],
additionalProperties: true,
},
}, null, 2,
)

View File

@ -2,21 +2,28 @@
import type { FC } from 'react'
import React from 'react'
import cn from '@/utils/classnames'
import { useTranslation } from 'react-i18next'
type Props = {
className?: string
title: string
isOptional?: boolean
children: React.JSX.Element
}
const Field: FC<Props> = ({
className,
title,
isOptional,
children,
}) => {
const { t } = useTranslation()
return (
<div className={cn(className)}>
<div className='system-sm-semibold leading-8 text-text-secondary'>{title}</div>
<div className='system-sm-semibold leading-8 text-text-secondary'>
{title}
{isOptional && <span className='system-xs-regular ml-1 text-text-tertiary'>({t('appDebug.variableConfig.optional')})</span>}
</div>
<div>{children}</div>
</div>
)

View File

@ -1,13 +1,12 @@
'use client'
import type { ChangeEvent, FC } from 'react'
import React, { useCallback, useEffect, useRef, useState } from 'react'
import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import produce from 'immer'
import ModalFoot from '../modal-foot'
import ConfigSelect from '../config-select'
import ConfigString from '../config-string'
import SelectTypeItem from '../select-type-item'
import Field from './field'
import Input from '@/app/components/base/input'
import Toast from '@/app/components/base/toast'
@ -20,7 +19,13 @@ import FileUploadSetting from '@/app/components/workflow/nodes/_base/components/
import Checkbox from '@/app/components/base/checkbox'
import { DEFAULT_FILE_UPLOAD_SETTING } from '@/app/components/workflow/constants'
import { DEFAULT_VALUE_MAX_LEN } from '@/config'
import type { Item as SelectItem } from './type-select'
import TypeSelector from './type-select'
import { SimpleSelect } from '@/app/components/base/select'
import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor'
import { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
import { jsonConfigPlaceHolder, jsonObjectWrap } from './config'
import { useStore as useAppStore } from '@/app/components/app/store'
import Textarea from '@/app/components/base/textarea'
import { FileUploaderInAttachmentWrapper } from '@/app/components/base/file-uploader'
import { TransferMethod } from '@/types/app'
@ -51,6 +56,20 @@ const ConfigModal: FC<IConfigModalProps> = ({
const [tempPayload, setTempPayload] = useState<InputVar>(payload || getNewVarInWorkflow('') as any)
const { type, label, variable, options, max_length } = tempPayload
const modalRef = useRef<HTMLDivElement>(null)
const appDetail = useAppStore(state => state.appDetail)
const isBasicApp = appDetail?.mode !== 'advanced-chat' && appDetail?.mode !== 'workflow'
const isSupportJSON = false
const jsonSchemaStr = useMemo(() => {
const isJsonObject = type === InputVarType.jsonObject
if (!isJsonObject || !tempPayload.json_schema)
return ''
try {
return JSON.stringify(JSON.parse(tempPayload.json_schema).properties, null, 2)
}
catch (_e) {
return ''
}
}, [tempPayload.json_schema])
useEffect(() => {
// To fix the first input element auto focus, then directly close modal will raise error
if (isShow)
@ -82,25 +101,74 @@ const ConfigModal: FC<IConfigModalProps> = ({
}
}, [])
const handleTypeChange = useCallback((type: InputVarType) => {
return () => {
const newPayload = produce(tempPayload, (draft) => {
draft.type = type
// Clear default value when switching types
draft.default = undefined
if ([InputVarType.singleFile, InputVarType.multiFiles].includes(type)) {
(Object.keys(DEFAULT_FILE_UPLOAD_SETTING)).forEach((key) => {
if (key !== 'max_length')
(draft as any)[key] = (DEFAULT_FILE_UPLOAD_SETTING as any)[key]
})
if (type === InputVarType.multiFiles)
draft.max_length = DEFAULT_FILE_UPLOAD_SETTING.max_length
}
if (type === InputVarType.paragraph)
draft.max_length = DEFAULT_VALUE_MAX_LEN
})
setTempPayload(newPayload)
const handleJSONSchemaChange = useCallback((value: string) => {
try {
const v = JSON.parse(value)
const res = {
...jsonObjectWrap,
properties: v,
}
handlePayloadChange('json_schema')(JSON.stringify(res, null, 2))
}
catch (_e) {
return null
}
}, [handlePayloadChange])
const selectOptions: SelectItem[] = [
{
name: t('appDebug.variableConfig.text-input'),
value: InputVarType.textInput,
},
{
name: t('appDebug.variableConfig.paragraph'),
value: InputVarType.paragraph,
},
{
name: t('appDebug.variableConfig.select'),
value: InputVarType.select,
},
{
name: t('appDebug.variableConfig.number'),
value: InputVarType.number,
},
{
name: t('appDebug.variableConfig.checkbox'),
value: InputVarType.checkbox,
},
...(supportFile ? [
{
name: t('appDebug.variableConfig.single-file'),
value: InputVarType.singleFile,
},
{
name: t('appDebug.variableConfig.multi-files'),
value: InputVarType.multiFiles,
},
] : []),
...((!isBasicApp && isSupportJSON) ? [{
name: t('appDebug.variableConfig.json'),
value: InputVarType.jsonObject,
}] : []),
]
const handleTypeChange = useCallback((item: SelectItem) => {
const type = item.value as InputVarType
const newPayload = produce(tempPayload, (draft) => {
draft.type = type
if ([InputVarType.singleFile, InputVarType.multiFiles].includes(type)) {
(Object.keys(DEFAULT_FILE_UPLOAD_SETTING)).forEach((key) => {
if (key !== 'max_length')
(draft as any)[key] = (DEFAULT_FILE_UPLOAD_SETTING as any)[key]
})
if (type === InputVarType.multiFiles)
draft.max_length = DEFAULT_FILE_UPLOAD_SETTING.max_length
}
if (type === InputVarType.paragraph)
draft.max_length = DEFAULT_VALUE_MAX_LEN
})
setTempPayload(newPayload)
}, [tempPayload])
const handleVarKeyBlur = useCallback((e: any) => {
@ -142,15 +210,6 @@ const ConfigModal: FC<IConfigModalProps> = ({
if (!isVariableNameValid)
return
// TODO: check if key already exists. should the consider the edit case
// if (varKeys.map(key => key?.trim()).includes(tempPayload.variable.trim())) {
// Toast.notify({
// type: 'error',
// message: t('appDebug.varKeyError.keyAlreadyExists', { key: tempPayload.variable }),
// })
// return
// }
if (!tempPayload.label) {
Toast.notify({ type: 'error', message: t('appDebug.variableConfig.errorMsg.labelNameRequired') })
return
@ -204,18 +263,8 @@ const ConfigModal: FC<IConfigModalProps> = ({
>
<div className='mb-8' ref={modalRef} tabIndex={-1}>
<div className='space-y-2'>
<Field title={t('appDebug.variableConfig.fieldType')}>
<div className='grid grid-cols-3 gap-2'>
<SelectTypeItem type={InputVarType.textInput} selected={type === InputVarType.textInput} onClick={handleTypeChange(InputVarType.textInput)} />
<SelectTypeItem type={InputVarType.paragraph} selected={type === InputVarType.paragraph} onClick={handleTypeChange(InputVarType.paragraph)} />
<SelectTypeItem type={InputVarType.select} selected={type === InputVarType.select} onClick={handleTypeChange(InputVarType.select)} />
<SelectTypeItem type={InputVarType.number} selected={type === InputVarType.number} onClick={handleTypeChange(InputVarType.number)} />
{supportFile && <>
<SelectTypeItem type={InputVarType.singleFile} selected={type === InputVarType.singleFile} onClick={handleTypeChange(InputVarType.singleFile)} />
<SelectTypeItem type={InputVarType.multiFiles} selected={type === InputVarType.multiFiles} onClick={handleTypeChange(InputVarType.multiFiles)} />
</>}
</div>
<TypeSelector value={type} items={selectOptions} onSelect={handleTypeChange} />
</Field>
<Field title={t('appDebug.variableConfig.varName')}>
@ -330,6 +379,21 @@ const ConfigModal: FC<IConfigModalProps> = ({
</>
)}
{type === InputVarType.jsonObject && (
<Field title={t('appDebug.variableConfig.jsonSchema')} isOptional>
<CodeEditor
language={CodeLanguage.json}
value={jsonSchemaStr}
onChange={handleJSONSchemaChange}
noWrapper
className='bg h-[80px] overflow-y-auto rounded-[10px] bg-components-input-bg-normal p-1'
placeholder={
<div className='whitespace-pre'>{jsonConfigPlaceHolder}</div>
}
/>
</Field>
)}
<div className='!mt-5 flex h-6 items-center space-x-2'>
<Checkbox checked={tempPayload.required} disabled={tempPayload.hide} onCheck={() => handlePayloadChange('required')(!tempPayload.required)} />
<span className='system-sm-semibold text-text-secondary'>{t('appDebug.variableConfig.required')}</span>

View File

@ -0,0 +1,97 @@
'use client'
import type { FC } from 'react'
import React, { useState } from 'react'
import { ChevronDownIcon } from '@heroicons/react/20/solid'
import classNames from '@/utils/classnames'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import InputVarTypeIcon from '@/app/components/workflow/nodes/_base/components/input-var-type-icon'
import type { InputVarType } from '@/app/components/workflow/types'
import cn from '@/utils/classnames'
import Badge from '@/app/components/base/badge'
import { inputVarTypeToVarType } from '@/app/components/workflow/nodes/_base/components/variable/utils'
export type Item = {
value: InputVarType
name: string
}
type Props = {
value: string | number
onSelect: (value: Item) => void
items: Item[]
popupClassName?: string
popupInnerClassName?: string
readonly?: boolean
hideChecked?: boolean
}
const TypeSelector: FC<Props> = ({
value,
onSelect,
items,
popupInnerClassName,
readonly,
}) => {
const [open, setOpen] = useState(false)
const selectedItem = value ? items.find(item => item.value === value) : undefined
return (
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement='bottom-start'
offset={4}
>
<PortalToFollowElemTrigger onClick={() => !readonly && setOpen(v => !v)} className='w-full'>
<div
className={classNames(`group flex h-9 items-center justify-between rounded-lg border-0 bg-components-input-bg-normal px-2 text-sm hover:bg-state-base-hover-alt ${readonly ? 'cursor-not-allowed' : 'cursor-pointer'}`)}
title={selectedItem?.name}
>
<div className='flex items-center'>
<InputVarTypeIcon type={selectedItem?.value as InputVarType} className='size-4 shrink-0 text-text-secondary' />
<span
className={`
ml-1.5 ${!selectedItem?.name && 'text-components-input-text-placeholder'}
`}
>
{selectedItem?.name}
</span>
</div>
<div className='flex items-center space-x-1'>
<Badge uppercase={false}>{inputVarTypeToVarType(selectedItem?.value as InputVarType)}</Badge>
<ChevronDownIcon className={cn('h-4 w-4 shrink-0 text-text-quaternary group-hover:text-text-secondary', open && 'text-text-secondary')} />
</div>
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-[61]'>
<div
className={classNames('w-[432px] rounded-md border-[0.5px] border-components-panel-border bg-components-panel-bg px-1 py-1 text-base shadow-lg focus:outline-none sm:text-sm', popupInnerClassName)}
>
{items.map((item: Item) => (
<div
key={item.value}
className={'flex h-9 cursor-pointer items-center justify-between rounded-lg px-2 text-text-secondary hover:bg-state-base-hover'}
title={item.name}
onClick={() => {
onSelect(item)
setOpen(false)
}}
>
<div className='flex items-center space-x-2'>
<InputVarTypeIcon type={item.value} className='size-4 shrink-0 text-text-secondary' />
<span title={item.name}>{item.name}</span>
</div>
<Badge uppercase={false}>{inputVarTypeToVarType(item.value)}</Badge>
</div>
))}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
)
}
export default TypeSelector

View File

@ -12,7 +12,7 @@ import SelectVarType from './select-var-type'
import Tooltip from '@/app/components/base/tooltip'
import type { PromptVariable } from '@/models/debug'
import { DEFAULT_VALUE_MAX_LEN } from '@/config'
import { getNewVar } from '@/utils/var'
import { getNewVar, hasDuplicateStr } from '@/utils/var'
import Toast from '@/app/components/base/toast'
import Confirm from '@/app/components/base/confirm'
import ConfigContext from '@/context/debug-configuration'
@ -80,7 +80,28 @@ const ConfigVar: FC<IConfigVarProps> = ({ promptVariables, readonly, onPromptVar
delete draft[currIndex].options
})
const newList = newPromptVariables
let errorMsgKey = ''
let typeName = ''
if (hasDuplicateStr(newList.map(item => item.key))) {
errorMsgKey = 'appDebug.varKeyError.keyAlreadyExists'
typeName = 'appDebug.variableConfig.varName'
}
else if (hasDuplicateStr(newList.map(item => item.name as string))) {
errorMsgKey = 'appDebug.varKeyError.keyAlreadyExists'
typeName = 'appDebug.variableConfig.labelName'
}
if (errorMsgKey) {
Toast.notify({
type: 'error',
message: t(errorMsgKey, { key: t(typeName) }),
})
return false
}
onPromptVariablesChange?.(newPromptVariables)
return true
}
const { setShowExternalDataToolModal } = useModalContext()
@ -190,7 +211,7 @@ const ConfigVar: FC<IConfigVarProps> = ({ promptVariables, readonly, onPromptVar
const handleConfig = ({ key, type, index, name, config, icon, icon_background }: ExternalDataToolParams) => {
// setCurrKey(key)
setCurrIndex(index)
if (type !== 'string' && type !== 'paragraph' && type !== 'select' && type !== 'number') {
if (type !== 'string' && type !== 'paragraph' && type !== 'select' && type !== 'number' && type !== 'checkbox') {
handleOpenExternalDataToolModal({ key, type, index, name, config, icon, icon_background }, promptVariables)
return
}
@ -245,7 +266,8 @@ const ConfigVar: FC<IConfigVarProps> = ({ promptVariables, readonly, onPromptVar
isShow={isShowEditModal}
onClose={hideEditModal}
onConfirm={(item) => {
updatePromptVariableItem(item)
const isValid = updatePromptVariableItem(item)
if (!isValid) return
hideEditModal()
}}
varKeys={promptVariables.map(v => v.key)}

View File

@ -65,6 +65,7 @@ const SelectVarType: FC<Props> = ({
<SelectItem type={InputVarType.paragraph} value='paragraph' text={t('appDebug.variableConfig.paragraph')} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.select} value='select' text={t('appDebug.variableConfig.select')} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.number} value='number' text={t('appDebug.variableConfig.number')} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.checkbox} value='checkbox' text={t('appDebug.variableConfig.checkbox')} onClick={handleChange}></SelectItem>
</div>
<div className='h-px border-t border-components-panel-border'></div>
<div className='p-1'>

View File

@ -28,7 +28,6 @@ import {
AuthCategory,
PluginAuthInAgent,
} from '@/app/components/plugins/plugin-auth'
import { ReadmeEntrance } from '@/app/components/plugins/readme-panel/entrance'
type Props = {
showBackButton?: boolean
@ -121,6 +120,8 @@ const SettingBuiltInTool: FC<Props> = ({
return t('tools.setBuiltInTools.number')
if (type === 'text-input')
return t('tools.setBuiltInTools.string')
if (type === 'checkbox')
return 'boolean'
if (type === 'file')
return t('tools.setBuiltInTools.file')
return type
@ -213,7 +214,6 @@ const SettingBuiltInTool: FC<Props> = ({
pluginPayload={{
provider: collection.name,
category: AuthCategory.tool,
detail: collection as any,
}}
credentialId={credentialId}
onAuthorizationItemClick={onAuthorizationItemClick}
@ -243,14 +243,13 @@ const SettingBuiltInTool: FC<Props> = ({
)}
<div className='h-0 grow overflow-y-auto px-4'>
{isInfoActive ? infoUI : settingUI}
{!readonly && !isInfoActive && (
<div className='flex shrink-0 justify-end space-x-2 rounded-b-[10px] bg-components-panel-bg py-2'>
<Button className='flex h-8 items-center !px-3 !text-[13px] font-medium ' onClick={onHide}>{t('common.operation.cancel')}</Button>
<Button className='flex h-8 items-center !px-3 !text-[13px] font-medium' variant='primary' disabled={!isValid} onClick={() => onSave?.(addDefaultValue(tempSetting, formSchemas))}>{t('common.operation.save')}</Button>
</div>
)}
</div>
<ReadmeEntrance detail={collection as any} className='mt-auto' />
{!readonly && !isInfoActive && (
<div className='mt-2 flex shrink-0 justify-end space-x-2 rounded-b-[10px] border-t border-divider-regular bg-components-panel-bg px-6 py-4'>
<Button className='flex h-8 items-center !px-3 !text-[13px] font-medium ' onClick={onHide}>{t('common.operation.cancel')}</Button>
<Button className='flex h-8 items-center !px-3 !text-[13px] font-medium' variant='primary' disabled={!isValid} onClick={() => onSave?.(addDefaultValue(tempSetting, formSchemas))}>{t('common.operation.save')}</Button>
</div>
)}
</div>
</div>
</>

View File

@ -8,6 +8,7 @@ import Textarea from '@/app/components/base/textarea'
import { DEFAULT_VALUE_MAX_LEN } from '@/config'
import type { Inputs } from '@/models/debug'
import cn from '@/utils/classnames'
import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input'
type Props = {
inputs: Inputs
@ -31,7 +32,7 @@ const ChatUserInput = ({
return obj
})()
const handleInputValueChange = (key: string, value: string) => {
const handleInputValueChange = (key: string, value: string | boolean) => {
if (!(key in promptVariableObj))
return
@ -55,10 +56,12 @@ const ChatUserInput = ({
className='mb-4 last-of-type:mb-0'
>
<div>
{type !== 'checkbox' && (
<div className='system-sm-semibold mb-1 flex h-6 items-center gap-1 text-text-secondary'>
<div className='truncate'>{name || key}</div>
{!required && <span className='system-xs-regular text-text-tertiary'>{t('workflow.panel.optional')}</span>}
</div>
)}
<div className='grow'>
{type === 'string' && (
<Input
@ -96,6 +99,14 @@ const ChatUserInput = ({
maxLength={max_length || DEFAULT_VALUE_MAX_LEN}
/>
)}
{type === 'checkbox' && (
<BoolInput
name={name || key}
value={!!inputs[key]}
required={required}
onChange={(value) => { handleInputValueChange(key, value) }}
/>
)}
</div>
</div>
</div>

View File

@ -34,7 +34,7 @@ import { RefreshCcw01 } from '@/app/components/base/icons/src/vender/line/arrows
import TooltipPlus from '@/app/components/base/tooltip'
import ActionButton, { ActionButtonState } from '@/app/components/base/action-button'
import type { ModelConfig as BackendModelConfig, VisionFile, VisionSettings } from '@/types/app'
import { promptVariablesToUserInputsForm } from '@/utils/model-config'
import { formatBooleanInputs, promptVariablesToUserInputsForm } from '@/utils/model-config'
import TextGeneration from '@/app/components/app/text-generate/item'
import { IS_CE_EDITION } from '@/config'
import type { Inputs } from '@/models/debug'
@ -259,7 +259,7 @@ const Debug: FC<IDebug> = ({
}
const data: Record<string, any> = {
inputs,
inputs: formatBooleanInputs(modelConfig.configs.prompt_variables, inputs),
model_config: postModelConfig,
}

View File

@ -60,7 +60,6 @@ import {
useModelListAndDefaultModelAndCurrentProviderAndModel,
useTextGenerationCurrentProviderAndModelAndModelList,
} from '@/app/components/header/account-setting/model-provider-page/hooks'
import { fetchCollectionList } from '@/service/tools'
import type { Collection } from '@/app/components/tools/types'
import { useStore as useAppStore } from '@/app/components/app/store'
import {
@ -82,6 +81,7 @@ import { supportFunctionCall } from '@/utils/tool-call'
import { MittProvider } from '@/context/mitt-context'
import { fetchAndMergeValidCompletionParams } from '@/utils/completion-params'
import Toast from '@/app/components/base/toast'
import { fetchCollectionList } from '@/service/tools'
import { useAppContext } from '@/context/app-context'
type PublishConfig = {

View File

@ -22,6 +22,7 @@ import type { VisionFile, VisionSettings } from '@/types/app'
import { DEFAULT_VALUE_MAX_LEN } from '@/config'
import { useStore as useAppStore } from '@/app/components/app/store'
import cn from '@/utils/classnames'
import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input'
export type IPromptValuePanelProps = {
appType: AppType
@ -66,7 +67,7 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
else { return !modelConfig.configs.prompt_template }
}, [chatPromptConfig.prompt, completionPromptConfig.prompt?.text, isAdvancedMode, mode, modelConfig.configs.prompt_template, modelModeType])
const handleInputValueChange = (key: string, value: string) => {
const handleInputValueChange = (key: string, value: string | boolean) => {
if (!(key in promptVariableObj))
return
@ -109,10 +110,12 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
className='mb-4 last-of-type:mb-0'
>
<div>
<div className='system-sm-semibold mb-1 flex h-6 items-center gap-1 text-text-secondary'>
<div className='truncate'>{name || key}</div>
{!required && <span className='system-xs-regular text-text-tertiary'>{t('workflow.panel.optional')}</span>}
</div>
{type !== 'checkbox' && (
<div className='system-sm-semibold mb-1 flex h-6 items-center gap-1 text-text-secondary'>
<div className='truncate'>{name || key}</div>
{!required && <span className='system-xs-regular text-text-tertiary'>{t('workflow.panel.optional')}</span>}
</div>
)}
<div className='grow'>
{type === 'string' && (
<Input
@ -151,6 +154,14 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
maxLength={max_length || DEFAULT_VALUE_MAX_LEN}
/>
)}
{type === 'checkbox' && (
<BoolInput
name={name || key}
value={!!inputs[key]}
required={required}
onChange={(value) => { handleInputValueChange(key, value) }}
/>
)}
</div>
</div>
</div>

View File

@ -23,6 +23,7 @@ import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested
import { Markdown } from '@/app/components/base/markdown'
import cn from '@/utils/classnames'
import type { FileEntity } from '../../file-uploader/types'
import { formatBooleanInputs } from '@/utils/model-config'
import Avatar from '../../avatar'
const ChatWrapper = () => {
@ -89,7 +90,7 @@ const ChatWrapper = () => {
let hasEmptyInput = ''
let fileIsUploading = false
const requiredVars = inputsForms.filter(({ required }) => required)
const requiredVars = inputsForms.filter(({ required, type }) => required && type !== InputVarType.checkbox)
if (requiredVars.length) {
requiredVars.forEach(({ variable, label, type }) => {
if (hasEmptyInput)
@ -131,7 +132,7 @@ const ChatWrapper = () => {
const data: any = {
query: message,
files,
inputs: currentConversationId ? currentConversationInputs : newConversationInputs,
inputs: formatBooleanInputs(inputsForms, currentConversationId ? currentConversationInputs : newConversationInputs),
conversation_id: currentConversationId,
parent_message_id: (isRegenerate ? parentAnswer?.id : getLastAnswer(chatList)?.id) || null,
}

View File

@ -222,6 +222,14 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
type: 'number',
}
}
if(item.checkbox) {
return {
...item.checkbox,
default: false,
type: 'checkbox',
}
}
if (item.select) {
const isInputInOptions = item.select.options.includes(initInputs[item.select.variable])
return {
@ -245,6 +253,13 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
}
}
if (item.json_object) {
return {
...item.json_object,
type: 'json_object',
}
}
let value = initInputs[item['text-input'].variable]
if (value && item['text-input'].max_length && value.length > item['text-input'].max_length)
value = value.slice(0, item['text-input'].max_length)
@ -340,7 +355,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
let hasEmptyInput = ''
let fileIsUploading = false
const requiredVars = inputsForms.filter(({ required }) => required)
const requiredVars = inputsForms.filter(({ required, type }) => required && type !== InputVarType.checkbox)
if (requiredVars.length) {
requiredVars.forEach(({ variable, label, type }) => {
if (hasEmptyInput)

View File

@ -6,6 +6,9 @@ import Textarea from '@/app/components/base/textarea'
import { PortalSelect } from '@/app/components/base/select'
import { FileUploaderInAttachmentWrapper } from '@/app/components/base/file-uploader'
import { InputVarType } from '@/app/components/workflow/types'
import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input'
import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor'
import { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
type Props = {
showTip?: boolean
@ -42,12 +45,14 @@ const InputsFormContent = ({ showTip }: Props) => {
<div className='space-y-4'>
{visibleInputsForms.map(form => (
<div key={form.variable} className='space-y-1'>
<div className='flex h-6 items-center gap-1'>
<div className='system-md-semibold text-text-secondary'>{form.label}</div>
{!form.required && (
<div className='system-xs-regular text-text-tertiary'>{t('appDebug.variableTable.optional')}</div>
)}
</div>
{form.type !== InputVarType.checkbox && (
<div className='flex h-6 items-center gap-1'>
<div className='system-md-semibold text-text-secondary'>{form.label}</div>
{!form.required && (
<div className='system-xs-regular text-text-tertiary'>{t('appDebug.variableTable.optional')}</div>
)}
</div>
)}
{form.type === InputVarType.textInput && (
<Input
value={inputsFormValue?.[form.variable] || ''}
@ -70,6 +75,14 @@ const InputsFormContent = ({ showTip }: Props) => {
placeholder={form.label}
/>
)}
{form.type === InputVarType.checkbox && (
<BoolInput
name={form.label}
value={!!inputsFormValue?.[form.variable]}
required={form.required}
onChange={value => handleFormChange(form.variable, value)}
/>
)}
{form.type === InputVarType.select && (
<PortalSelect
popupClassName='w-[200px]'
@ -105,6 +118,18 @@ const InputsFormContent = ({ showTip }: Props) => {
}}
/>
)}
{form.type === InputVarType.jsonObject && (
<CodeEditor
language={CodeLanguage.json}
value={inputsFormValue?.[form.variable] || ''}
onChange={v => handleFormChange(form.variable, v)}
noWrapper
className='bg h-[80px] overflow-y-auto rounded-[10px] bg-components-input-bg-normal p-1'
placeholder={
<div className='whitespace-pre'>{form.json_schema}</div>
}
/>
)}
</div>
))}
{showTip && (

View File

@ -12,7 +12,7 @@ export const useCheckInputsForms = () => {
const checkInputsForm = useCallback((inputs: Record<string, any>, inputsForm: InputForm[]) => {
let hasEmptyInput = ''
let fileIsUploading = false
const requiredVars = inputsForm.filter(({ required }) => required)
const requiredVars = inputsForm.filter(({ required, type }) => required && type !== InputVarType.checkbox) // boolean can be not checked
if (requiredVars?.length) {
requiredVars.forEach(({ variable, label, type }) => {

View File

@ -31,6 +31,12 @@ export const getProcessedInputs = (inputs: Record<string, any>, inputsForm: Inpu
inputsForm.forEach((item) => {
const inputValue = inputs[item.variable]
// set boolean type default value
if(item.type === InputVarType.checkbox) {
processedInputs[item.variable] = !!inputValue
return
}
if (!inputValue)
return

View File

@ -90,7 +90,7 @@ const ChatWrapper = () => {
let hasEmptyInput = ''
let fileIsUploading = false
const requiredVars = inputsForms.filter(({ required }) => required)
const requiredVars = inputsForms.filter(({ required, type }) => required && type !== InputVarType.checkbox) // boolean can be not checked
if (requiredVars.length) {
requiredVars.forEach(({ variable, label, type }) => {
if (hasEmptyInput)

View File

@ -195,6 +195,13 @@ export const useEmbeddedChatbot = () => {
type: 'number',
}
}
if (item.checkbox) {
return {
...item.checkbox,
default: false,
type: 'checkbox',
}
}
if (item.select) {
const isInputInOptions = item.select.options.includes(initInputs[item.select.variable])
return {
@ -218,6 +225,13 @@ export const useEmbeddedChatbot = () => {
}
}
if (item.json_object) {
return {
...item.json_object,
type: 'json_object',
}
}
let value = initInputs[item['text-input'].variable]
if (value && item['text-input'].max_length && value.length > item['text-input'].max_length)
value = value.slice(0, item['text-input'].max_length)
@ -312,7 +326,7 @@ export const useEmbeddedChatbot = () => {
let hasEmptyInput = ''
let fileIsUploading = false
const requiredVars = inputsForms.filter(({ required }) => required)
const requiredVars = inputsForms.filter(({ required, type }) => required && type !== InputVarType.checkbox)
if (requiredVars.length) {
requiredVars.forEach(({ variable, label, type }) => {
if (hasEmptyInput)

View File

@ -6,6 +6,9 @@ import Textarea from '@/app/components/base/textarea'
import { PortalSelect } from '@/app/components/base/select'
import { FileUploaderInAttachmentWrapper } from '@/app/components/base/file-uploader'
import { InputVarType } from '@/app/components/workflow/types'
import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input'
import { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor'
type Props = {
showTip?: boolean
@ -42,12 +45,14 @@ const InputsFormContent = ({ showTip }: Props) => {
<div className='space-y-4'>
{visibleInputsForms.map(form => (
<div key={form.variable} className='space-y-1'>
{form.type !== InputVarType.checkbox && (
<div className='flex h-6 items-center gap-1'>
<div className='system-md-semibold text-text-secondary'>{form.label}</div>
{!form.required && (
<div className='system-xs-regular text-text-tertiary'>{t('appDebug.variableTable.optional')}</div>
)}
</div>
)}
{form.type === InputVarType.textInput && (
<Input
value={inputsFormValue?.[form.variable] || ''}
@ -70,6 +75,14 @@ const InputsFormContent = ({ showTip }: Props) => {
placeholder={form.label}
/>
)}
{form.type === InputVarType.checkbox && (
<BoolInput
name={form.label}
value={inputsFormValue?.[form.variable]}
required={form.required}
onChange={value => handleFormChange(form.variable, value)}
/>
)}
{form.type === InputVarType.select && (
<PortalSelect
popupClassName='w-[200px]'
@ -105,6 +118,18 @@ const InputsFormContent = ({ showTip }: Props) => {
}}
/>
)}
{form.type === InputVarType.jsonObject && (
<CodeEditor
language={CodeLanguage.json}
value={inputsFormValue?.[form.variable] || ''}
onChange={v => handleFormChange(form.variable, v)}
noWrapper
className='bg h-[80px] overflow-y-auto rounded-[10px] bg-components-input-bg-normal p-1'
placeholder={
<div className='whitespace-pre'>{form.json_schema}</div>
}
/>
)}
</div>
))}
{showTip && (

View File

@ -10,7 +10,6 @@ export type IDrawerProps = {
description?: string
dialogClassName?: string
dialogBackdropClassName?: string
containerClassName?: string
panelClassName?: string
children: React.ReactNode
footer?: React.ReactNode
@ -23,7 +22,6 @@ export type IDrawerProps = {
onCancel?: () => void
onOk?: () => void
unmount?: boolean
noOverlay?: boolean
}
export default function Drawer({
@ -31,7 +29,6 @@ export default function Drawer({
description = '',
dialogClassName = '',
dialogBackdropClassName = '',
containerClassName = '',
panelClassName = '',
children,
footer,
@ -44,7 +41,6 @@ export default function Drawer({
onCancel,
onOk,
unmount = false,
noOverlay = false,
}: IDrawerProps) {
const { t } = useTranslation()
return (
@ -54,14 +50,14 @@ export default function Drawer({
onClose={() => !clickOutsideNotOpen && onClose()}
className={cn('fixed inset-0 z-[30] overflow-y-auto', dialogClassName)}
>
<div className={cn('flex h-screen w-screen justify-end', positionCenter && '!justify-center', containerClassName)}>
<div className={cn('flex h-screen w-screen justify-end', positionCenter && '!justify-center')}>
{/* mask */}
{!noOverlay && <DialogBackdrop
<DialogBackdrop
className={cn('fixed inset-0 z-[40]', mask && 'bg-black/30', dialogBackdropClassName)}
onClick={() => {
!clickOutsideNotOpen && onClose()
}}
/>}
/>
<div className={cn('relative z-[50] flex w-full max-w-sm flex-col justify-between overflow-hidden bg-components-panel-bg p-6 text-left align-middle shadow-xl', panelClassName)}>
<>
<div className='flex justify-between'>

View File

@ -24,7 +24,7 @@ export enum FormTypeEnum {
secretInput = 'secret-input',
select = 'select',
radio = 'radio',
boolean = 'boolean',
checkbox = 'checkbox',
files = 'files',
file = 'file',
modelSelector = 'model-selector',

View File

@ -3,34 +3,11 @@
* Extracted from the main markdown renderer for modularity.
* Uses the ImageGallery component to display images.
*/
import React, { useEffect, useMemo } from 'react'
import React from 'react'
import ImageGallery from '@/app/components/base/image-gallery'
import { getMarkdownImageURL } from './utils'
import { usePluginReadmeAsset } from '@/service/use-plugins'
const Img = ({ src, pluginUniqueIdentifier }: { src: string, pluginUniqueIdentifier?: string }) => {
const imgURL = getMarkdownImageURL(src, pluginUniqueIdentifier)
const { data: asset } = usePluginReadmeAsset({ plugin_unique_identifier: pluginUniqueIdentifier, file_name: src })
const blobUrl = useMemo(() => {
if (asset)
return URL.createObjectURL(asset)
return imgURL
}, [asset, imgURL])
useEffect(() => {
return () => {
if (blobUrl && asset)
URL.revokeObjectURL(blobUrl)
}
}, [blobUrl])
return (
<div className='markdown-img-wrapper'>
<ImageGallery srcs={[blobUrl]} />
</div>
)
const Img = ({ src }: any) => {
return <div className="markdown-img-wrapper"><ImageGallery srcs={[src]} /></div>
}
export default Img

View File

@ -3,43 +3,25 @@
* Extracted from the main markdown renderer for modularity.
* Handles special rendering for paragraphs that directly contain an image.
*/
import React, { useEffect, useMemo } from 'react'
import React from 'react'
import ImageGallery from '@/app/components/base/image-gallery'
// import { getMarkdownImageURL } from './utils'
import { usePluginReadmeAsset } from '@/service/use-plugins'
const Paragraph = (props: { pluginUniqueIdentifier?: string, node?: any, children?: any }) => {
const { node, pluginUniqueIdentifier, children } = props
const Paragraph = (paragraph: any) => {
const { node }: any = paragraph
const children_node = node.children
const { data: asset } = usePluginReadmeAsset({ plugin_unique_identifier: pluginUniqueIdentifier, file_name: children_node[0].properties?.src })
const blobUrl = useMemo(() => {
if (asset)
return URL.createObjectURL(asset)
return ''
}, [asset])
useEffect(() => {
return () => {
if (blobUrl && asset)
URL.revokeObjectURL(blobUrl)
}
}, [blobUrl])
if (children_node?.[0]?.tagName === 'img') {
// const imageURL = getMarkdownImageURL(children_node[0].properties?.src, pluginUniqueIdentifier)
if (children_node && children_node[0] && 'tagName' in children_node[0] && children_node[0].tagName === 'img') {
return (
<div className="markdown-img-wrapper">
<ImageGallery srcs={[blobUrl]} />
<ImageGallery srcs={[children_node[0].properties.src]} />
{
Array.isArray(children) && children.length > 1 && (
<div className="mt-2">{children.slice(1)}</div>
Array.isArray(paragraph.children) && paragraph.children.length > 1 && (
<div className="mt-2">{paragraph.children.slice(1)}</div>
)
}
</div>
)
}
return <p>{children}</p>
return <p>{paragraph.children}</p>
}
export default Paragraph

View File

@ -1,14 +1,7 @@
import { ALLOW_UNSAFE_DATA_SCHEME, MARKETPLACE_API_PREFIX } from '@/config'
import { ALLOW_UNSAFE_DATA_SCHEME } from '@/config'
export const isValidUrl = (url: string): boolean => {
const validPrefixes = ['http:', 'https:', '//', 'mailto:']
if (ALLOW_UNSAFE_DATA_SCHEME) validPrefixes.push('data:')
return validPrefixes.some(prefix => url.startsWith(prefix))
}
export const getMarkdownImageURL = (url: string, pathname?: string) => {
const regex = /(^\.\/_assets|^_assets)/
if (regex.test(url))
return `${MARKETPLACE_API_PREFIX}${pathname ?? ''}${url.replace(regex, '/_assets')}`
return url
}

View File

@ -33,11 +33,10 @@ export type MarkdownProps = {
className?: string
customDisallowedElements?: string[]
customComponents?: Record<string, React.ComponentType<any>>
pluginUniqueIdentifier?: string
}
export const Markdown = (props: MarkdownProps) => {
const { customComponents = {}, pluginUniqueIdentifier } = props
const { customComponents = {} } = props
const latexContent = flow([
preprocessThinkTag,
preprocessLaTeX,
@ -77,11 +76,11 @@ export const Markdown = (props: MarkdownProps) => {
disallowedElements={['iframe', 'head', 'html', 'meta', 'link', 'style', 'body', ...(props.customDisallowedElements || [])]}
components={{
code: CodeBlock,
img: props => <Img {...{ pluginUniqueIdentifier, ...props, src: props.src as string }} />,
img: Img,
video: VideoBlock,
audio: AudioBlock,
a: Link,
p: props => <Paragraph {...{ pluginUniqueIdentifier, ...props }} />,
p: Paragraph,
button: MarkdownButton,
form: MarkdownForm,
script: ScriptBlock as any,

View File

@ -8,7 +8,6 @@ import { noop } from 'lodash-es'
type IModal = {
className?: string
wrapperClassName?: string
containerClassName?: string
isShow: boolean
onClose?: () => void
title?: React.ReactNode
@ -17,13 +16,11 @@ type IModal = {
closable?: boolean
overflowVisible?: boolean
highPriority?: boolean // For modals that need to appear above dropdowns
noOverlay?: boolean
}
export default function Modal({
className,
wrapperClassName,
containerClassName,
isShow,
onClose = noop,
title,
@ -32,19 +29,18 @@ export default function Modal({
closable = false,
overflowVisible = false,
highPriority = false,
noOverlay = false,
}: IModal) {
return (
<Transition appear show={isShow} as={Fragment}>
<Dialog as="div" className={classNames('relative', highPriority ? 'z-[1100]' : 'z-[60]', wrapperClassName)} onClose={onClose}>
{!noOverlay && <TransitionChild>
<TransitionChild>
<div className={classNames(
'fixed inset-0 bg-background-overlay',
'duration-300 ease-in data-[closed]:opacity-0',
'data-[enter]:opacity-100',
'data-[leave]:opacity-0',
)} />
</TransitionChild>}
</TransitionChild>
<div
className="fixed inset-0 overflow-y-auto"
@ -53,7 +49,7 @@ export default function Modal({
e.stopPropagation()
}}
>
<div className={classNames('flex min-h-full items-center justify-center p-4 text-center', containerClassName)}>
<div className="flex min-h-full items-center justify-center p-4 text-center">
<TransitionChild>
<DialogPanel className={classNames(
'relative w-full max-w-[480px] rounded-2xl bg-components-panel-bg p-6 text-left align-middle shadow-xl transition-all',

View File

@ -26,7 +26,6 @@ type ModalProps = {
footerSlot?: React.ReactNode
bottomSlot?: React.ReactNode
disabled?: boolean
containerClassName?: string
}
const Modal = ({
onClose,
@ -45,7 +44,6 @@ const Modal = ({
footerSlot,
bottomSlot,
disabled,
containerClassName,
}: ModalProps) => {
const { t } = useTranslation()
@ -81,7 +79,7 @@ const Modal = ({
</div>
{
children && (
<div className={cn('px-6 py-3', containerClassName)}>{children}</div>
<div className='px-6 py-3'>{children}</div>
)
}
<div className='flex justify-between p-6 pt-5'>

View File

@ -53,7 +53,6 @@ const CurrentBlockReplacementBlock = ({
return mergeRegister(
editor.registerNodeTransform(CustomTextNode, textNode => decoratorTransform(textNode, getMatch, createCurrentBlockNode)),
)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
return null

View File

@ -52,7 +52,6 @@ const ErrorMessageBlockReplacementBlock = ({
return mergeRegister(
editor.registerNodeTransform(CustomTextNode, textNode => decoratorTransform(textNode, getMatch, createErrorMessageBlockNode)),
)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
return null

View File

@ -52,7 +52,6 @@ const LastRunReplacementBlock = ({
return mergeRegister(
editor.registerNodeTransform(CustomTextNode, textNode => decoratorTransform(textNode, getMatch, createLastRunBlockNode)),
)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
return null

View File

@ -66,7 +66,7 @@ const TagsFilter = ({
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-[1000]'>
<div className='w-[240px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg backdrop-blur-sm'>
<div className='w-[240px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg'>
<div className='p-2 pb-1'>
<Input
showLeftIcon

View File

@ -20,8 +20,6 @@ import {
useGetPluginCredentialSchemaHook,
useUpdatePluginCredentialHook,
} from '../hooks/use-credential'
import { ReadmeEntrance } from '../../readme-panel/entrance'
import { ReadmeShowType } from '../../readme-panel/context'
export type ApiKeyModalProps = {
pluginPayload: PluginPayload
@ -144,7 +142,6 @@ const ApiKeyModal = ({
onExtraButtonClick={onRemove}
disabled={disabled || isLoading || doingAction}
>
<ReadmeEntrance detail={pluginPayload.detail} showType={ReadmeShowType.modal} />
{
isLoading && (
<div className='flex h-40 items-center justify-center'>

View File

@ -23,8 +23,6 @@ import type {
} from '@/app/components/base/form/types'
import { useToastContext } from '@/app/components/base/toast'
import Button from '@/app/components/base/button'
import { ReadmeEntrance } from '../../readme-panel/entrance'
import { ReadmeShowType } from '../../readme-panel/context'
type OAuthClientSettingsProps = {
pluginPayload: PluginPayload
@ -156,16 +154,16 @@ const OAuthClientSettings = ({
</div>
)
}
containerClassName='pt-0'
>
<ReadmeEntrance detail={pluginPayload.detail} showType={ReadmeShowType.modal} />
<AuthForm
formFromProps={form}
ref={formRef}
formSchemas={schemas}
defaultValues={editValues || defaultValues}
disabled={disabled}
/>
<>
<AuthForm
formFromProps={form}
ref={formRef}
formSchemas={schemas}
defaultValues={editValues || defaultValues}
disabled={disabled}
/>
</>
</Modal>
)
}

View File

@ -1,5 +1,3 @@
import type { PluginDetail } from '../types'
export enum AuthCategory {
tool = 'tool',
datasource = 'datasource',
@ -9,7 +7,6 @@ export enum AuthCategory {
export type PluginPayload = {
category: AuthCategory
provider: string
detail: PluginDetail
}
export enum CredentialTypeEnum {

View File

@ -77,6 +77,13 @@ const AppInputsPanel = ({
required: false,
}
}
if(item.checkbox) {
return {
...item.checkbox,
type: 'checkbox',
required: false,
}
}
if (item.select) {
return {
...item.select,
@ -103,6 +110,13 @@ const AppInputsPanel = ({
}
}
if (item.json_object) {
return {
...item.json_object,
type: 'json_object',
}
}
return {
...item['text-input'],
type: 'text-input',

View File

@ -61,7 +61,7 @@ const DetailHeader = ({
onUpdate,
}: Props) => {
const { t } = useTranslation()
const { userProfile: { timezone } } = useAppContext()
const { userProfile: { timezone } } = useAppContext()
const { theme } = useTheme()
const locale = useGetLanguage()
@ -128,13 +128,13 @@ const DetailHeader = ({
return false
if (!autoUpgradeInfo || !isFromMarketplace)
return false
if (autoUpgradeInfo.strategy_setting === 'disabled')
if(autoUpgradeInfo.strategy_setting === 'disabled')
return false
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.update_all)
if(autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.update_all)
return true
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.partial && autoUpgradeInfo.include_plugins.includes(plugin_id))
if(autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.partial && autoUpgradeInfo.include_plugins.includes(plugin_id))
return true
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.exclude && !autoUpgradeInfo.exclude_plugins.includes(plugin_id))
if(autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.exclude && !autoUpgradeInfo.exclude_plugins.includes(plugin_id))
return true
return false
}, [autoUpgradeInfo, plugin_id, isFromMarketplace])
@ -331,7 +331,6 @@ const DetailHeader = ({
pluginPayload={{
provider: provider?.name || '',
category: AuthCategory.tool,
detail,
}}
/>
)

View File

@ -3,7 +3,7 @@ import { useTranslation } from 'react-i18next'
import { useBoolean } from 'ahooks'
import copy from 'copy-to-clipboard'
import { RiClipboardLine, RiDeleteBinLine, RiEditLine, RiLoginCircleLine } from '@remixicon/react'
import type { EndpointListItem, PluginDetail } from '../types'
import type { EndpointListItem } from '../types'
import EndpointModal from './endpoint-modal'
import { NAME_FIELD } from './utils'
import { addDefaultValue, toolCredentialToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
@ -22,13 +22,11 @@ import {
} from '@/service/use-endpoints'
type Props = {
pluginDetail: PluginDetail
data: EndpointListItem
handleChange: () => void
}
const EndpointCard = ({
pluginDetail,
data,
handleChange,
}: Props) => {
@ -212,7 +210,6 @@ const EndpointCard = ({
defaultValues={formValue}
onCancel={hideEndpointModalConfirm}
onSaved={handleUpdate}
pluginDetail={pluginDetail}
/>
)}
</div>

View File

@ -102,7 +102,6 @@ const EndpointList = ({ detail }: Props) => {
key={index}
data={item}
handleChange={() => invalidateEndpointList(detail.plugin_id)}
pluginDetail={detail}
/>
))}
</div>
@ -111,7 +110,6 @@ const EndpointList = ({ detail }: Props) => {
formSchemas={formSchemas}
onCancel={hideEndpointModal}
onSaved={handleCreate}
pluginDetail={detail}
/>
)}
</div>

View File

@ -10,15 +10,12 @@ import Form from '@/app/components/header/account-setting/model-provider-page/mo
import Toast from '@/app/components/base/toast'
import { useRenderI18nObject } from '@/hooks/use-i18n'
import cn from '@/utils/classnames'
import { ReadmeEntrance } from '../readme-panel/entrance'
import type { PluginDetail } from '../types'
type Props = {
formSchemas: any
defaultValues?: any
onCancel: () => void
onSaved: (value: Record<string, any>) => void
pluginDetail: PluginDetail
}
const extractDefaultValues = (schemas: any[]) => {
@ -35,7 +32,6 @@ const EndpointModal: FC<Props> = ({
defaultValues = {},
onCancel,
onSaved,
pluginDetail,
}) => {
const getValueFromI18nObject = useRenderI18nObject()
const { t } = useTranslation()
@ -59,9 +55,9 @@ const EndpointModal: FC<Props> = ({
const value = processedCredential[field.name]
if (typeof value === 'string')
processedCredential[field.name] = value === 'true' || value === '1' || value === 'True'
else if (typeof value === 'number')
else if (typeof value === 'number')
processedCredential[field.name] = value === 1
else if (typeof value === 'boolean')
else if (typeof value === 'boolean')
processedCredential[field.name] = value
}
})
@ -88,7 +84,6 @@ const EndpointModal: FC<Props> = ({
</ActionButton>
</div>
<div className='system-xs-regular mt-0.5 text-text-tertiary'>{t('plugin.detailPanel.endpointModalDesc')}</div>
<ReadmeEntrance detail={pluginDetail} className='px-0 pt-3' />
</div>
<div className='grow overflow-y-auto'>
<div className='px-4 py-2'>

View File

@ -1,5 +1,5 @@
'use client'
import React, { useCallback } from 'react'
import React from 'react'
import type { FC } from 'react'
import DetailHeader from './detail-header'
import EndpointList from './endpoint-list'
@ -9,7 +9,6 @@ import AgentStrategyList from './agent-strategy-list'
import Drawer from '@/app/components/base/drawer'
import type { PluginDetail } from '@/app/components/plugins/types'
import cn from '@/utils/classnames'
import { ReadmeEntrance } from '../readme-panel/entrance'
type Props = {
detail?: PluginDetail
@ -22,11 +21,11 @@ const PluginDetailPanel: FC<Props> = ({
onUpdate,
onHide,
}) => {
const handleUpdate = useCallback((isDelete = false) => {
const handleUpdate = (isDelete = false) => {
if (isDelete)
onHide()
onUpdate()
}, [onHide, onUpdate])
}
if (!detail)
return null
@ -43,17 +42,16 @@ const PluginDetailPanel: FC<Props> = ({
>
{detail && (
<>
<DetailHeader detail={detail} onUpdate={handleUpdate} onHide={onHide} />
<DetailHeader
detail={detail}
onHide={onHide}
onUpdate={handleUpdate}
/>
<div className='grow overflow-y-auto'>
<div className='flex min-h-full flex-col'>
<div className='flex-1'>
{!!detail.declaration.tool && <ActionList detail={detail} />}
{!!detail.declaration.agent_strategy && <AgentStrategyList detail={detail} />}
{!!detail.declaration.endpoint && <EndpointList detail={detail} />}
{!!detail.declaration.model && <ModelList detail={detail} />}
</div>
<ReadmeEntrance detail={detail} className='mt-auto' />
</div>
{!!detail.declaration.tool && <ActionList detail={detail} />}
{!!detail.declaration.agent_strategy && <AgentStrategyList detail={detail} />}
{!!detail.declaration.endpoint && <EndpointList detail={detail} />}
{!!detail.declaration.model && <ModelList detail={detail} />}
</div>
</>
)}

View File

@ -63,6 +63,8 @@ const StrategyDetail: FC<Props> = ({
return t('tools.setBuiltInTools.number')
if (type === 'text-input')
return t('tools.setBuiltInTools.string')
if (type === 'checkbox')
return 'boolean'
if (type === 'file')
return t('tools.setBuiltInTools.file')
if (type === 'array[tools]')

View File

@ -40,7 +40,6 @@ import {
AuthCategory,
PluginAuthInAgent,
} from '@/app/components/plugins/plugin-auth'
import { ReadmeEntrance } from '../../readme-panel/entrance'
type Props = {
disabled?: boolean
@ -273,10 +272,7 @@ const ToolSelector: FC<Props> = ({
{/* base form */}
<div className='flex flex-col gap-3 px-4 py-2'>
<div className='flex flex-col gap-1'>
<div className='system-sm-semibold flex h-6 items-center justify-between text-text-secondary'>
{t('plugin.detailPanel.toolSelector.toolLabel')}
<ReadmeEntrance detail={currentProvider as any} showShortTip className='pb-0' />
</div>
<div className='system-sm-semibold flex h-6 items-center text-text-secondary'>{t('plugin.detailPanel.toolSelector.toolLabel')}</div>
<ToolPicker
placement='bottom'
offset={offset}
@ -318,7 +314,6 @@ const ToolSelector: FC<Props> = ({
pluginPayload={{
provider: currentProvider.name,
category: AuthCategory.tool,
detail: currentProvider as any,
}}
credentialId={value?.credential_id}
onAuthorizationItemClick={handleAuthorizationItemClick}

View File

@ -90,7 +90,7 @@ const CategoriesFilter = ({
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-10'>
<div className='w-[240px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg backdrop-blur-sm'>
<div className='w-[240px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg'>
<div className='p-2 pb-1'>
<Input
showLeftIcon

View File

@ -85,7 +85,7 @@ const TagsFilter = ({
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-10'>
<div className='w-[240px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg backdrop-blur-sm'>
<div className='w-[240px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg'>
<div className='p-2 pb-1'>
<Input
showLeftIcon

View File

@ -1,119 +0,0 @@
'use client'
import React from 'react'
import type { FC } from 'react'
import { useTranslation } from 'react-i18next'
import {
RiBugLine,
RiHardDrive3Line,
RiVerifiedBadgeLine,
} from '@remixicon/react'
import { BoxSparkleFill } from '@/app/components/base/icons/src/vender/plugin'
import { Github } from '@/app/components/base/icons/src/public/common'
import Tooltip from '@/app/components/base/tooltip'
import Badge from '@/app/components/base/badge'
import { API_PREFIX } from '@/config'
import { useAppContext } from '@/context/app-context'
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
import type { PluginDetail } from '@/app/components/plugins/types'
import { PluginSource } from '@/app/components/plugins/types'
import OrgInfo from '@/app/components/plugins/card/base/org-info'
import Icon from '@/app/components/plugins/card/base/card-icon'
type PluginInfoProps = {
detail: PluginDetail
size?: 'default' | 'large'
}
const PluginInfo: FC<PluginInfoProps> = ({
detail,
size = 'default',
}) => {
const { t } = useTranslation()
const { currentWorkspace } = useAppContext()
const locale = useLanguage()
const tenant_id = currentWorkspace?.id
const {
version,
source,
} = detail
const {
icon,
label,
author,
name,
verified,
} = detail.declaration || detail
const isLarge = size === 'large'
const iconSize = isLarge ? 'h-10 w-10' : 'h-8 w-8'
const titleSize = isLarge ? 'text-sm' : 'text-xs'
return (
<div className={`flex items-center ${isLarge ? 'gap-3' : 'gap-2'}`}>
{/* Plugin Icon */}
<div className={`shrink-0 overflow-hidden rounded-lg border border-components-panel-border-subtle ${iconSize}`}>
<Icon src={`${API_PREFIX}/workspaces/current/plugin/icon?tenant_id=${tenant_id}&filename=${icon}`} />
</div>
{/* Plugin Details */}
<div className="min-w-0 flex-1">
{/* Name and Version */}
<div className="mb-0.5 flex items-center gap-1">
<h3 className={`truncate font-semibold text-text-secondary ${titleSize}`}>
{label[locale]}
</h3>
{verified && <RiVerifiedBadgeLine className="h-3 w-3 shrink-0 text-text-accent" />}
<Badge
className="mx-1"
uppercase={false}
text={version}
/>
</div>
{/* Organization and Source */}
<div className="flex items-center text-xs">
<OrgInfo
packageNameClassName="w-auto"
orgName={author}
packageName={name}
/>
<div className="ml-1 mr-0.5 text-text-quaternary">·</div>
{/* Source Icon */}
{source === PluginSource.marketplace && (
<Tooltip popupContent={t('plugin.detailPanel.categoryTip.marketplace')}>
<div>
<BoxSparkleFill className="h-3.5 w-3.5 text-text-tertiary hover:text-text-accent" />
</div>
</Tooltip>
)}
{source === PluginSource.github && (
<Tooltip popupContent={t('plugin.detailPanel.categoryTip.github')}>
<div>
<Github className="h-3.5 w-3.5 text-text-secondary hover:text-text-primary" />
</div>
</Tooltip>
)}
{source === PluginSource.local && (
<Tooltip popupContent={t('plugin.detailPanel.categoryTip.local')}>
<div>
<RiHardDrive3Line className="h-3.5 w-3.5 text-text-tertiary" />
</div>
</Tooltip>
)}
{source === PluginSource.debugging && (
<Tooltip popupContent={t('plugin.detailPanel.categoryTip.debugging')}>
<div>
<RiBugLine className="h-3.5 w-3.5 text-text-tertiary hover:text-text-warning" />
</div>
</Tooltip>
)}
</div>
</div>
</div>
)
}
export default PluginInfo

View File

@ -1,6 +0,0 @@
export const BUILTIN_TOOLS_ARRAY = [
'code',
'audio',
'time',
'webscraper',
]

Some files were not shown because too many files have changed in this diff Show More