Merge branch 'main' into feat/rag-2

This commit is contained in:
twwu
2025-08-27 11:16:27 +08:00
438 changed files with 17986 additions and 7846 deletions

View File

@ -3,6 +3,17 @@ import re
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.external_data_tool.factory import ExternalDataToolFactory
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
[
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.EXTERNAL_DATA_TOOL,
VariableEntityType.CHECKBOX,
]
)
class BasicVariablesConfigManager:
@classmethod
@ -47,6 +58,7 @@ class BasicVariablesConfigManager:
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.SELECT,
VariableEntityType.CHECKBOX,
}:
variable = variables[variable_type]
variable_entities.append(
@ -96,8 +108,17 @@ class BasicVariablesConfigManager:
variables = []
for item in config["user_input_form"]:
key = list(item.keys())[0]
if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
# if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
if key not in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.EXTERNAL_DATA_TOOL,
VariableEntityType.CHECKBOX,
}:
allowed_keys = ", ".join(i.value for i in _ALLOWED_VARIABLE_ENTITY_TYPE)
raise ValueError(f"Keys in user_input_form list can only be {allowed_keys}")
form_item = item[key]
if "label" not in form_item:

View File

@ -8,6 +8,8 @@ from core.app.entities.task_entities import AppBlockingResponse, AppStreamRespon
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
logger = logging.getLogger(__name__)
class AppGenerateResponseConverter(ABC):
_blocking_response_type: type[AppBlockingResponse]
@ -120,7 +122,7 @@ class AppGenerateResponseConverter(ABC):
if data:
data.setdefault("message", getattr(e, "description", str(e)))
else:
logging.error(e)
logger.error(e)
data = {
"code": "internal_server_error",
"message": "Internal Server Error, please contact support.",

View File

@ -103,18 +103,23 @@ class BaseAppGenerator:
f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string"
)
if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str):
# handle empty string case
if not value.strip():
return None
# may raise ValueError if user_input_value is not a valid number
try:
if "." in value:
return float(value)
else:
return int(value)
except ValueError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
if variable_entity.type == VariableEntityType.NUMBER:
if isinstance(value, (int, float)):
return value
elif isinstance(value, str):
# handle empty string case
if not value.strip():
return None
# may raise ValueError if user_input_value is not a valid number
try:
if "." in value:
return float(value)
else:
return int(value)
except ValueError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
else:
raise TypeError(f"expected value type int, float or str, got {type(value)}, value: {value}")
match variable_entity.type:
case VariableEntityType.SELECT:
@ -144,6 +149,11 @@ class BaseAppGenerator:
raise ValueError(
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
)
case VariableEntityType.CHECKBOX:
if not isinstance(value, bool):
raise ValueError(f"{variable_entity.variable} in input form must be a valid boolean value")
case _:
raise AssertionError("this statement should be unreachable.")
return value

View File

@ -32,6 +32,8 @@ from extensions.ext_database import db
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
from services.annotation_service import AppAnnotationService
logger = logging.getLogger(__name__)
class MessageCycleManager:
def __init__(
@ -98,7 +100,7 @@ class MessageCycleManager:
conversation.name = name
except Exception as e:
if dify_config.DEBUG:
logging.exception("generate conversation name failed, conversation_id: %s", conversation_id)
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
pass
db.session.merge(conversation)

View File

@ -19,6 +19,7 @@ class ModelStatus(Enum):
QUOTA_EXCEEDED = "quota-exceeded"
NO_PERMISSION = "no-permission"
DISABLED = "disabled"
CREDENTIAL_REMOVED = "credential-removed"
class SimpleModelProviderEntity(BaseModel):
@ -54,6 +55,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
status: ModelStatus
load_balancing_enabled: bool = False
has_invalid_load_balancing_configs: bool = False
def raise_for_status(self) -> None:
"""

File diff suppressed because it is too large Load Diff

View File

@ -69,6 +69,15 @@ class QuotaConfiguration(BaseModel):
restrict_models: list[RestrictModel] = []
class CredentialConfiguration(BaseModel):
"""
Model class for credential configuration.
"""
credential_id: str
credential_name: str
class SystemConfiguration(BaseModel):
"""
Model class for provider system configuration.
@ -86,6 +95,9 @@ class CustomProviderConfiguration(BaseModel):
"""
credentials: dict
current_credential_id: Optional[str] = None
current_credential_name: Optional[str] = None
available_credentials: list[CredentialConfiguration] = []
class CustomModelConfiguration(BaseModel):
@ -95,7 +107,10 @@ class CustomModelConfiguration(BaseModel):
model: str
model_type: ModelType
credentials: dict
credentials: dict | None
current_credential_id: Optional[str] = None
current_credential_name: Optional[str] = None
available_model_credentials: list[CredentialConfiguration] = []
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@ -118,6 +133,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
id: str
name: str
credentials: dict
credential_source_type: str | None = None
class ModelSettings(BaseModel):

View File

@ -10,6 +10,8 @@ from pydantic import BaseModel
from core.helper.position_helper import sort_to_dict_by_position_map
logger = logging.getLogger(__name__)
class ExtensionModule(enum.Enum):
MODERATION = "moderation"
@ -66,7 +68,7 @@ class Extensible:
# Check for extension module file
if (extension_name + ".py") not in file_names:
logging.warning("Missing %s.py file in %s, Skip.", extension_name, subdir_path)
logger.warning("Missing %s.py file in %s, Skip.", extension_name, subdir_path)
continue
# Check for builtin flag and position
@ -95,7 +97,7 @@ class Extensible:
break
if not extension_class:
logging.warning("Missing subclass of %s in %s, Skip.", cls.__name__, module_name)
logger.warning("Missing subclass of %s in %s, Skip.", cls.__name__, module_name)
continue
# Load schema if not builtin
@ -103,7 +105,7 @@ class Extensible:
if not builtin:
json_path = os.path.join(subdir_path, "schema.json")
if not os.path.exists(json_path):
logging.warning("Missing schema.json file in %s, Skip.", subdir_path)
logger.warning("Missing schema.json file in %s, Skip.", subdir_path)
continue
with open(json_path, encoding="utf-8") as f:
@ -122,7 +124,7 @@ class Extensible:
)
except Exception as e:
logging.exception("Error scanning extensions")
logger.exception("Error scanning extensions")
raise
# Sort extensions by position

View File

@ -17,6 +17,7 @@ def encrypt_token(tenant_id: str, token: str):
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
raise ValueError(f"Tenant with id {tenant_id} not found")
assert tenant.encrypt_public_key is not None
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()

View File

@ -4,6 +4,8 @@ import sys
from types import ModuleType
from typing import AnyStr
logger = logging.getLogger(__name__)
def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
"""
@ -30,7 +32,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
spec.loader.exec_module(module)
return module
except Exception as e:
logging.exception("Failed to load module %s from script file '%s'", module_name, repr(py_file_path))
logger.exception("Failed to load module %s from script file '%s'", module_name, repr(py_file_path))
raise e

View File

@ -9,6 +9,8 @@ import httpx
from configs import dify_config
logger = logging.getLogger(__name__)
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True
@ -73,12 +75,12 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if response.status_code not in STATUS_FORCELIST:
return response
else:
logging.warning(
logger.warning(
"Received status code %s for URL %s which is in the force list", response.status_code, url
)
except httpx.RequestError as e:
logging.warning("Request to URL %s failed on attempt %s: %s", url, retries + 1, e)
logger.warning("Request to URL %s failed on attempt %s: %s", url, retries + 1, e)
if max_retries == 0:
raise

View File

@ -39,6 +39,8 @@ from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
class IndexingRunner:
def __init__(self):
@ -90,9 +92,9 @@ class IndexingRunner:
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
except ObjectDeletedError:
logging.warning("Document deleted, document id: %s", dataset_document.id)
logger.warning("Document deleted, document id: %s", dataset_document.id)
except Exception as e:
logging.exception("consume document failed")
logger.exception("consume document failed")
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
dataset_document.stopped_at = naive_utc_now()
@ -153,7 +155,7 @@ class IndexingRunner:
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
logger.exception("consume document failed")
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
dataset_document.stopped_at = naive_utc_now()
@ -228,7 +230,7 @@ class IndexingRunner:
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
logger.exception("consume document failed")
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
dataset_document.stopped_at = naive_utc_now()
@ -321,7 +323,7 @@ class IndexingRunner:
try:
storage.delete(image_file.key)
except Exception:
logging.exception(
logger.exception(
"Delete image_files failed while indexing_estimate, \
image_upload_file_is: %s",
upload_file_id,

View File

@ -31,6 +31,8 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.graph_engine.entities.event import AgentLogEvent
from models import App, Message, WorkflowNodeExecutionModel, db
logger = logging.getLogger(__name__)
class LLMGenerator:
@classmethod
@ -68,7 +70,7 @@ class LLMGenerator:
result_dict = json.loads(cleaned_answer)
answer = result_dict["Your Output"]
except json.JSONDecodeError as e:
logging.exception("Failed to generate name after answer, use query instead")
logger.exception("Failed to generate name after answer, use query instead")
answer = query
name = answer.strip()
@ -125,7 +127,7 @@ class LLMGenerator:
except InvokeError:
questions = []
except Exception:
logging.exception("Failed to generate suggested questions after answer")
logger.exception("Failed to generate suggested questions after answer")
questions = []
return questions
@ -173,7 +175,7 @@ class LLMGenerator:
error = str(e)
error_step = "generate rule config"
except Exception as e:
logging.exception("Failed to generate rule config, model: %s", model_config.get("name"))
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@ -270,7 +272,7 @@ class LLMGenerator:
error_step = "generate conversation opener"
except Exception as e:
logging.exception("Failed to generate rule config, model: %s", model_config.get("name"))
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@ -319,7 +321,7 @@ class LLMGenerator:
error = str(e)
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logging.exception(
logger.exception(
"Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language
)
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
@ -392,7 +394,7 @@ class LLMGenerator:
error = str(e)
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
except Exception as e:
logging.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
@staticmethod
@ -570,5 +572,5 @@ class LLMGenerator:
error = str(e)
return {"error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logging.exception("Failed to invoke LLM model, model: " + json.dumps(model_config.get("name")), exc_info=e)
logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=e)
return {"error": f"An unexpected error occurred: {str(e)}"}

View File

@ -152,7 +152,7 @@ class MCPClient:
# ExitStack will handle proper cleanup of all managed context managers
self._exit_stack.close()
except Exception as e:
logging.exception("Error during cleanup")
logger.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")
finally:
self._session = None

View File

@ -31,6 +31,9 @@ from core.mcp.types import (
SessionMessage,
)
logger = logging.getLogger(__name__)
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
@ -366,7 +369,7 @@ class BaseSession(
self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)
else: # Response or error
response_queue = self._response_streams.get(message.message.root.id)
if response_queue is not None:
@ -376,7 +379,7 @@ class BaseSession(
except queue.Empty:
continue
except Exception:
logging.exception("Error in message processing loop")
logger.exception("Error in message processing loop")
raise
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:

View File

@ -201,7 +201,7 @@ class ModelProviderFactory:
return filtered_credentials
def get_model_schema(
self, *, provider: str, model_type: ModelType, model: str, credentials: dict
self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None
) -> AIModelEntity | None:
"""
Get model schema

View File

@ -100,14 +100,14 @@ class Moderation(Extensible, ABC):
if not inputs_config.get("preset_response"):
raise ValueError("inputs_config.preset_response is required")
if len(inputs_config.get("preset_response", 0)) > 100:
if len(inputs_config.get("preset_response", "0")) > 100:
raise ValueError("inputs_config.preset_response must be less than 100 characters")
if outputs_config_enabled:
if not outputs_config.get("preset_response"):
raise ValueError("outputs_config.preset_response is required")
if len(outputs_config.get("preset_response", 0)) > 100:
if len(outputs_config.get("preset_response", "0")) > 100:
raise ValueError("outputs_config.preset_response must be less than 100 characters")

View File

@ -306,7 +306,7 @@ class AliyunDataTrace(BaseTraceInstance):
node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
return node_span
except Exception as e:
logging.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
logger.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
return None
def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:

View File

@ -37,6 +37,8 @@ from models.model import App, AppModelConfig, Conversation, Message, MessageFile
from models.workflow import WorkflowAppLog, WorkflowRun
from tasks.ops_trace_task import process_trace_tasks
logger = logging.getLogger(__name__)
class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
def __getitem__(self, provider: str) -> dict[str, Any]:
@ -287,7 +289,7 @@ class OpsTraceManager:
# create new tracing_instance and update the cache if it absent
tracing_instance = trace_instance(config_class(**decrypt_trace_config))
cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance
logging.info("new tracing_instance for app_id: %s", app_id)
logger.info("new tracing_instance for app_id: %s", app_id)
return tracing_instance
@classmethod
@ -849,7 +851,7 @@ class TraceQueueManager:
trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task)
except Exception as e:
logging.exception("Error adding trace task, trace_type %s", trace_task.trace_type)
logger.exception("Error adding trace task, trace_type %s", trace_task.trace_type)
finally:
self.start_timer()
@ -868,7 +870,7 @@ class TraceQueueManager:
if tasks:
self.send_to_celery(tasks)
except Exception as e:
logging.exception("Error processing trace tasks")
logger.exception("Error processing trace tasks")
def start_timer(self):
global trace_manager_timer

View File

@ -154,7 +154,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
workflow = app.workflow
if not workflow:
raise ValueError("")
raise ValueError("unexpected app type")
return WorkflowAppGenerator().generate(
app_model=app,

View File

@ -8,6 +8,7 @@ from core.plugin.entities.plugin_daemon import (
)
from core.plugin.entities.request import PluginInvokeContext
from core.plugin.impl.base import BasePluginClient
from core.plugin.utils.chunk_merger import merge_blob_chunks
class PluginAgentClient(BasePluginClient):
@ -113,4 +114,4 @@ class PluginAgentClient(BasePluginClient):
"Content-Type": "application/json",
},
)
return response
return merge_blob_chunks(response)

View File

@ -141,11 +141,11 @@ class BasePluginClient:
response.raise_for_status()
except HTTPError as e:
msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}"
logging.exception(msg)
logger.exception(msg)
raise e
except Exception as e:
msg = f"Failed to request plugin daemon, url: {path}"
logging.exception(msg)
logger.exception(msg)
raise ValueError(msg) from e
try:
@ -158,7 +158,7 @@ class BasePluginClient:
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}],"
f" url: {path}"
)
logging.exception(msg)
logger.exception(msg)
raise ValueError(msg)
if rep.code != 0:

View File

@ -9,6 +9,7 @@ from core.plugin.entities.plugin_daemon import (
PluginToolProviderEntity,
)
from core.plugin.impl.base import BasePluginClient
from core.plugin.utils.chunk_merger import merge_blob_chunks
from core.schemas.resolver import resolve_dify_schema_refs
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
@ -123,61 +124,7 @@ class PluginToolManager(BasePluginClient):
},
)
class FileChunk:
"""
Only used for internal processing.
"""
bytes_written: int
total_length: int
data: bytearray
def __init__(self, total_length: int):
self.bytes_written = 0
self.total_length = total_length
self.data = bytearray(total_length)
files: dict[str, FileChunk] = {}
for resp in response:
if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
assert isinstance(resp.message, ToolInvokeMessage.BlobChunkMessage)
# Get blob chunk information
chunk_id = resp.message.id
total_length = resp.message.total_length
blob_data = resp.message.blob
is_end = resp.message.end
# Initialize buffer for this file if it doesn't exist
if chunk_id not in files:
files[chunk_id] = FileChunk(total_length)
# If this is the final chunk, yield a complete blob message
if is_end:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data),
meta=resp.meta,
)
else:
# Check if file is too large (30MB limit)
if files[chunk_id].bytes_written + len(blob_data) > 30 * 1024 * 1024:
# Delete the file if it's too large
del files[chunk_id]
# Skip yielding this message
raise ValueError("File is too large which reached the limit of 30MB")
# Check if single chunk is too large (8KB limit)
if len(blob_data) > 8192:
# Skip yielding this message
raise ValueError("File chunk is too large which reached the limit of 8KB")
# Append the blob data to the buffer
files[chunk_id].data[
files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data)
] = blob_data
files[chunk_id].bytes_written += len(blob_data)
else:
yield resp
return merge_blob_chunks(response)
def validate_provider_credentials(
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]

View File

@ -0,0 +1,92 @@
from collections.abc import Generator
from dataclasses import dataclass, field
from typing import TypeVar, Union, cast
from core.agent.entities import AgentInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage
MessageType = TypeVar("MessageType", bound=Union[ToolInvokeMessage, AgentInvokeMessage])
@dataclass
class FileChunk:
"""
Buffer for accumulating file chunks during streaming.
"""
total_length: int
bytes_written: int = field(default=0, init=False)
data: bytearray = field(init=False)
def __post_init__(self) -> None:
self.data = bytearray(self.total_length)
def merge_blob_chunks(
response: Generator[MessageType, None, None],
max_file_size: int = 30 * 1024 * 1024,
max_chunk_size: int = 8192,
) -> Generator[MessageType, None, None]:
"""
Merge streaming blob chunks into complete blob messages.
This function processes a stream of plugin invoke messages, accumulating
BLOB_CHUNK messages by their ID until the final chunk is received,
then yielding a single complete BLOB message.
Args:
response: Generator yielding messages that may include blob chunks
max_file_size: Maximum allowed file size in bytes (default: 30MB)
max_chunk_size: Maximum allowed chunk size in bytes (default: 8KB)
Yields:
Messages from the response stream, with blob chunks merged into complete blobs
Raises:
ValueError: If file size exceeds max_file_size or chunk size exceeds max_chunk_size
"""
files: dict[str, FileChunk] = {}
for resp in response:
if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
assert isinstance(resp.message, ToolInvokeMessage.BlobChunkMessage)
# Get blob chunk information
chunk_id = resp.message.id
total_length = resp.message.total_length
blob_data = resp.message.blob
is_end = resp.message.end
# Initialize buffer for this file if it doesn't exist
if chunk_id not in files:
files[chunk_id] = FileChunk(total_length)
# Check if file is too large (before appending)
if files[chunk_id].bytes_written + len(blob_data) > max_file_size:
# Delete the file if it's too large
del files[chunk_id]
raise ValueError(f"File is too large which reached the limit of {max_file_size / 1024 / 1024}MB")
# Check if single chunk is too large
if len(blob_data) > max_chunk_size:
raise ValueError(f"File chunk is too large which reached the limit of {max_chunk_size / 1024}KB")
# Append the blob data to the buffer
files[chunk_id].data[files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data)] = (
blob_data
)
files[chunk_id].bytes_written += len(blob_data)
# If this is the final chunk, yield a complete blob message
if is_end:
# Create the appropriate message type based on the response type
message_class = type(resp)
merged_message = message_class(
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]),
meta=resp.meta,
)
yield cast(MessageType, merged_message)
# Clean up the buffer
del files[chunk_id]
else:
yield resp

View File

@ -12,6 +12,7 @@ from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
from core.entities.provider_entities import (
CredentialConfiguration,
CustomConfiguration,
CustomModelConfiguration,
CustomProviderConfiguration,
@ -40,7 +41,9 @@ from extensions.ext_redis import redis_client
from models.provider import (
LoadBalancingModelConfig,
Provider,
ProviderCredential,
ProviderModel,
ProviderModelCredential,
ProviderModelSetting,
ProviderType,
TenantDefaultModel,
@ -488,6 +491,61 @@ class ProviderManager:
return provider_name_to_provider_load_balancing_model_configs_dict
@staticmethod
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
"""
Get provider all credentials.
:param tenant_id: workspace id
:param provider_name: provider name
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(ProviderCredential)
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name)
.order_by(ProviderCredential.created_at.desc())
)
available_credentials = session.scalars(stmt).all()
return [
CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name)
for credential in available_credentials
]
@staticmethod
def get_provider_model_available_credentials(
tenant_id: str, provider_name: str, model_name: str, model_type: str
) -> list[CredentialConfiguration]:
"""
Get provider custom model all credentials.
:param tenant_id: workspace id
:param provider_name: provider name
:param model_name: model name
:param model_type: model type
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(ProviderModelCredential)
.where(
ProviderModelCredential.tenant_id == tenant_id,
ProviderModelCredential.provider_name == provider_name,
ProviderModelCredential.model_name == model_name,
ProviderModelCredential.model_type == model_type,
)
.order_by(ProviderModelCredential.created_at.desc())
)
available_credentials = session.scalars(stmt).all()
return [
CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name)
for credential in available_credentials
]
@staticmethod
def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
@ -590,9 +648,6 @@ class ProviderManager:
if provider_record.provider_type == ProviderType.SYSTEM.value:
continue
if not provider_record.encrypted_config:
continue
custom_provider_record = provider_record
# Get custom provider credentials
@ -611,8 +666,8 @@ class ProviderManager:
try:
# fix origin data
if custom_provider_record.encrypted_config is None:
raise ValueError("No credentials found")
if not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {}
elif not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
else:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
@ -637,7 +692,14 @@ class ProviderManager:
else:
provider_credentials = cached_provider_credentials
custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials)
custom_provider_configuration = CustomProviderConfiguration(
credentials=provider_credentials,
current_credential_name=custom_provider_record.credential_name,
current_credential_id=custom_provider_record.credential_id,
available_credentials=self.get_provider_available_credentials(
tenant_id, custom_provider_record.provider_name
),
)
# Get provider model credential secret variables
model_credential_secret_variables = self._extract_secret_variables(
@ -649,8 +711,12 @@ class ProviderManager:
# Get custom provider model credentials
custom_model_configurations = []
for provider_model_record in provider_model_records:
if not provider_model_record.encrypted_config:
continue
available_model_credentials = self.get_provider_model_available_credentials(
tenant_id,
provider_model_record.provider_name,
provider_model_record.model_name,
provider_model_record.model_type,
)
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
@ -659,7 +725,7 @@ class ProviderManager:
# Get cached provider model credentials
cached_provider_model_credentials = provider_model_credentials_cache.get()
if not cached_provider_model_credentials:
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
try:
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
except JSONDecodeError:
@ -688,6 +754,9 @@ class ProviderManager:
model=provider_model_record.model_name,
model_type=ModelType.value_of(provider_model_record.model_type),
credentials=provider_model_credentials,
current_credential_id=provider_model_record.credential_id,
current_credential_name=provider_model_record.credential_name,
available_model_credentials=available_model_credentials,
)
)
@ -899,6 +968,18 @@ class ProviderManager:
load_balancing_model_config.model_name == provider_model_setting.model_name
and load_balancing_model_config.model_type == provider_model_setting.model_type
):
if load_balancing_model_config.name == "__delete__":
# to calculate current model whether has invalidate lb configs
load_balancing_configs.append(
ModelLoadBalancingConfiguration(
id=load_balancing_model_config.id,
name=load_balancing_model_config.name,
credentials={},
credential_source_type=load_balancing_model_config.credential_source_type,
)
)
continue
if not load_balancing_model_config.enabled:
continue
@ -955,6 +1036,7 @@ class ProviderManager:
id=load_balancing_model_config.id,
name=load_balancing_model_config.name,
credentials=provider_model_credentials,
credential_source_type=load_balancing_model_config.credential_source_type,
)
)

View File

@ -259,8 +259,16 @@ class MilvusVector(BaseVector):
"""
Search for documents by full-text search (if hybrid search is enabled).
"""
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
if not self._hybrid_search_enabled:
logger.warning(
"Full-text search is disabled: set MILVUS_ENABLE_HYBRID_SEARCH=true (requires Milvus >= 2.5.0)."
)
return []
if not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning(
"Full-text search unavailable: collection missing 'sparse_vector' field; "
"recreate the collection after enabling MILVUS_ENABLE_HYBRID_SEARCH to add BM25 sparse index."
)
return []
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""

View File

@ -15,6 +15,8 @@ from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class MyScaleConfig(BaseModel):
host: str
@ -53,7 +55,7 @@ class MyScaleVector(BaseVector):
return self.add_texts(documents=texts, embeddings=embeddings, **kwargs)
def _create_collection(self, dimension: int):
logging.info("create MyScale collection %s with dimension %s", self._collection_name, dimension)
logger.info("create MyScale collection %s with dimension %s", self._collection_name, dimension)
self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}")
fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else ""
sql = f"""
@ -151,7 +153,7 @@ class MyScaleVector(BaseVector):
for r in self._client.query(sql).named_results()
]
except Exception as e:
logging.exception("\033[91m\033[1m%s\033[0m \033[95m%s\033[0m", type(e), str(e)) # noqa:TRY401
logger.exception("\033[91m\033[1m%s\033[0m \033[95m%s\033[0m", type(e), str(e)) # noqa:TRY401
return []
def delete(self) -> None:

View File

@ -188,14 +188,17 @@ class OracleVector(BaseVector):
def text_exists(self, id: str) -> bool:
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = :1", (id,))
return cur.fetchone() is not None
conn.close()
def get_by_ids(self, ids: list[str]) -> list[Document]:
if not ids:
return []
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
docs = []
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
@ -208,14 +211,15 @@ class OracleVector(BaseVector):
return
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids)
conn.commit()
conn.close()
def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,))
conn.commit()
conn.close()
@ -227,12 +231,20 @@ class OracleVector(BaseVector):
:param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector.
"""
# Validate and sanitize top_k to prevent SQL injection
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
top_k = 4 # Use default if invalid
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
params = [numpy.array(query_vector)]
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
placeholders = ", ".join(f":{i + 2}" for i in range(len(document_ids_filter)))
where_clause = f"WHERE JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
params.extend(document_ids_filter)
with self._get_connection() as conn:
conn.inputtypehandler = self.input_type_handler
conn.outputtypehandler = self.output_type_handler
@ -241,7 +253,7 @@ class OracleVector(BaseVector):
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
AS distance FROM {self.table_name}
{where_clause} ORDER BY distance fetch first {top_k} rows only""",
[numpy.array(query_vector)],
params,
)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
@ -259,7 +271,10 @@ class OracleVector(BaseVector):
import nltk # type: ignore
from nltk.corpus import stopwords # type: ignore
# Validate and sanitize top_k to prevent SQL injection
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
top_k = 5 # Use default if invalid
# just not implement fetch by score_threshold now, may be later
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if len(query) > 0:
@ -297,14 +312,21 @@ class OracleVector(BaseVector):
with conn.cursor() as cur:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
params: dict[str, Any] = {"kk": " ACCUM ".join(entities)}
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
placeholders = []
for i, doc_id in enumerate(document_ids_filter):
param_name = f"doc_id_{i}"
placeholders.append(f":{param_name}")
params[param_name] = doc_id
where_clause = f" AND JSON_VALUE(meta, '$.document_id') IN ({', '.join(placeholders)}) "
cur.execute(
f"""select meta, text, embedding FROM {self.table_name}
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
order by score(1) desc fetch first {top_k} rows only""",
kk=" ACCUM ".join(entities),
params,
)
docs = []
for record in cur:

View File

@ -19,6 +19,8 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class PGVectorConfig(BaseModel):
host: str
@ -155,7 +157,7 @@ class PGVector(BaseVector):
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
except psycopg2.errors.UndefinedTable:
# table not exists
logging.warning("Table %s not found, skipping delete operation.", self.table_name)
logger.warning("Table %s not found, skipping delete operation.", self.table_name)
return
except Exception as e:
raise e

View File

@ -17,6 +17,8 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models import Dataset
logger = logging.getLogger(__name__)
class TableStoreConfig(BaseModel):
access_key_id: Optional[str] = None
@ -145,7 +147,7 @@ class TableStoreVector(BaseVector):
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logging.info("Collection %s already exists.", self._collection_name)
logger.info("Collection %s already exists.", self._collection_name)
return
self._create_table_if_not_exist()
@ -155,7 +157,7 @@ class TableStoreVector(BaseVector):
def _create_table_if_not_exist(self) -> None:
table_list = self._tablestore_client.list_table()
if self._table_name in table_list:
logging.info("Tablestore system table[%s] already exists", self._table_name)
logger.info("Tablestore system table[%s] already exists", self._table_name)
return None
schema_of_primary_key = [("id", "STRING")]
@ -163,12 +165,12 @@ class TableStoreVector(BaseVector):
table_options = tablestore.TableOptions()
reserved_throughput = tablestore.ReservedThroughput(tablestore.CapacityUnit(0, 0))
self._tablestore_client.create_table(table_meta, table_options, reserved_throughput)
logging.info("Tablestore create table[%s] successfully.", self._table_name)
logger.info("Tablestore create table[%s] successfully.", self._table_name)
def _create_search_index_if_not_exist(self, dimension: int) -> None:
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
if self._index_name in [t[1] for t in search_index_list]:
logging.info("Tablestore system index[%s] already exists", self._index_name)
logger.info("Tablestore system index[%s] already exists", self._index_name)
return None
field_schemas = [
@ -206,20 +208,20 @@ class TableStoreVector(BaseVector):
index_meta = tablestore.SearchIndexMeta(field_schemas)
self._tablestore_client.create_search_index(self._table_name, self._index_name, index_meta)
logging.info("Tablestore create system index[%s] successfully.", self._index_name)
logger.info("Tablestore create system index[%s] successfully.", self._index_name)
def _delete_table_if_exist(self):
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
for resp_tuple in search_index_list:
self._tablestore_client.delete_search_index(resp_tuple[0], resp_tuple[1])
logging.info("Tablestore delete index[%s] successfully.", self._index_name)
logger.info("Tablestore delete index[%s] successfully.", self._index_name)
self._tablestore_client.delete_table(self._table_name)
logging.info("Tablestore delete system table[%s] successfully.", self._index_name)
logger.info("Tablestore delete system table[%s] successfully.", self._index_name)
def _delete_search_index(self) -> None:
self._tablestore_client.delete_search_index(self._table_name, self._index_name)
logging.info("Tablestore delete index[%s] successfully.", self._index_name)
logger.info("Tablestore delete index[%s] successfully.", self._index_name)
def _write_row(self, primary_key: str, attributes: dict[str, Any]) -> None:
pk = [("id", primary_key)]

View File

@ -83,14 +83,14 @@ class TiDBVector(BaseVector):
self._dimension = 1536
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
logger.info("create collection and add texts, collection_name: " + self._collection_name)
logger.info("create collection and add texts, collection_name: %s", self._collection_name)
self._create_collection(len(embeddings[0]))
self.add_texts(texts, embeddings)
self._dimension = len(embeddings[0])
pass
def _create_collection(self, dimension: int):
logger.info("_create_collection, collection_name " + self._collection_name)
logger.info("_create_collection, collection_name %s", self._collection_name)
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"

View File

@ -75,7 +75,7 @@ class CacheEmbedding(Embeddings):
except IntegrityError:
db.session.rollback()
except Exception:
logging.exception("Failed transform embedding")
logger.exception("Failed transform embedding")
cache_embeddings = []
try:
for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
@ -95,7 +95,7 @@ class CacheEmbedding(Embeddings):
db.session.rollback()
except Exception as ex:
db.session.rollback()
logger.exception("Failed to embed documents: %s")
logger.exception("Failed to embed documents")
raise ex
return text_embeddings
@ -122,7 +122,7 @@ class CacheEmbedding(Embeddings):
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex:
if dify_config.DEBUG:
logging.exception("Failed to embed query text '%s...(%s chars)'", text[:10], len(text))
logger.exception("Failed to embed query text '%s...(%s chars)'", text[:10], len(text))
raise ex
try:
@ -136,7 +136,7 @@ class CacheEmbedding(Embeddings):
redis_client.setex(embedding_cache_key, 600, encoded_str)
except Exception as ex:
if dify_config.DEBUG:
logging.exception(
logger.exception(
"Failed to add embedding to redis for the text '%s...(%s chars)'", text[:10], len(text)
)
raise ex

View File

@ -26,6 +26,8 @@ from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
logger = logging.getLogger(__name__)
class QAIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
@ -215,7 +217,7 @@ class QAIndexProcessor(BaseIndexProcessor):
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
logging.exception("Failed to format qa document")
logger.exception("Failed to format qa document")
all_qa_documents.extend(format_documents)

View File

@ -39,9 +39,16 @@ class WeightRerankRunner(BaseRerankRunner):
unique_documents = []
doc_ids = set()
for document in documents:
if document.metadata is not None and document.metadata["doc_id"] not in doc_ids:
if (
document.provider == "dify"
and document.metadata is not None
and document.metadata["doc_id"] not in doc_ids
):
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
else:
if document not in unique_documents:
unique_documents.append(document)
documents = unique_documents

View File

@ -275,35 +275,30 @@ class ApiTool(Tool):
if files:
headers.pop("Content-Type", None)
if method in {
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
}:
response: httpx.Response = getattr(ssrf_proxy, method.lower())(
url,
params=params,
headers=headers,
cookies=cookies,
data=body,
files=files,
timeout=API_TOOL_DEFAULT_TIMEOUT,
follow_redirects=True,
)
return response
else:
_METHOD_MAP = {
"get": ssrf_proxy.get,
"head": ssrf_proxy.head,
"post": ssrf_proxy.post,
"put": ssrf_proxy.put,
"delete": ssrf_proxy.delete,
"patch": ssrf_proxy.patch,
}
method_lc = method.lower()
if method_lc not in _METHOD_MAP:
raise ValueError(f"Invalid http method {method}")
response: httpx.Response = _METHOD_MAP[
method_lc
]( # https://discuss.python.org/t/type-inference-for-function-return-types/42926
url,
params=params,
headers=headers,
cookies=cookies,
data=body,
files=files,
timeout=API_TOOL_DEFAULT_TIMEOUT,
follow_redirects=True,
)
return response
def _convert_body_property_any_of(
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10

View File

@ -280,7 +280,7 @@ class ToolEngine:
mimetype = "image/jpeg"
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "image/jpeg"),
mimetype=response.meta.get("mime_type", mimetype),
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
)
elif response.type == ToolInvokeMessage.MessageType.BLOB:

View File

@ -151,6 +151,11 @@ class FileSegment(Segment):
return ""
class BooleanSegment(Segment):
value_type: SegmentType = SegmentType.BOOLEAN
value: bool
class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Any]
@ -198,6 +203,11 @@ class ArrayFileSegment(ArraySegment):
return ""
class ArrayBooleanSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
value: Sequence[bool]
def get_segment_discriminator(v: Any) -> SegmentType | None:
if isinstance(v, Segment):
return v.value_type
@ -231,11 +241,13 @@ SegmentUnion: TypeAlias = Annotated[
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
| Annotated[FileSegment, Tag(SegmentType.FILE)]
| Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)]
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
),
Discriminator(get_segment_discriminator),
]

View File

@ -6,7 +6,12 @@ from core.file.models import File
class ArrayValidation(StrEnum):
"""Strategy for validating array elements"""
"""Strategy for validating array elements.
Note:
The `NONE` and `FIRST` strategies are primarily for compatibility purposes.
Avoid using them in new code whenever possible.
"""
# Skip element validation (only check array container)
NONE = "none"
@ -27,12 +32,14 @@ class SegmentType(StrEnum):
SECRET = "secret"
FILE = "file"
BOOLEAN = "boolean"
ARRAY_ANY = "array[any]"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILE = "array[file]"
ARRAY_BOOLEAN = "array[boolean]"
NONE = "none"
@ -76,12 +83,18 @@ class SegmentType(StrEnum):
return SegmentType.ARRAY_FILE
case SegmentType.NONE:
return SegmentType.ARRAY_ANY
case SegmentType.BOOLEAN:
return SegmentType.ARRAY_BOOLEAN
case _:
# This should be unreachable.
raise ValueError(f"not supported value {value}")
if value is None:
return SegmentType.NONE
elif isinstance(value, int) and not isinstance(value, bool):
# Important: The check for `bool` must precede the check for `int`,
# as `bool` is a subclass of `int` in Python's type hierarchy.
elif isinstance(value, bool):
return SegmentType.BOOLEAN
elif isinstance(value, int):
return SegmentType.INTEGER
elif isinstance(value, float):
return SegmentType.FLOAT
@ -111,7 +124,7 @@ class SegmentType(StrEnum):
else:
return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value)
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.ALL) -> bool:
"""
Check if a value matches the segment type.
Users of `SegmentType` should call this method, instead of using
@ -126,6 +139,10 @@ class SegmentType(StrEnum):
"""
if self.is_array_type():
return self._validate_array(value, array_validation)
# Important: The check for `bool` must precede the check for `int`,
# as `bool` is a subclass of `int` in Python's type hierarchy.
elif self == SegmentType.BOOLEAN:
return isinstance(value, bool)
elif self in [SegmentType.INTEGER, SegmentType.FLOAT, SegmentType.NUMBER]:
return isinstance(value, (int, float))
elif self == SegmentType.STRING:
@ -141,6 +158,27 @@ class SegmentType(StrEnum):
else:
raise AssertionError("this statement should be unreachable.")
@staticmethod
def cast_value(value: Any, type_: "SegmentType") -> Any:
# Cast Python's `bool` type to `int` when the runtime type requires
# an integer or number.
#
# This ensures compatibility with existing workflows that may use `bool` as
# `int`, since in Python's type system, `bool` is a subtype of `int`.
#
# This function exists solely to maintain compatibility with existing workflows.
# It should not be used to compromise the integrity of the runtime type system.
# No additional casting rules should be introduced to this function.
if type_ in (
SegmentType.INTEGER,
SegmentType.NUMBER,
) and isinstance(value, bool):
return int(value)
if type_ == SegmentType.ARRAY_NUMBER and all(isinstance(i, bool) for i in value):
return [int(i) for i in value]
return value
def exposed_type(self) -> "SegmentType":
"""Returns the type exposed to the frontend.
@ -150,6 +188,20 @@ class SegmentType(StrEnum):
return SegmentType.NUMBER
return self
def element_type(self) -> "SegmentType | None":
"""Return the element type of the current segment type, or `None` if the element type is undefined.
Raises:
ValueError: If the current segment type is not an array type.
Note:
For certain array types, such as `SegmentType.ARRAY_ANY`, their element types are not defined
by the runtime system. In such cases, this method will return `None`.
"""
if not self.is_array_type():
raise ValueError(f"element_type is only supported by array type, got {self}")
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
# ARRAY_ANY does not have corresponding element type.
@ -157,6 +209,7 @@ _ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
SegmentType.ARRAY_FILE: SegmentType.FILE,
SegmentType.ARRAY_BOOLEAN: SegmentType.BOOLEAN,
}
_ARRAY_TYPES = frozenset(

View File

@ -8,11 +8,13 @@ from core.helper import encrypter
from .segments import (
ArrayAnySegment,
ArrayBooleanSegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArraySegment,
ArrayStringSegment,
BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
@ -96,10 +98,18 @@ class FileVariable(FileSegment, Variable):
pass
class BooleanVariable(BooleanSegment, Variable):
pass
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass
class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
pass
class RAGPipelineVariable(BaseModel):
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
@ -143,11 +153,13 @@ VariableUnion: TypeAlias = Annotated[
| Annotated[IntegerVariable, Tag(SegmentType.INTEGER)]
| Annotated[ObjectVariable, Tag(SegmentType.OBJECT)]
| Annotated[FileVariable, Tag(SegmentType.FILE)]
| Annotated[BooleanVariable, Tag(SegmentType.BOOLEAN)]
| Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
),
Discriminator(get_segment_discriminator),

View File

@ -8,6 +8,7 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.variables.segments import ArrayFileSegment
from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@ -119,6 +120,14 @@ class CodeNode(BaseNode):
return value.replace("\x00", "")
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
if value is None:
return None
if not isinstance(value, bool):
raise OutputValidationError(f"Output variable `{variable}` must be a boolean")
return value
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
"""
Check number
@ -173,6 +182,8 @@ class CodeNode(BaseNode):
prefix=f"{prefix}.{output_name}" if prefix else output_name,
depth=depth + 1,
)
elif isinstance(output_value, bool):
self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name)
elif isinstance(output_value, int | float):
self._check_number(
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
@ -232,7 +243,7 @@ class CodeNode(BaseNode):
if output_name not in result:
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
if output_config.type == "object":
if output_config.type == SegmentType.OBJECT:
# check if output is object
if not isinstance(result.get(output_name), dict):
if result[output_name] is None:
@ -249,18 +260,28 @@ class CodeNode(BaseNode):
prefix=f"{prefix}.{output_name}",
depth=depth + 1,
)
elif output_config.type == "number":
elif output_config.type == SegmentType.NUMBER:
# check if number available
transformed_result[output_name] = self._check_number(
value=result[output_name], variable=f"{prefix}{dot}{output_name}"
)
elif output_config.type == "string":
checked = self._check_number(value=result[output_name], variable=f"{prefix}{dot}{output_name}")
# If the output is a boolean and the output schema specifies a NUMBER type,
# convert the boolean value to an integer.
#
# This ensures compatibility with existing workflows that may use
# `True` and `False` as values for NUMBER type outputs.
transformed_result[output_name] = self._convert_boolean_to_int(checked)
elif output_config.type == SegmentType.STRING:
# check if string available
transformed_result[output_name] = self._check_string(
value=result[output_name],
variable=f"{prefix}{dot}{output_name}",
)
elif output_config.type == "array[number]":
elif output_config.type == SegmentType.BOOLEAN:
transformed_result[output_name] = self._check_boolean(
value=result[output_name],
variable=f"{prefix}{dot}{output_name}",
)
elif output_config.type == SegmentType.ARRAY_NUMBER:
# check if array of number available
if not isinstance(result[output_name], list):
if result[output_name] is None:
@ -278,10 +299,17 @@ class CodeNode(BaseNode):
)
transformed_result[output_name] = [
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
# If the element is a boolean and the output schema specifies a `array[number]` type,
# convert the boolean value to an integer.
#
# This ensures compatibility with existing workflows that may use
# `True` and `False` as values for NUMBER type outputs.
self._convert_boolean_to_int(
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]"),
)
for i, value in enumerate(result[output_name])
]
elif output_config.type == "array[string]":
elif output_config.type == SegmentType.ARRAY_STRING:
# check if array of string available
if not isinstance(result[output_name], list):
if result[output_name] is None:
@ -302,7 +330,7 @@ class CodeNode(BaseNode):
self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
for i, value in enumerate(result[output_name])
]
elif output_config.type == "array[object]":
elif output_config.type == SegmentType.ARRAY_OBJECT:
# check if array of object available
if not isinstance(result[output_name], list):
if result[output_name] is None:
@ -340,6 +368,22 @@ class CodeNode(BaseNode):
)
for i, value in enumerate(result[output_name])
]
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
# check if array of object available
if not isinstance(result[output_name], list):
if result[output_name] is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
transformed_result[output_name] = [
self._check_boolean(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
for i, value in enumerate(result[output_name])
]
else:
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
@ -374,3 +418,16 @@ class CodeNode(BaseNode):
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
@staticmethod
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
"""This function convert boolean to integers when the output schema specifies a NUMBER type.
This ensures compatibility with existing workflows that may use
`True` and `False` as values for NUMBER type outputs.
"""
if value is None:
return None
if isinstance(value, bool):
return int(value)
return value

View File

@ -1,11 +1,31 @@
from typing import Literal, Optional
from typing import Annotated, Literal, Optional
from pydantic import BaseModel
from pydantic import AfterValidator, BaseModel
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.types import SegmentType
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
[
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.OBJECT,
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
]
)
def _validate_type(segment_type: SegmentType) -> SegmentType:
if segment_type not in _ALLOWED_OUTPUT_FROM_CODE:
raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}")
return segment_type
class CodeNodeData(BaseNodeData):
"""
@ -13,7 +33,7 @@ class CodeNodeData(BaseNodeData):
"""
class Output(BaseModel):
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: Optional[dict[str, "CodeNodeData.Output"]] = None
class Dependency(BaseModel):

View File

@ -1,36 +1,43 @@
from collections.abc import Sequence
from typing import Literal
from enum import StrEnum
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
_Condition = Literal[
class FilterOperator(StrEnum):
# string conditions
"contains",
"start with",
"end with",
"is",
"in",
"empty",
"not contains",
"is not",
"not in",
"not empty",
CONTAINS = "contains"
START_WITH = "start with"
END_WITH = "end with"
IS = "is"
IN = "in"
EMPTY = "empty"
NOT_CONTAINS = "not contains"
IS_NOT = "is not"
NOT_IN = "not in"
NOT_EMPTY = "not empty"
# number conditions
"=",
"",
"<",
">",
"",
"",
]
EQUAL = "="
NOT_EQUAL = ""
LESS_THAN = "<"
GREATER_THAN = ">"
GREATER_THAN_OR_EQUAL = ""
LESS_THAN_OR_EQUAL = ""
class Order(StrEnum):
ASC = "asc"
DESC = "desc"
class FilterCondition(BaseModel):
key: str = ""
comparison_operator: _Condition = "contains"
value: str | Sequence[str] = ""
comparison_operator: FilterOperator = FilterOperator.CONTAINS
# the value is bool if the filter operator is comparing with
# a boolean constant.
value: str | Sequence[str] | bool = ""
class FilterBy(BaseModel):
@ -38,10 +45,10 @@ class FilterBy(BaseModel):
conditions: Sequence[FilterCondition] = Field(default_factory=list)
class OrderBy(BaseModel):
class OrderByConfig(BaseModel):
enabled: bool = False
key: str = ""
value: Literal["asc", "desc"] = "asc"
value: Order = Order.ASC
class Limit(BaseModel):
@ -57,6 +64,6 @@ class ExtractConfig(BaseModel):
class ListOperatorNodeData(BaseNodeData):
variable: Sequence[str] = Field(default_factory=list)
filter_by: FilterBy
order_by: OrderBy
order_by: OrderByConfig
limit: Limit
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)

View File

@ -1,18 +1,40 @@
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal, Optional, Union
from typing import Any, Optional, TypeAlias, TypeVar
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .entities import ListOperatorNodeData
from .entities import FilterOperator, ListOperatorNodeData, Order
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
_SUPPORTED_TYPES_TUPLE = (
ArrayFileSegment,
ArrayNumberSegment,
ArrayStringSegment,
ArrayBooleanSegment,
)
_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment
_T = TypeVar("_T")
def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
"""Returns the negation of a given filter function. If the original filter
returns `True` for a value, the negated filter will return `False`, and vice versa.
"""
def wrapper(value: _T) -> bool:
return not filter_(value)
return wrapper
class ListOperatorNode(BaseNode):
_node_type = NodeType.LIST_OPERATOR
@ -69,11 +91,8 @@ class ListOperatorNode(BaseNode):
process_data=process_data,
outputs=outputs,
)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
)
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
@ -122,9 +141,7 @@ class ListOperatorNode(BaseNode):
outputs=outputs,
)
def _apply_filter(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
filter_func: Callable[[Any], bool]
result: list[Any] = []
for condition in self._node_data.filter_by.conditions:
@ -154,33 +171,35 @@ class ListOperatorNode(BaseNode):
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayBooleanSegment):
if not isinstance(condition.value, bool):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
else:
raise AssertionError("this statment should be unreachable.")
return variable
def _apply_order(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self._node_data.order_by.value, array=variable.value)
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
else:
raise AssertionError("this statement should be unreachable")
return variable
def _apply_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
result = variable.value[: self._node_data.limit.size]
return variable.model_copy(update={"value": result})
def _extract_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
@ -232,11 +251,11 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
case "empty":
return lambda x: x == ""
case "not contains":
return lambda x: not _contains(value)(x)
return _negation(_contains(value))
case "is not":
return lambda x: not _is(value)(x)
return _negation(_is(value))
case "not in":
return lambda x: not _in(value)(x)
return _negation(_in(value))
case "not empty":
return lambda x: x != ""
case _:
@ -248,7 +267,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
case "in":
return _in(value)
case "not in":
return lambda x: not _in(value)(x)
return _negation(_in(value))
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
@ -271,6 +290,16 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]:
match condition:
case FilterOperator.IS:
return _is(value)
case FilterOperator.IS_NOT:
return _negation(_is(value))
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
@ -298,7 +327,7 @@ def _endswith(value: str) -> Callable[[str], bool]:
return lambda x: x.endswith(value)
def _is(value: str) -> Callable[[str], bool]:
def _is(value: _T) -> Callable[[_T], bool]:
return lambda x: x == value
@ -330,21 +359,13 @@ def _ge(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x >= value
def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]):
return sorted(array, key=lambda x: x, reverse=order == "desc")
def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
return sorted(array, key=lambda x: x, reverse=order == "desc")
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
extract_func: Callable[[File], Any]
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
extract_func = _get_file_extract_string_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
elif order_by == "size":
extract_func = _get_file_extract_number_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
else:
raise InvalidKeyError(f"Invalid order key: {order_by}")

View File

@ -3,7 +3,7 @@ import io
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
@ -55,7 +55,6 @@ from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
@ -90,6 +89,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.event import InNodeEvent
logger = logging.getLogger(__name__)
@ -161,7 +161,7 @@ class LLMNode(BaseNode):
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
node_inputs: Optional[dict[str, Any]] = None
process_data = None
result_text = ""
@ -737,7 +737,7 @@ class LLMNode(BaseNode):
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))

View File

@ -12,9 +12,11 @@ _VALID_VAR_TYPE = frozenset(
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.OBJECT,
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
]
)

View File

@ -404,11 +404,11 @@ class LoopNode(BaseNode):
for node_id in loop_graph.node_ids:
variable_pool.remove([node_id])
_outputs = {}
_outputs: dict[str, Segment | int | None] = {}
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
_loop_variable_segment = variable_pool.get(loop_variable_selector)
if _loop_variable_segment:
_outputs[loop_variable_key] = _loop_variable_segment.value
_outputs[loop_variable_key] = _loop_variable_segment
else:
_outputs[loop_variable_key] = None
@ -522,21 +522,30 @@ class LoopNode(BaseNode):
return variable_mapping
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
if var_type in ["array[string]", "array[number]", "array[object]"]:
if value and isinstance(value, str):
value = json.loads(value)
if var_type in [
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
]:
if original_value and isinstance(original_value, str):
value = json.loads(original_value)
else:
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
value = []
elif var_type == SegmentType.ARRAY_BOOLEAN:
value = original_value
else:
raise AssertionError("this statement should be unreachable.")
try:
return build_segment_with_type(var_type, value)
return build_segment_with_type(var_type, value=value)
except TypeMismatchError as type_exc:
# Attempt to parse the value as a JSON-encoded string, if applicable.
if not isinstance(value, str):
if not isinstance(original_value, str):
raise
try:
value = json.loads(value)
value = json.loads(original_value)
except ValueError:
raise type_exc
return build_segment_with_type(var_type, value)

View File

@ -1,10 +1,46 @@
from typing import Any, Literal, Optional
from typing import Annotated, Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
from pydantic import (
BaseModel,
BeforeValidator,
Field,
field_validator,
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
_OLD_BOOL_TYPE_NAME = "bool"
_OLD_SELECT_TYPE_NAME = "select"
_VALID_PARAMETER_TYPES = frozenset(
[
SegmentType.STRING, # "string",
SegmentType.NUMBER, # "number",
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
_OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node
_OLD_SELECT_TYPE_NAME, # string type with enumeration choices.
]
)
def _validate_type(parameter_type: str) -> SegmentType:
if not isinstance(parameter_type, str):
raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}")
if parameter_type not in _VALID_PARAMETER_TYPES:
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
if parameter_type == _OLD_BOOL_TYPE_NAME:
return SegmentType.BOOLEAN
elif parameter_type == _OLD_SELECT_TYPE_NAME:
return SegmentType.STRING
return SegmentType(parameter_type)
class _ParameterConfigError(Exception):
@ -17,7 +53,7 @@ class ParameterConfig(BaseModel):
"""
name: str
type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"]
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
options: Optional[list[str]] = None
description: str
required: bool
@ -32,17 +68,20 @@ class ParameterConfig(BaseModel):
return str(value)
def is_array_type(self) -> bool:
return self.type in ("array[string]", "array[number]", "array[object]")
return self.type.is_array_type()
def element_type(self) -> Literal["string", "number", "object"]:
if self.type == "array[number]":
return "number"
elif self.type == "array[string]":
return "string"
elif self.type == "array[object]":
return "object"
else:
raise _ParameterConfigError(f"{self.type} is not array type.")
def element_type(self) -> SegmentType:
"""Return the element type of the parameter.
Raises a ValueError if the parameter's type is not an array type.
"""
element_type = self.type.element_type()
# At this point, self.type is guaranteed to be one of `ARRAY_STRING`,
# `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`.
#
# See: _VALID_PARAMETER_TYPES for reference.
assert element_type is not None, f"the element type should not be None, {self.type=}"
return element_type
class ParameterExtractorNodeData(BaseNodeData):
@ -74,16 +113,18 @@ class ParameterExtractorNodeData(BaseNodeData):
for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description}
if parameter.type in {"string", "select"}:
if parameter.type == SegmentType.STRING:
parameter_schema["type"] = "string"
elif parameter.type.startswith("array"):
elif parameter.type.is_array_type():
parameter_schema["type"] = "array"
nested_type = parameter.type[6:-1]
parameter_schema["items"] = {"type": nested_type}
element_type = parameter.type.element_type()
if element_type is None:
raise AssertionError("element type should not be None.")
parameter_schema["items"] = {"type": element_type.value}
else:
parameter_schema["type"] = parameter.type
if parameter.type == "select":
if parameter.options:
parameter_schema["enum"] = parameter.options
parameters["properties"][parameter.name] = parameter_schema

View File

@ -1,3 +1,8 @@
from typing import Any
from core.variables.types import SegmentType
class ParameterExtractorNodeError(ValueError):
"""Base error for ParameterExtractorNode."""
@ -48,3 +53,23 @@ class InvalidArrayValueError(ParameterExtractorNodeError):
class InvalidModelModeError(ParameterExtractorNodeError):
"""Raised when the model mode is invalid."""
class InvalidValueTypeError(ParameterExtractorNodeError):
def __init__(
self,
/,
parameter_name: str,
expected_type: SegmentType,
actual_type: SegmentType | None,
value: Any,
) -> None:
message = (
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
)
super().__init__(message)
self.parameter_name = parameter_name
self.expected_type = expected_type
self.actual_type = actual_type
self.value = value

View File

@ -26,7 +26,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import SegmentType
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -39,16 +39,13 @@ from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData
from .exc import (
InvalidArrayValueError,
InvalidBoolValueError,
InvalidInvokeResultError,
InvalidModelModeError,
InvalidModelTypeError,
InvalidNumberOfParametersError,
InvalidNumberValueError,
InvalidSelectValueError,
InvalidStringValueError,
InvalidTextContentTypeError,
InvalidValueTypeError,
ModelSchemaNotFoundError,
ParameterExtractorNodeError,
RequiredParameterMissingError,
@ -549,9 +546,6 @@ class ParameterExtractorNode(BaseNode):
return prompt_messages
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
"""
Validate result.
"""
if len(data.parameters) != len(result):
raise InvalidNumberOfParametersError("Invalid number of parameters")
@ -559,101 +553,106 @@ class ParameterExtractorNode(BaseNode):
if parameter.required and parameter.name not in result:
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
if parameter.type.startswith("array"):
parameters = result.get(parameter.name)
if not isinstance(parameters, list):
raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
nested_type = parameter.type[6:-1]
for item in parameters:
if nested_type == "number" and not isinstance(item, int | float):
raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
if nested_type == "string" and not isinstance(item, str):
raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
if nested_type == "object" and not isinstance(item, dict):
raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
param_value = result.get(parameter.name)
if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL):
inferred_type = SegmentType.infer_segment_type(param_value)
raise InvalidValueTypeError(
parameter_name=parameter.name,
expected_type=parameter.type,
actual_type=inferred_type,
value=param_value,
)
if parameter.type == SegmentType.STRING and parameter.options:
if param_value not in parameter.options:
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
return result
@staticmethod
def _transform_number(value: int | float | str | bool) -> int | float | None:
"""
Attempts to transform the input into an integer or float.
Returns:
int or float: The transformed number if the conversion is successful.
None: If the transformation fails.
Note:
Boolean values `True` and `False` are converted to integers `1` and `0`, respectively.
This behavior ensures compatibility with existing workflows that may use boolean types as integers.
"""
if isinstance(value, bool):
return int(value)
elif isinstance(value, (int, float)):
return value
elif not isinstance(value, str):
return None
if "." in value:
try:
return float(value)
except ValueError:
return None
else:
try:
return int(value)
except ValueError:
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
"""
Transform result into standard format.
"""
transformed_result = {}
transformed_result: dict[str, Any] = {}
for parameter in data.parameters:
if parameter.name in result:
param_value = result[parameter.name]
# transform value
if parameter.type == "number":
if isinstance(result[parameter.name], int | float):
transformed_result[parameter.name] = result[parameter.name]
elif isinstance(result[parameter.name], str):
try:
if "." in result[parameter.name]:
result[parameter.name] = float(result[parameter.name])
else:
result[parameter.name] = int(result[parameter.name])
except ValueError:
pass
else:
pass
# TODO: bool is not supported in the current version
# elif parameter.type == 'bool':
# if isinstance(result[parameter.name], bool):
# transformed_result[parameter.name] = bool(result[parameter.name])
# elif isinstance(result[parameter.name], str):
# if result[parameter.name].lower() in ['true', 'false']:
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
# elif isinstance(result[parameter.name], int):
# transformed_result[parameter.name] = bool(result[parameter.name])
elif parameter.type in {"string", "select"}:
if isinstance(result[parameter.name], str):
transformed_result[parameter.name] = result[parameter.name]
if parameter.type == SegmentType.NUMBER:
transformed = self._transform_number(param_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
elif parameter.type == SegmentType.BOOLEAN:
if isinstance(result[parameter.name], (bool, int)):
transformed_result[parameter.name] = bool(result[parameter.name])
# elif isinstance(result[parameter.name], str):
# if result[parameter.name].lower() in ["true", "false"]:
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true")
elif parameter.type == SegmentType.STRING:
if isinstance(param_value, str):
transformed_result[parameter.name] = param_value
elif parameter.is_array_type():
if isinstance(result[parameter.name], list):
if isinstance(param_value, list):
nested_type = parameter.element_type()
assert nested_type is not None
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
transformed_result[parameter.name] = segment_value
for item in result[parameter.name]:
if nested_type == "number":
if isinstance(item, int | float):
segment_value.value.append(item)
elif isinstance(item, str):
try:
if "." in item:
segment_value.value.append(float(item))
else:
segment_value.value.append(int(item))
except ValueError:
pass
elif nested_type == "string":
for item in param_value:
if nested_type == SegmentType.NUMBER:
transformed = self._transform_number(item)
if transformed is not None:
segment_value.value.append(transformed)
elif nested_type == SegmentType.STRING:
if isinstance(item, str):
segment_value.value.append(item)
elif nested_type == "object":
elif nested_type == SegmentType.OBJECT:
if isinstance(item, dict):
segment_value.value.append(item)
elif nested_type == SegmentType.BOOLEAN:
if isinstance(item, bool):
segment_value.value.append(item)
if parameter.name not in transformed_result:
if parameter.type == "number":
transformed_result[parameter.name] = 0
elif parameter.type == "bool":
transformed_result[parameter.name] = False
elif parameter.type in {"string", "select"}:
transformed_result[parameter.name] = ""
elif parameter.type.startswith("array"):
if parameter.type.is_array_type():
transformed_result[parameter.name] = build_segment_with_type(
segment_type=SegmentType(parameter.type), value=[]
)
elif parameter.type in (SegmentType.STRING, SegmentType.SECRET):
transformed_result[parameter.name] = ""
elif parameter.type == SegmentType.NUMBER:
transformed_result[parameter.name] = 0
elif parameter.type == SegmentType.BOOLEAN:
transformed_result[parameter.name] = False
else:
raise AssertionError("this statement should be unreachable.")
return transformed_result

View File

@ -2,6 +2,7 @@ from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
from core.variables import SegmentType, Variable
from core.variables.segments import BooleanSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
@ -158,8 +159,8 @@ class VariableAssignerNode(BaseNode):
def get_zero_value(t: SegmentType):
# TODO(QuantumGhost): this should be a method of `SegmentType`.
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return variable_factory.build_segment([])
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN:
return variable_factory.build_segment_with_type(t, [])
case SegmentType.OBJECT:
return variable_factory.build_segment({})
case SegmentType.STRING:
@ -170,5 +171,7 @@ def get_zero_value(t: SegmentType):
return variable_factory.build_segment(0.0)
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case SegmentType.BOOLEAN:
return BooleanSegment(value=False)
case _:
raise VariableOperatorNodeError(f"unsupported variable type: {t}")

View File

@ -4,9 +4,11 @@ from core.variables import SegmentType
EMPTY_VALUE_MAPPING = {
SegmentType.STRING: "",
SegmentType.NUMBER: 0,
SegmentType.BOOLEAN: False,
SegmentType.OBJECT: {},
SegmentType.ARRAY_ANY: [],
SegmentType.ARRAY_STRING: [],
SegmentType.ARRAY_NUMBER: [],
SegmentType.ARRAY_OBJECT: [],
SegmentType.ARRAY_BOOLEAN: [],
}

View File

@ -16,28 +16,15 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
SegmentType.NUMBER,
SegmentType.INTEGER,
SegmentType.FLOAT,
SegmentType.BOOLEAN,
}
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
# Only number variable can be added, subtracted, multiplied or divided
return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
case Operation.APPEND | Operation.EXTEND:
case Operation.APPEND | Operation.EXTEND | Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
# Only array variable can be appended or extended
return variable_type in {
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
case Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
# Only array variable can have elements removed
return variable_type in {
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
return variable_type.is_array_type()
case _:
return False
@ -50,7 +37,7 @@ def is_variable_input_supported(*, operation: Operation):
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
match variable_type:
case SegmentType.STRING | SegmentType.OBJECT:
case SegmentType.STRING | SegmentType.OBJECT | SegmentType.BOOLEAN:
return operation in {Operation.OVER_WRITE, Operation.SET}
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return operation in {
@ -72,6 +59,9 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
case SegmentType.STRING:
return isinstance(value, str)
case SegmentType.BOOLEAN:
return isinstance(value, bool)
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
if not isinstance(value, int | float):
return False
@ -91,6 +81,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
return isinstance(value, int | float)
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
return isinstance(value, dict)
case SegmentType.ARRAY_BOOLEAN if operation == Operation.APPEND:
return isinstance(value, bool)
# Array & Extend / Overwrite
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
@ -101,6 +93,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
case SegmentType.ARRAY_BOOLEAN if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, bool) for item in value)
case _:
return False

View File

@ -45,5 +45,5 @@ class SubVariableCondition(BaseModel):
class Condition(BaseModel):
variable_selector: list[str]
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None = None
value: str | Sequence[str] | bool | None = None
sub_variable_condition: SubVariableCondition | None = None

View File

@ -1,13 +1,27 @@
import json
from collections.abc import Sequence
from typing import Any, Literal
from typing import Any, Literal, Union
from core.file import FileAttribute, file_manager
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayBooleanSegment, BooleanSegment
from core.workflow.entities.variable_pool import VariablePool
from .entities import Condition, SubCondition, SupportedComparisonOperator
def _convert_to_bool(value: Any) -> bool:
if isinstance(value, int):
return bool(value)
if isinstance(value, str):
loaded = json.loads(value)
if isinstance(loaded, (int, bool)):
return bool(loaded)
raise TypeError(f"unexpected value: type={type(value)}, value={value}")
class ConditionProcessor:
def process_conditions(
self,
@ -48,9 +62,16 @@ class ConditionProcessor:
)
else:
actual_value = variable.value if variable else None
expected_value = condition.value
expected_value: str | Sequence[str] | bool | list[bool] | None = condition.value
if isinstance(expected_value, str):
expected_value = variable_pool.convert_template(expected_value).text
# Here we need to explicit convet the input string to boolean.
if isinstance(variable, (BooleanSegment, ArrayBooleanSegment)) and expected_value is not None:
# The following two lines is for compatibility with existing workflows.
if isinstance(expected_value, list):
expected_value = [_convert_to_bool(i) for i in expected_value]
else:
expected_value = _convert_to_bool(expected_value)
input_conditions.append(
{
"actual_value": actual_value,
@ -77,7 +98,7 @@ def _evaluate_condition(
*,
operator: SupportedComparisonOperator,
value: Any,
expected: str | Sequence[str] | None,
expected: Union[str, Sequence[str], bool | Sequence[bool], None],
) -> bool:
match operator:
case "contains":
@ -130,7 +151,7 @@ def _assert_contains(*, value: Any, expected: Any) -> bool:
if not value:
return False
if not isinstance(value, str | list):
if not isinstance(value, (str, list)):
raise ValueError("Invalid actual value type: string or array")
if expected not in value:
@ -142,7 +163,7 @@ def _assert_not_contains(*, value: Any, expected: Any) -> bool:
if not value:
return True
if not isinstance(value, str | list):
if not isinstance(value, (str, list)):
raise ValueError("Invalid actual value type: string or array")
if expected in value:
@ -178,8 +199,8 @@ def _assert_is(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, str):
raise ValueError("Invalid actual value type: string")
if not isinstance(value, (str, bool)):
raise ValueError("Invalid actual value type: string or boolean")
if value != expected:
return False
@ -190,8 +211,8 @@ def _assert_is_not(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, str):
raise ValueError("Invalid actual value type: string")
if not isinstance(value, (str, bool)):
raise ValueError("Invalid actual value type: string or boolean")
if value == expected:
return False
@ -214,10 +235,13 @@ def _assert_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
raise ValueError("Invalid actual value type: number")
if not isinstance(value, (int, float, bool)):
raise ValueError("Invalid actual value type: number or boolean")
if isinstance(value, int):
# Handle boolean comparison
if isinstance(value, bool):
expected = bool(expected)
elif isinstance(value, int):
expected = int(expected)
else:
expected = float(expected)
@ -231,10 +255,13 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
raise ValueError("Invalid actual value type: number")
if not isinstance(value, (int, float, bool)):
raise ValueError("Invalid actual value type: number or boolean")
if isinstance(value, int):
# Handle boolean comparison
if isinstance(value, bool):
expected = bool(expected)
elif isinstance(value, int):
expected = int(expected)
else:
expected = float(expected)
@ -248,7 +275,7 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
@ -265,7 +292,7 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
@ -282,7 +309,7 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
@ -299,7 +326,7 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):