mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
Merge branch 'main' into feat/rag-2
This commit is contained in:
@ -3,6 +3,17 @@ import re
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
|
||||
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
[
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
VariableEntityType.NUMBER,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
VariableEntityType.CHECKBOX,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class BasicVariablesConfigManager:
|
||||
@classmethod
|
||||
@ -47,6 +58,7 @@ class BasicVariablesConfigManager:
|
||||
VariableEntityType.PARAGRAPH,
|
||||
VariableEntityType.NUMBER,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.CHECKBOX,
|
||||
}:
|
||||
variable = variables[variable_type]
|
||||
variable_entities.append(
|
||||
@ -96,8 +108,17 @@ class BasicVariablesConfigManager:
|
||||
variables = []
|
||||
for item in config["user_input_form"]:
|
||||
key = list(item.keys())[0]
|
||||
if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
|
||||
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
|
||||
# if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
|
||||
if key not in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
VariableEntityType.NUMBER,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
VariableEntityType.CHECKBOX,
|
||||
}:
|
||||
allowed_keys = ", ".join(i.value for i in _ALLOWED_VARIABLE_ENTITY_TYPE)
|
||||
raise ValueError(f"Keys in user_input_form list can only be {allowed_keys}")
|
||||
|
||||
form_item = item[key]
|
||||
if "label" not in form_item:
|
||||
|
||||
@ -8,6 +8,8 @@ from core.app.entities.task_entities import AppBlockingResponse, AppStreamRespon
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppGenerateResponseConverter(ABC):
|
||||
_blocking_response_type: type[AppBlockingResponse]
|
||||
@ -120,7 +122,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
if data:
|
||||
data.setdefault("message", getattr(e, "description", str(e)))
|
||||
else:
|
||||
logging.error(e)
|
||||
logger.error(e)
|
||||
data = {
|
||||
"code": "internal_server_error",
|
||||
"message": "Internal Server Error, please contact support.",
|
||||
|
||||
@ -103,18 +103,23 @@ class BaseAppGenerator:
|
||||
f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string"
|
||||
)
|
||||
|
||||
if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str):
|
||||
# handle empty string case
|
||||
if not value.strip():
|
||||
return None
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
|
||||
if variable_entity.type == VariableEntityType.NUMBER:
|
||||
if isinstance(value, (int, float)):
|
||||
return value
|
||||
elif isinstance(value, str):
|
||||
# handle empty string case
|
||||
if not value.strip():
|
||||
return None
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
|
||||
else:
|
||||
raise TypeError(f"expected value type int, float or str, got {type(value)}, value: {value}")
|
||||
|
||||
match variable_entity.type:
|
||||
case VariableEntityType.SELECT:
|
||||
@ -144,6 +149,11 @@ class BaseAppGenerator:
|
||||
raise ValueError(
|
||||
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
|
||||
)
|
||||
case VariableEntityType.CHECKBOX:
|
||||
if not isinstance(value, bool):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid boolean value")
|
||||
case _:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
@ -32,6 +32,8 @@ from extensions.ext_database import db
|
||||
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageCycleManager:
|
||||
def __init__(
|
||||
@ -98,7 +100,7 @@ class MessageCycleManager:
|
||||
conversation.name = name
|
||||
except Exception as e:
|
||||
if dify_config.DEBUG:
|
||||
logging.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
pass
|
||||
|
||||
db.session.merge(conversation)
|
||||
|
||||
@ -19,6 +19,7 @@ class ModelStatus(Enum):
|
||||
QUOTA_EXCEEDED = "quota-exceeded"
|
||||
NO_PERMISSION = "no-permission"
|
||||
DISABLED = "disabled"
|
||||
CREDENTIAL_REMOVED = "credential-removed"
|
||||
|
||||
|
||||
class SimpleModelProviderEntity(BaseModel):
|
||||
@ -54,6 +55,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
|
||||
|
||||
status: ModelStatus
|
||||
load_balancing_enabled: bool = False
|
||||
has_invalid_load_balancing_configs: bool = False
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -69,6 +69,15 @@ class QuotaConfiguration(BaseModel):
|
||||
restrict_models: list[RestrictModel] = []
|
||||
|
||||
|
||||
class CredentialConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for credential configuration.
|
||||
"""
|
||||
|
||||
credential_id: str
|
||||
credential_name: str
|
||||
|
||||
|
||||
class SystemConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider system configuration.
|
||||
@ -86,6 +95,9 @@ class CustomProviderConfiguration(BaseModel):
|
||||
"""
|
||||
|
||||
credentials: dict
|
||||
current_credential_id: Optional[str] = None
|
||||
current_credential_name: Optional[str] = None
|
||||
available_credentials: list[CredentialConfiguration] = []
|
||||
|
||||
|
||||
class CustomModelConfiguration(BaseModel):
|
||||
@ -95,7 +107,10 @@ class CustomModelConfiguration(BaseModel):
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict
|
||||
credentials: dict | None
|
||||
current_credential_id: Optional[str] = None
|
||||
current_credential_name: Optional[str] = None
|
||||
available_model_credentials: list[CredentialConfiguration] = []
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
@ -118,6 +133,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
credentials: dict
|
||||
credential_source_type: str | None = None
|
||||
|
||||
|
||||
class ModelSettings(BaseModel):
|
||||
|
||||
@ -10,6 +10,8 @@ from pydantic import BaseModel
|
||||
|
||||
from core.helper.position_helper import sort_to_dict_by_position_map
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExtensionModule(enum.Enum):
|
||||
MODERATION = "moderation"
|
||||
@ -66,7 +68,7 @@ class Extensible:
|
||||
|
||||
# Check for extension module file
|
||||
if (extension_name + ".py") not in file_names:
|
||||
logging.warning("Missing %s.py file in %s, Skip.", extension_name, subdir_path)
|
||||
logger.warning("Missing %s.py file in %s, Skip.", extension_name, subdir_path)
|
||||
continue
|
||||
|
||||
# Check for builtin flag and position
|
||||
@ -95,7 +97,7 @@ class Extensible:
|
||||
break
|
||||
|
||||
if not extension_class:
|
||||
logging.warning("Missing subclass of %s in %s, Skip.", cls.__name__, module_name)
|
||||
logger.warning("Missing subclass of %s in %s, Skip.", cls.__name__, module_name)
|
||||
continue
|
||||
|
||||
# Load schema if not builtin
|
||||
@ -103,7 +105,7 @@ class Extensible:
|
||||
if not builtin:
|
||||
json_path = os.path.join(subdir_path, "schema.json")
|
||||
if not os.path.exists(json_path):
|
||||
logging.warning("Missing schema.json file in %s, Skip.", subdir_path)
|
||||
logger.warning("Missing schema.json file in %s, Skip.", subdir_path)
|
||||
continue
|
||||
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
@ -122,7 +124,7 @@ class Extensible:
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Error scanning extensions")
|
||||
logger.exception("Error scanning extensions")
|
||||
raise
|
||||
|
||||
# Sort extensions by position
|
||||
|
||||
@ -17,6 +17,7 @@ def encrypt_token(tenant_id: str, token: str):
|
||||
|
||||
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
|
||||
raise ValueError(f"Tenant with id {tenant_id} not found")
|
||||
assert tenant.encrypt_public_key is not None
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
|
||||
@ -4,6 +4,8 @@ import sys
|
||||
from types import ModuleType
|
||||
from typing import AnyStr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
|
||||
"""
|
||||
@ -30,7 +32,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
except Exception as e:
|
||||
logging.exception("Failed to load module %s from script file '%s'", module_name, repr(py_file_path))
|
||||
logger.exception("Failed to load module %s from script file '%s'", module_name, repr(py_file_path))
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
@ -9,6 +9,8 @@ import httpx
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True
|
||||
@ -73,12 +75,12 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
if response.status_code not in STATUS_FORCELIST:
|
||||
return response
|
||||
else:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"Received status code %s for URL %s which is in the force list", response.status_code, url
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logging.warning("Request to URL %s failed on attempt %s: %s", url, retries + 1, e)
|
||||
logger.warning("Request to URL %s failed on attempt %s: %s", url, retries + 1, e)
|
||||
if max_retries == 0:
|
||||
raise
|
||||
|
||||
|
||||
@ -39,6 +39,8 @@ from models.dataset import Document as DatasetDocument
|
||||
from models.model import UploadFile
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IndexingRunner:
|
||||
def __init__(self):
|
||||
@ -90,9 +92,9 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
except ObjectDeletedError:
|
||||
logging.warning("Document deleted, document id: %s", dataset_document.id)
|
||||
logger.warning("Document deleted, document id: %s", dataset_document.id)
|
||||
except Exception as e:
|
||||
logging.exception("consume document failed")
|
||||
logger.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
@ -153,7 +155,7 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
logging.exception("consume document failed")
|
||||
logger.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
@ -228,7 +230,7 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
logging.exception("consume document failed")
|
||||
logger.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
@ -321,7 +323,7 @@ class IndexingRunner:
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
logger.exception(
|
||||
"Delete image_files failed while indexing_estimate, \
|
||||
image_upload_file_is: %s",
|
||||
upload_file_id,
|
||||
|
||||
@ -31,6 +31,8 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from models import App, Message, WorkflowNodeExecutionModel, db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMGenerator:
|
||||
@classmethod
|
||||
@ -68,7 +70,7 @@ class LLMGenerator:
|
||||
result_dict = json.loads(cleaned_answer)
|
||||
answer = result_dict["Your Output"]
|
||||
except json.JSONDecodeError as e:
|
||||
logging.exception("Failed to generate name after answer, use query instead")
|
||||
logger.exception("Failed to generate name after answer, use query instead")
|
||||
answer = query
|
||||
name = answer.strip()
|
||||
|
||||
@ -125,7 +127,7 @@ class LLMGenerator:
|
||||
except InvokeError:
|
||||
questions = []
|
||||
except Exception:
|
||||
logging.exception("Failed to generate suggested questions after answer")
|
||||
logger.exception("Failed to generate suggested questions after answer")
|
||||
questions = []
|
||||
|
||||
return questions
|
||||
@ -173,7 +175,7 @@ class LLMGenerator:
|
||||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logging.exception("Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
rule_config["error"] = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
@ -270,7 +272,7 @@ class LLMGenerator:
|
||||
error_step = "generate conversation opener"
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
rule_config["error"] = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
@ -319,7 +321,7 @@ class LLMGenerator:
|
||||
error = str(e)
|
||||
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logging.exception(
|
||||
logger.exception(
|
||||
"Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language
|
||||
)
|
||||
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
|
||||
@ -392,7 +394,7 @@ class LLMGenerator:
|
||||
error = str(e)
|
||||
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
||||
except Exception as e:
|
||||
logging.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
|
||||
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
|
||||
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
@ -570,5 +572,5 @@ class LLMGenerator:
|
||||
error = str(e)
|
||||
return {"error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logging.exception("Failed to invoke LLM model, model: " + json.dumps(model_config.get("name")), exc_info=e)
|
||||
logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=e)
|
||||
return {"error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@ -152,7 +152,7 @@ class MCPClient:
|
||||
# ExitStack will handle proper cleanup of all managed context managers
|
||||
self._exit_stack.close()
|
||||
except Exception as e:
|
||||
logging.exception("Error during cleanup")
|
||||
logger.exception("Error during cleanup")
|
||||
raise ValueError(f"Error during cleanup: {e}")
|
||||
finally:
|
||||
self._session = None
|
||||
|
||||
@ -31,6 +31,9 @@ from core.mcp.types import (
|
||||
SessionMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
|
||||
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
||||
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
||||
@ -366,7 +369,7 @@ class BaseSession(
|
||||
self._handle_incoming(notification)
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logging.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)
|
||||
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)
|
||||
else: # Response or error
|
||||
response_queue = self._response_streams.get(message.message.root.id)
|
||||
if response_queue is not None:
|
||||
@ -376,7 +379,7 @@ class BaseSession(
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception:
|
||||
logging.exception("Error in message processing loop")
|
||||
logger.exception("Error in message processing loop")
|
||||
raise
|
||||
|
||||
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
|
||||
|
||||
@ -201,7 +201,7 @@ class ModelProviderFactory:
|
||||
return filtered_credentials
|
||||
|
||||
def get_model_schema(
|
||||
self, *, provider: str, model_type: ModelType, model: str, credentials: dict
|
||||
self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
|
||||
@ -100,14 +100,14 @@ class Moderation(Extensible, ABC):
|
||||
if not inputs_config.get("preset_response"):
|
||||
raise ValueError("inputs_config.preset_response is required")
|
||||
|
||||
if len(inputs_config.get("preset_response", 0)) > 100:
|
||||
if len(inputs_config.get("preset_response", "0")) > 100:
|
||||
raise ValueError("inputs_config.preset_response must be less than 100 characters")
|
||||
|
||||
if outputs_config_enabled:
|
||||
if not outputs_config.get("preset_response"):
|
||||
raise ValueError("outputs_config.preset_response is required")
|
||||
|
||||
if len(outputs_config.get("preset_response", 0)) > 100:
|
||||
if len(outputs_config.get("preset_response", "0")) > 100:
|
||||
raise ValueError("outputs_config.preset_response must be less than 100 characters")
|
||||
|
||||
|
||||
|
||||
@ -306,7 +306,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
return node_span
|
||||
except Exception as e:
|
||||
logging.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
|
||||
logger.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:
|
||||
|
||||
@ -37,6 +37,8 @@ from models.model import App, AppModelConfig, Conversation, Message, MessageFile
|
||||
from models.workflow import WorkflowAppLog, WorkflowRun
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
|
||||
def __getitem__(self, provider: str) -> dict[str, Any]:
|
||||
@ -287,7 +289,7 @@ class OpsTraceManager:
|
||||
# create new tracing_instance and update the cache if it absent
|
||||
tracing_instance = trace_instance(config_class(**decrypt_trace_config))
|
||||
cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance
|
||||
logging.info("new tracing_instance for app_id: %s", app_id)
|
||||
logger.info("new tracing_instance for app_id: %s", app_id)
|
||||
return tracing_instance
|
||||
|
||||
@classmethod
|
||||
@ -849,7 +851,7 @@ class TraceQueueManager:
|
||||
trace_task.app_id = self.app_id
|
||||
trace_manager_queue.put(trace_task)
|
||||
except Exception as e:
|
||||
logging.exception("Error adding trace task, trace_type %s", trace_task.trace_type)
|
||||
logger.exception("Error adding trace task, trace_type %s", trace_task.trace_type)
|
||||
finally:
|
||||
self.start_timer()
|
||||
|
||||
@ -868,7 +870,7 @@ class TraceQueueManager:
|
||||
if tasks:
|
||||
self.send_to_celery(tasks)
|
||||
except Exception as e:
|
||||
logging.exception("Error processing trace tasks")
|
||||
logger.exception("Error processing trace tasks")
|
||||
|
||||
def start_timer(self):
|
||||
global trace_manager_timer
|
||||
|
||||
@ -154,7 +154,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
workflow = app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("")
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
return WorkflowAppGenerator().generate(
|
||||
app_model=app,
|
||||
|
||||
@ -8,6 +8,7 @@ from core.plugin.entities.plugin_daemon import (
|
||||
)
|
||||
from core.plugin.entities.request import PluginInvokeContext
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||
|
||||
|
||||
class PluginAgentClient(BasePluginClient):
|
||||
@ -113,4 +114,4 @@ class PluginAgentClient(BasePluginClient):
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
return response
|
||||
return merge_blob_chunks(response)
|
||||
|
||||
@ -141,11 +141,11 @@ class BasePluginClient:
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}"
|
||||
logging.exception(msg)
|
||||
logger.exception(msg)
|
||||
raise e
|
||||
except Exception as e:
|
||||
msg = f"Failed to request plugin daemon, url: {path}"
|
||||
logging.exception(msg)
|
||||
logger.exception(msg)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
try:
|
||||
@ -158,7 +158,7 @@ class BasePluginClient:
|
||||
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}],"
|
||||
f" url: {path}"
|
||||
)
|
||||
logging.exception(msg)
|
||||
logger.exception(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
if rep.code != 0:
|
||||
|
||||
@ -9,6 +9,7 @@ from core.plugin.entities.plugin_daemon import (
|
||||
PluginToolProviderEntity,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||
from core.schemas.resolver import resolve_dify_schema_refs
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
|
||||
|
||||
@ -123,61 +124,7 @@ class PluginToolManager(BasePluginClient):
|
||||
},
|
||||
)
|
||||
|
||||
class FileChunk:
|
||||
"""
|
||||
Only used for internal processing.
|
||||
"""
|
||||
|
||||
bytes_written: int
|
||||
total_length: int
|
||||
data: bytearray
|
||||
|
||||
def __init__(self, total_length: int):
|
||||
self.bytes_written = 0
|
||||
self.total_length = total_length
|
||||
self.data = bytearray(total_length)
|
||||
|
||||
files: dict[str, FileChunk] = {}
|
||||
for resp in response:
|
||||
if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
|
||||
assert isinstance(resp.message, ToolInvokeMessage.BlobChunkMessage)
|
||||
# Get blob chunk information
|
||||
chunk_id = resp.message.id
|
||||
total_length = resp.message.total_length
|
||||
blob_data = resp.message.blob
|
||||
is_end = resp.message.end
|
||||
|
||||
# Initialize buffer for this file if it doesn't exist
|
||||
if chunk_id not in files:
|
||||
files[chunk_id] = FileChunk(total_length)
|
||||
|
||||
# If this is the final chunk, yield a complete blob message
|
||||
if is_end:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB,
|
||||
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data),
|
||||
meta=resp.meta,
|
||||
)
|
||||
else:
|
||||
# Check if file is too large (30MB limit)
|
||||
if files[chunk_id].bytes_written + len(blob_data) > 30 * 1024 * 1024:
|
||||
# Delete the file if it's too large
|
||||
del files[chunk_id]
|
||||
# Skip yielding this message
|
||||
raise ValueError("File is too large which reached the limit of 30MB")
|
||||
|
||||
# Check if single chunk is too large (8KB limit)
|
||||
if len(blob_data) > 8192:
|
||||
# Skip yielding this message
|
||||
raise ValueError("File chunk is too large which reached the limit of 8KB")
|
||||
|
||||
# Append the blob data to the buffer
|
||||
files[chunk_id].data[
|
||||
files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data)
|
||||
] = blob_data
|
||||
files[chunk_id].bytes_written += len(blob_data)
|
||||
else:
|
||||
yield resp
|
||||
return merge_blob_chunks(response)
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
|
||||
|
||||
92
api/core/plugin/utils/chunk_merger.py
Normal file
92
api/core/plugin/utils/chunk_merger.py
Normal file
@ -0,0 +1,92 @@
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeVar, Union, cast
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
MessageType = TypeVar("MessageType", bound=Union[ToolInvokeMessage, AgentInvokeMessage])
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileChunk:
|
||||
"""
|
||||
Buffer for accumulating file chunks during streaming.
|
||||
"""
|
||||
|
||||
total_length: int
|
||||
bytes_written: int = field(default=0, init=False)
|
||||
data: bytearray = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.data = bytearray(self.total_length)
|
||||
|
||||
|
||||
def merge_blob_chunks(
|
||||
response: Generator[MessageType, None, None],
|
||||
max_file_size: int = 30 * 1024 * 1024,
|
||||
max_chunk_size: int = 8192,
|
||||
) -> Generator[MessageType, None, None]:
|
||||
"""
|
||||
Merge streaming blob chunks into complete blob messages.
|
||||
|
||||
This function processes a stream of plugin invoke messages, accumulating
|
||||
BLOB_CHUNK messages by their ID until the final chunk is received,
|
||||
then yielding a single complete BLOB message.
|
||||
|
||||
Args:
|
||||
response: Generator yielding messages that may include blob chunks
|
||||
max_file_size: Maximum allowed file size in bytes (default: 30MB)
|
||||
max_chunk_size: Maximum allowed chunk size in bytes (default: 8KB)
|
||||
|
||||
Yields:
|
||||
Messages from the response stream, with blob chunks merged into complete blobs
|
||||
|
||||
Raises:
|
||||
ValueError: If file size exceeds max_file_size or chunk size exceeds max_chunk_size
|
||||
"""
|
||||
files: dict[str, FileChunk] = {}
|
||||
|
||||
for resp in response:
|
||||
if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
|
||||
assert isinstance(resp.message, ToolInvokeMessage.BlobChunkMessage)
|
||||
# Get blob chunk information
|
||||
chunk_id = resp.message.id
|
||||
total_length = resp.message.total_length
|
||||
blob_data = resp.message.blob
|
||||
is_end = resp.message.end
|
||||
|
||||
# Initialize buffer for this file if it doesn't exist
|
||||
if chunk_id not in files:
|
||||
files[chunk_id] = FileChunk(total_length)
|
||||
|
||||
# Check if file is too large (before appending)
|
||||
if files[chunk_id].bytes_written + len(blob_data) > max_file_size:
|
||||
# Delete the file if it's too large
|
||||
del files[chunk_id]
|
||||
raise ValueError(f"File is too large which reached the limit of {max_file_size / 1024 / 1024}MB")
|
||||
|
||||
# Check if single chunk is too large
|
||||
if len(blob_data) > max_chunk_size:
|
||||
raise ValueError(f"File chunk is too large which reached the limit of {max_chunk_size / 1024}KB")
|
||||
|
||||
# Append the blob data to the buffer
|
||||
files[chunk_id].data[files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data)] = (
|
||||
blob_data
|
||||
)
|
||||
files[chunk_id].bytes_written += len(blob_data)
|
||||
|
||||
# If this is the final chunk, yield a complete blob message
|
||||
if is_end:
|
||||
# Create the appropriate message type based on the response type
|
||||
message_class = type(resp)
|
||||
merged_message = message_class(
|
||||
type=ToolInvokeMessage.MessageType.BLOB,
|
||||
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]),
|
||||
meta=resp.meta,
|
||||
)
|
||||
yield cast(MessageType, merged_message)
|
||||
# Clean up the buffer
|
||||
del files[chunk_id]
|
||||
else:
|
||||
yield resp
|
||||
@ -12,6 +12,7 @@ from configs import dify_config
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
||||
from core.entities.provider_entities import (
|
||||
CredentialConfiguration,
|
||||
CustomConfiguration,
|
||||
CustomModelConfiguration,
|
||||
CustomProviderConfiguration,
|
||||
@ -40,7 +41,9 @@ from extensions.ext_redis import redis_client
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
ProviderCredential,
|
||||
ProviderModel,
|
||||
ProviderModelCredential,
|
||||
ProviderModelSetting,
|
||||
ProviderType,
|
||||
TenantDefaultModel,
|
||||
@ -488,6 +491,61 @@ class ProviderManager:
|
||||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
@staticmethod
|
||||
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
|
||||
"""
|
||||
Get provider all credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_name: provider name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(ProviderCredential)
|
||||
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name)
|
||||
.order_by(ProviderCredential.created_at.desc())
|
||||
)
|
||||
|
||||
available_credentials = session.scalars(stmt).all()
|
||||
|
||||
return [
|
||||
CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name)
|
||||
for credential in available_credentials
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_provider_model_available_credentials(
|
||||
tenant_id: str, provider_name: str, model_name: str, model_type: str
|
||||
) -> list[CredentialConfiguration]:
|
||||
"""
|
||||
Get provider custom model all credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_name: provider name
|
||||
:param model_name: model name
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(ProviderModelCredential)
|
||||
.where(
|
||||
ProviderModelCredential.tenant_id == tenant_id,
|
||||
ProviderModelCredential.provider_name == provider_name,
|
||||
ProviderModelCredential.model_name == model_name,
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
.order_by(ProviderModelCredential.created_at.desc())
|
||||
)
|
||||
|
||||
available_credentials = session.scalars(stmt).all()
|
||||
|
||||
return [
|
||||
CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name)
|
||||
for credential in available_credentials
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _init_trial_provider_records(
|
||||
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
|
||||
@ -590,9 +648,6 @@ class ProviderManager:
|
||||
if provider_record.provider_type == ProviderType.SYSTEM.value:
|
||||
continue
|
||||
|
||||
if not provider_record.encrypted_config:
|
||||
continue
|
||||
|
||||
custom_provider_record = provider_record
|
||||
|
||||
# Get custom provider credentials
|
||||
@ -611,8 +666,8 @@ class ProviderManager:
|
||||
try:
|
||||
# fix origin data
|
||||
if custom_provider_record.encrypted_config is None:
|
||||
raise ValueError("No credentials found")
|
||||
if not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {}
|
||||
elif not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
@ -637,7 +692,14 @@ class ProviderManager:
|
||||
else:
|
||||
provider_credentials = cached_provider_credentials
|
||||
|
||||
custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials)
|
||||
custom_provider_configuration = CustomProviderConfiguration(
|
||||
credentials=provider_credentials,
|
||||
current_credential_name=custom_provider_record.credential_name,
|
||||
current_credential_id=custom_provider_record.credential_id,
|
||||
available_credentials=self.get_provider_available_credentials(
|
||||
tenant_id, custom_provider_record.provider_name
|
||||
),
|
||||
)
|
||||
|
||||
# Get provider model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
@ -649,8 +711,12 @@ class ProviderManager:
|
||||
# Get custom provider model credentials
|
||||
custom_model_configurations = []
|
||||
for provider_model_record in provider_model_records:
|
||||
if not provider_model_record.encrypted_config:
|
||||
continue
|
||||
available_model_credentials = self.get_provider_model_available_credentials(
|
||||
tenant_id,
|
||||
provider_model_record.provider_name,
|
||||
provider_model_record.model_name,
|
||||
provider_model_record.model_type,
|
||||
)
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
|
||||
@ -659,7 +725,7 @@ class ProviderManager:
|
||||
# Get cached provider model credentials
|
||||
cached_provider_model_credentials = provider_model_credentials_cache.get()
|
||||
|
||||
if not cached_provider_model_credentials:
|
||||
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
|
||||
try:
|
||||
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
@ -688,6 +754,9 @@ class ProviderManager:
|
||||
model=provider_model_record.model_name,
|
||||
model_type=ModelType.value_of(provider_model_record.model_type),
|
||||
credentials=provider_model_credentials,
|
||||
current_credential_id=provider_model_record.credential_id,
|
||||
current_credential_name=provider_model_record.credential_name,
|
||||
available_model_credentials=available_model_credentials,
|
||||
)
|
||||
)
|
||||
|
||||
@ -899,6 +968,18 @@ class ProviderManager:
|
||||
load_balancing_model_config.model_name == provider_model_setting.model_name
|
||||
and load_balancing_model_config.model_type == provider_model_setting.model_type
|
||||
):
|
||||
if load_balancing_model_config.name == "__delete__":
|
||||
# to calculate current model whether has invalidate lb configs
|
||||
load_balancing_configs.append(
|
||||
ModelLoadBalancingConfiguration(
|
||||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials={},
|
||||
credential_source_type=load_balancing_model_config.credential_source_type,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if not load_balancing_model_config.enabled:
|
||||
continue
|
||||
|
||||
@ -955,6 +1036,7 @@ class ProviderManager:
|
||||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials=provider_model_credentials,
|
||||
credential_source_type=load_balancing_model_config.credential_source_type,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -259,8 +259,16 @@ class MilvusVector(BaseVector):
|
||||
"""
|
||||
Search for documents by full-text search (if hybrid search is enabled).
|
||||
"""
|
||||
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
|
||||
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
|
||||
if not self._hybrid_search_enabled:
|
||||
logger.warning(
|
||||
"Full-text search is disabled: set MILVUS_ENABLE_HYBRID_SEARCH=true (requires Milvus >= 2.5.0)."
|
||||
)
|
||||
return []
|
||||
if not self.field_exists(Field.SPARSE_VECTOR.value):
|
||||
logger.warning(
|
||||
"Full-text search unavailable: collection missing 'sparse_vector' field; "
|
||||
"recreate the collection after enabling MILVUS_ENABLE_HYBRID_SEARCH to add BM25 sparse index."
|
||||
)
|
||||
return []
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = ""
|
||||
|
||||
@ -15,6 +15,8 @@ from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MyScaleConfig(BaseModel):
|
||||
host: str
|
||||
@ -53,7 +55,7 @@ class MyScaleVector(BaseVector):
|
||||
return self.add_texts(documents=texts, embeddings=embeddings, **kwargs)
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
logging.info("create MyScale collection %s with dimension %s", self._collection_name, dimension)
|
||||
logger.info("create MyScale collection %s with dimension %s", self._collection_name, dimension)
|
||||
self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}")
|
||||
fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else ""
|
||||
sql = f"""
|
||||
@ -151,7 +153,7 @@ class MyScaleVector(BaseVector):
|
||||
for r in self._client.query(sql).named_results()
|
||||
]
|
||||
except Exception as e:
|
||||
logging.exception("\033[91m\033[1m%s\033[0m \033[95m%s\033[0m", type(e), str(e)) # noqa:TRY401
|
||||
logger.exception("\033[91m\033[1m%s\033[0m \033[95m%s\033[0m", type(e), str(e)) # noqa:TRY401
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
|
||||
@ -188,14 +188,17 @@ class OracleVector(BaseVector):
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = :1", (id,))
|
||||
return cur.fetchone() is not None
|
||||
conn.close()
|
||||
|
||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
if not ids:
|
||||
return []
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
|
||||
docs = []
|
||||
for record in cur:
|
||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||
@ -208,14 +211,15 @@ class OracleVector(BaseVector):
|
||||
return
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
||||
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@ -227,12 +231,20 @@ class OracleVector(BaseVector):
|
||||
:param top_k: The number of nearest neighbors to return, default is 5.
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
# Validate and sanitize top_k to prevent SQL injection
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
|
||||
top_k = 4 # Use default if invalid
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
params = [numpy.array(query_vector)]
|
||||
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
|
||||
placeholders = ", ".join(f":{i + 2}" for i in range(len(document_ids_filter)))
|
||||
where_clause = f"WHERE JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
|
||||
params.extend(document_ids_filter)
|
||||
|
||||
with self._get_connection() as conn:
|
||||
conn.inputtypehandler = self.input_type_handler
|
||||
conn.outputtypehandler = self.output_type_handler
|
||||
@ -241,7 +253,7 @@ class OracleVector(BaseVector):
|
||||
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
|
||||
AS distance FROM {self.table_name}
|
||||
{where_clause} ORDER BY distance fetch first {top_k} rows only""",
|
||||
[numpy.array(query_vector)],
|
||||
params,
|
||||
)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
@ -259,7 +271,10 @@ class OracleVector(BaseVector):
|
||||
import nltk # type: ignore
|
||||
from nltk.corpus import stopwords # type: ignore
|
||||
|
||||
# Validate and sanitize top_k to prevent SQL injection
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
|
||||
top_k = 5 # Use default if invalid
|
||||
# just not implement fetch by score_threshold now, may be later
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if len(query) > 0:
|
||||
@ -297,14 +312,21 @@ class OracleVector(BaseVector):
|
||||
with conn.cursor() as cur:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
params: dict[str, Any] = {"kk": " ACCUM ".join(entities)}
|
||||
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
placeholders = []
|
||||
for i, doc_id in enumerate(document_ids_filter):
|
||||
param_name = f"doc_id_{i}"
|
||||
placeholders.append(f":{param_name}")
|
||||
params[param_name] = doc_id
|
||||
where_clause = f" AND JSON_VALUE(meta, '$.document_id') IN ({', '.join(placeholders)}) "
|
||||
|
||||
cur.execute(
|
||||
f"""select meta, text, embedding FROM {self.table_name}
|
||||
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
|
||||
order by score(1) desc fetch first {top_k} rows only""",
|
||||
kk=" ACCUM ".join(entities),
|
||||
params,
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
|
||||
@ -19,6 +19,8 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PGVectorConfig(BaseModel):
|
||||
host: str
|
||||
@ -155,7 +157,7 @@ class PGVector(BaseVector):
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
except psycopg2.errors.UndefinedTable:
|
||||
# table not exists
|
||||
logging.warning("Table %s not found, skipping delete operation.", self.table_name)
|
||||
logger.warning("Table %s not found, skipping delete operation.", self.table_name)
|
||||
return
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@ -17,6 +17,8 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TableStoreConfig(BaseModel):
|
||||
access_key_id: Optional[str] = None
|
||||
@ -145,7 +147,7 @@ class TableStoreVector(BaseVector):
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logging.info("Collection %s already exists.", self._collection_name)
|
||||
logger.info("Collection %s already exists.", self._collection_name)
|
||||
return
|
||||
|
||||
self._create_table_if_not_exist()
|
||||
@ -155,7 +157,7 @@ class TableStoreVector(BaseVector):
|
||||
def _create_table_if_not_exist(self) -> None:
|
||||
table_list = self._tablestore_client.list_table()
|
||||
if self._table_name in table_list:
|
||||
logging.info("Tablestore system table[%s] already exists", self._table_name)
|
||||
logger.info("Tablestore system table[%s] already exists", self._table_name)
|
||||
return None
|
||||
|
||||
schema_of_primary_key = [("id", "STRING")]
|
||||
@ -163,12 +165,12 @@ class TableStoreVector(BaseVector):
|
||||
table_options = tablestore.TableOptions()
|
||||
reserved_throughput = tablestore.ReservedThroughput(tablestore.CapacityUnit(0, 0))
|
||||
self._tablestore_client.create_table(table_meta, table_options, reserved_throughput)
|
||||
logging.info("Tablestore create table[%s] successfully.", self._table_name)
|
||||
logger.info("Tablestore create table[%s] successfully.", self._table_name)
|
||||
|
||||
def _create_search_index_if_not_exist(self, dimension: int) -> None:
|
||||
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
|
||||
if self._index_name in [t[1] for t in search_index_list]:
|
||||
logging.info("Tablestore system index[%s] already exists", self._index_name)
|
||||
logger.info("Tablestore system index[%s] already exists", self._index_name)
|
||||
return None
|
||||
|
||||
field_schemas = [
|
||||
@ -206,20 +208,20 @@ class TableStoreVector(BaseVector):
|
||||
|
||||
index_meta = tablestore.SearchIndexMeta(field_schemas)
|
||||
self._tablestore_client.create_search_index(self._table_name, self._index_name, index_meta)
|
||||
logging.info("Tablestore create system index[%s] successfully.", self._index_name)
|
||||
logger.info("Tablestore create system index[%s] successfully.", self._index_name)
|
||||
|
||||
def _delete_table_if_exist(self):
|
||||
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
|
||||
for resp_tuple in search_index_list:
|
||||
self._tablestore_client.delete_search_index(resp_tuple[0], resp_tuple[1])
|
||||
logging.info("Tablestore delete index[%s] successfully.", self._index_name)
|
||||
logger.info("Tablestore delete index[%s] successfully.", self._index_name)
|
||||
|
||||
self._tablestore_client.delete_table(self._table_name)
|
||||
logging.info("Tablestore delete system table[%s] successfully.", self._index_name)
|
||||
logger.info("Tablestore delete system table[%s] successfully.", self._index_name)
|
||||
|
||||
def _delete_search_index(self) -> None:
|
||||
self._tablestore_client.delete_search_index(self._table_name, self._index_name)
|
||||
logging.info("Tablestore delete index[%s] successfully.", self._index_name)
|
||||
logger.info("Tablestore delete index[%s] successfully.", self._index_name)
|
||||
|
||||
def _write_row(self, primary_key: str, attributes: dict[str, Any]) -> None:
|
||||
pk = [("id", primary_key)]
|
||||
|
||||
@ -83,14 +83,14 @@ class TiDBVector(BaseVector):
|
||||
self._dimension = 1536
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
logger.info("create collection and add texts, collection_name: " + self._collection_name)
|
||||
logger.info("create collection and add texts, collection_name: %s", self._collection_name)
|
||||
self._create_collection(len(embeddings[0]))
|
||||
self.add_texts(texts, embeddings)
|
||||
self._dimension = len(embeddings[0])
|
||||
pass
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
logger.info("_create_collection, collection_name " + self._collection_name)
|
||||
logger.info("_create_collection, collection_name %s", self._collection_name)
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
|
||||
@ -75,7 +75,7 @@ class CacheEmbedding(Embeddings):
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
except Exception:
|
||||
logging.exception("Failed transform embedding")
|
||||
logger.exception("Failed transform embedding")
|
||||
cache_embeddings = []
|
||||
try:
|
||||
for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
|
||||
@ -95,7 +95,7 @@ class CacheEmbedding(Embeddings):
|
||||
db.session.rollback()
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
logger.exception("Failed to embed documents: %s")
|
||||
logger.exception("Failed to embed documents")
|
||||
raise ex
|
||||
|
||||
return text_embeddings
|
||||
@ -122,7 +122,7 @@ class CacheEmbedding(Embeddings):
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
except Exception as ex:
|
||||
if dify_config.DEBUG:
|
||||
logging.exception("Failed to embed query text '%s...(%s chars)'", text[:10], len(text))
|
||||
logger.exception("Failed to embed query text '%s...(%s chars)'", text[:10], len(text))
|
||||
raise ex
|
||||
|
||||
try:
|
||||
@ -136,7 +136,7 @@ class CacheEmbedding(Embeddings):
|
||||
redis_client.setex(embedding_cache_key, 600, encoded_str)
|
||||
except Exception as ex:
|
||||
if dify_config.DEBUG:
|
||||
logging.exception(
|
||||
logger.exception(
|
||||
"Failed to add embedding to redis for the text '%s...(%s chars)'", text[:10], len(text)
|
||||
)
|
||||
raise ex
|
||||
|
||||
@ -26,6 +26,8 @@ from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QAIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
@ -215,7 +217,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
qa_documents.append(qa_document)
|
||||
format_documents.extend(qa_documents)
|
||||
except Exception as e:
|
||||
logging.exception("Failed to format qa document")
|
||||
logger.exception("Failed to format qa document")
|
||||
|
||||
all_qa_documents.extend(format_documents)
|
||||
|
||||
|
||||
@ -39,9 +39,16 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
unique_documents = []
|
||||
doc_ids = set()
|
||||
for document in documents:
|
||||
if document.metadata is not None and document.metadata["doc_id"] not in doc_ids:
|
||||
if (
|
||||
document.provider == "dify"
|
||||
and document.metadata is not None
|
||||
and document.metadata["doc_id"] not in doc_ids
|
||||
):
|
||||
doc_ids.add(document.metadata["doc_id"])
|
||||
unique_documents.append(document)
|
||||
else:
|
||||
if document not in unique_documents:
|
||||
unique_documents.append(document)
|
||||
|
||||
documents = unique_documents
|
||||
|
||||
|
||||
@ -275,35 +275,30 @@ class ApiTool(Tool):
|
||||
if files:
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
if method in {
|
||||
"get",
|
||||
"head",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
}:
|
||||
response: httpx.Response = getattr(ssrf_proxy, method.lower())(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
data=body,
|
||||
files=files,
|
||||
timeout=API_TOOL_DEFAULT_TIMEOUT,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return response
|
||||
else:
|
||||
_METHOD_MAP = {
|
||||
"get": ssrf_proxy.get,
|
||||
"head": ssrf_proxy.head,
|
||||
"post": ssrf_proxy.post,
|
||||
"put": ssrf_proxy.put,
|
||||
"delete": ssrf_proxy.delete,
|
||||
"patch": ssrf_proxy.patch,
|
||||
}
|
||||
method_lc = method.lower()
|
||||
if method_lc not in _METHOD_MAP:
|
||||
raise ValueError(f"Invalid http method {method}")
|
||||
response: httpx.Response = _METHOD_MAP[
|
||||
method_lc
|
||||
]( # https://discuss.python.org/t/type-inference-for-function-return-types/42926
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
data=body,
|
||||
files=files,
|
||||
timeout=API_TOOL_DEFAULT_TIMEOUT,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return response
|
||||
|
||||
def _convert_body_property_any_of(
|
||||
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
|
||||
|
||||
@ -280,7 +280,7 @@ class ToolEngine:
|
||||
mimetype = "image/jpeg"
|
||||
|
||||
yield ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "image/jpeg"),
|
||||
mimetype=response.meta.get("mime_type", mimetype),
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
|
||||
@ -151,6 +151,11 @@ class FileSegment(Segment):
|
||||
return ""
|
||||
|
||||
|
||||
class BooleanSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.BOOLEAN
|
||||
value: bool
|
||||
|
||||
|
||||
class ArrayAnySegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_ANY
|
||||
value: Sequence[Any]
|
||||
@ -198,6 +203,11 @@ class ArrayFileSegment(ArraySegment):
|
||||
return ""
|
||||
|
||||
|
||||
class ArrayBooleanSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
|
||||
value: Sequence[bool]
|
||||
|
||||
|
||||
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||
if isinstance(v, Segment):
|
||||
return v.value_type
|
||||
@ -231,11 +241,13 @@ SegmentUnion: TypeAlias = Annotated[
|
||||
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
|
||||
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
|
||||
| Annotated[FileSegment, Tag(SegmentType.FILE)]
|
||||
| Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)]
|
||||
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
|
||||
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
|
||||
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
|
||||
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
]
|
||||
|
||||
@ -6,7 +6,12 @@ from core.file.models import File
|
||||
|
||||
|
||||
class ArrayValidation(StrEnum):
|
||||
"""Strategy for validating array elements"""
|
||||
"""Strategy for validating array elements.
|
||||
|
||||
Note:
|
||||
The `NONE` and `FIRST` strategies are primarily for compatibility purposes.
|
||||
Avoid using them in new code whenever possible.
|
||||
"""
|
||||
|
||||
# Skip element validation (only check array container)
|
||||
NONE = "none"
|
||||
@ -27,12 +32,14 @@ class SegmentType(StrEnum):
|
||||
SECRET = "secret"
|
||||
|
||||
FILE = "file"
|
||||
BOOLEAN = "boolean"
|
||||
|
||||
ARRAY_ANY = "array[any]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILE = "array[file]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
|
||||
NONE = "none"
|
||||
|
||||
@ -76,12 +83,18 @@ class SegmentType(StrEnum):
|
||||
return SegmentType.ARRAY_FILE
|
||||
case SegmentType.NONE:
|
||||
return SegmentType.ARRAY_ANY
|
||||
case SegmentType.BOOLEAN:
|
||||
return SegmentType.ARRAY_BOOLEAN
|
||||
case _:
|
||||
# This should be unreachable.
|
||||
raise ValueError(f"not supported value {value}")
|
||||
if value is None:
|
||||
return SegmentType.NONE
|
||||
elif isinstance(value, int) and not isinstance(value, bool):
|
||||
# Important: The check for `bool` must precede the check for `int`,
|
||||
# as `bool` is a subclass of `int` in Python's type hierarchy.
|
||||
elif isinstance(value, bool):
|
||||
return SegmentType.BOOLEAN
|
||||
elif isinstance(value, int):
|
||||
return SegmentType.INTEGER
|
||||
elif isinstance(value, float):
|
||||
return SegmentType.FLOAT
|
||||
@ -111,7 +124,7 @@ class SegmentType(StrEnum):
|
||||
else:
|
||||
return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value)
|
||||
|
||||
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
|
||||
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.ALL) -> bool:
|
||||
"""
|
||||
Check if a value matches the segment type.
|
||||
Users of `SegmentType` should call this method, instead of using
|
||||
@ -126,6 +139,10 @@ class SegmentType(StrEnum):
|
||||
"""
|
||||
if self.is_array_type():
|
||||
return self._validate_array(value, array_validation)
|
||||
# Important: The check for `bool` must precede the check for `int`,
|
||||
# as `bool` is a subclass of `int` in Python's type hierarchy.
|
||||
elif self == SegmentType.BOOLEAN:
|
||||
return isinstance(value, bool)
|
||||
elif self in [SegmentType.INTEGER, SegmentType.FLOAT, SegmentType.NUMBER]:
|
||||
return isinstance(value, (int, float))
|
||||
elif self == SegmentType.STRING:
|
||||
@ -141,6 +158,27 @@ class SegmentType(StrEnum):
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
@staticmethod
|
||||
def cast_value(value: Any, type_: "SegmentType") -> Any:
|
||||
# Cast Python's `bool` type to `int` when the runtime type requires
|
||||
# an integer or number.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use `bool` as
|
||||
# `int`, since in Python's type system, `bool` is a subtype of `int`.
|
||||
#
|
||||
# This function exists solely to maintain compatibility with existing workflows.
|
||||
# It should not be used to compromise the integrity of the runtime type system.
|
||||
# No additional casting rules should be introduced to this function.
|
||||
|
||||
if type_ in (
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.NUMBER,
|
||||
) and isinstance(value, bool):
|
||||
return int(value)
|
||||
if type_ == SegmentType.ARRAY_NUMBER and all(isinstance(i, bool) for i in value):
|
||||
return [int(i) for i in value]
|
||||
return value
|
||||
|
||||
def exposed_type(self) -> "SegmentType":
|
||||
"""Returns the type exposed to the frontend.
|
||||
|
||||
@ -150,6 +188,20 @@ class SegmentType(StrEnum):
|
||||
return SegmentType.NUMBER
|
||||
return self
|
||||
|
||||
def element_type(self) -> "SegmentType | None":
|
||||
"""Return the element type of the current segment type, or `None` if the element type is undefined.
|
||||
|
||||
Raises:
|
||||
ValueError: If the current segment type is not an array type.
|
||||
|
||||
Note:
|
||||
For certain array types, such as `SegmentType.ARRAY_ANY`, their element types are not defined
|
||||
by the runtime system. In such cases, this method will return `None`.
|
||||
"""
|
||||
if not self.is_array_type():
|
||||
raise ValueError(f"element_type is only supported by array type, got {self}")
|
||||
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
||||
|
||||
|
||||
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
|
||||
# ARRAY_ANY does not have corresponding element type.
|
||||
@ -157,6 +209,7 @@ _ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
|
||||
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
|
||||
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
|
||||
SegmentType.ARRAY_FILE: SegmentType.FILE,
|
||||
SegmentType.ARRAY_BOOLEAN: SegmentType.BOOLEAN,
|
||||
}
|
||||
|
||||
_ARRAY_TYPES = frozenset(
|
||||
|
||||
@ -8,11 +8,13 @@ from core.helper import encrypter
|
||||
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayBooleanSegment,
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
@ -96,10 +98,18 @@ class FileVariable(FileSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class BooleanVariable(BooleanSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class RAGPipelineVariable(BaseModel):
|
||||
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
|
||||
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
|
||||
@ -143,11 +153,13 @@ VariableUnion: TypeAlias = Annotated[
|
||||
| Annotated[IntegerVariable, Tag(SegmentType.INTEGER)]
|
||||
| Annotated[ObjectVariable, Tag(SegmentType.OBJECT)]
|
||||
| Annotated[FileVariable, Tag(SegmentType.FILE)]
|
||||
| Annotated[BooleanVariable, Tag(SegmentType.BOOLEAN)]
|
||||
| Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)]
|
||||
| Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)]
|
||||
| Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)]
|
||||
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
|
||||
@ -8,6 +8,7 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
@ -119,6 +120,14 @@ class CodeNode(BaseNode):
|
||||
|
||||
return value.replace("\x00", "")
|
||||
|
||||
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, bool):
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a boolean")
|
||||
|
||||
return value
|
||||
|
||||
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
|
||||
"""
|
||||
Check number
|
||||
@ -173,6 +182,8 @@ class CodeNode(BaseNode):
|
||||
prefix=f"{prefix}.{output_name}" if prefix else output_name,
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif isinstance(output_value, bool):
|
||||
self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name)
|
||||
elif isinstance(output_value, int | float):
|
||||
self._check_number(
|
||||
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
|
||||
@ -232,7 +243,7 @@ class CodeNode(BaseNode):
|
||||
if output_name not in result:
|
||||
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
|
||||
|
||||
if output_config.type == "object":
|
||||
if output_config.type == SegmentType.OBJECT:
|
||||
# check if output is object
|
||||
if not isinstance(result.get(output_name), dict):
|
||||
if result[output_name] is None:
|
||||
@ -249,18 +260,28 @@ class CodeNode(BaseNode):
|
||||
prefix=f"{prefix}.{output_name}",
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif output_config.type == "number":
|
||||
elif output_config.type == SegmentType.NUMBER:
|
||||
# check if number available
|
||||
transformed_result[output_name] = self._check_number(
|
||||
value=result[output_name], variable=f"{prefix}{dot}{output_name}"
|
||||
)
|
||||
elif output_config.type == "string":
|
||||
checked = self._check_number(value=result[output_name], variable=f"{prefix}{dot}{output_name}")
|
||||
# If the output is a boolean and the output schema specifies a NUMBER type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
transformed_result[output_name] = self._convert_boolean_to_int(checked)
|
||||
|
||||
elif output_config.type == SegmentType.STRING:
|
||||
# check if string available
|
||||
transformed_result[output_name] = self._check_string(
|
||||
value=result[output_name],
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == "array[number]":
|
||||
elif output_config.type == SegmentType.BOOLEAN:
|
||||
transformed_result[output_name] = self._check_boolean(
|
||||
value=result[output_name],
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == SegmentType.ARRAY_NUMBER:
|
||||
# check if array of number available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
@ -278,10 +299,17 @@ class CodeNode(BaseNode):
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
# If the element is a boolean and the output schema specifies a `array[number]` type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
self._convert_boolean_to_int(
|
||||
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]"),
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == "array[string]":
|
||||
elif output_config.type == SegmentType.ARRAY_STRING:
|
||||
# check if array of string available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
@ -302,7 +330,7 @@ class CodeNode(BaseNode):
|
||||
self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == "array[object]":
|
||||
elif output_config.type == SegmentType.ARRAY_OBJECT:
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
@ -340,6 +368,22 @@ class CodeNode(BaseNode):
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
transformed_result[output_name] = [
|
||||
self._check_boolean(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
|
||||
else:
|
||||
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
|
||||
|
||||
@ -374,3 +418,16 @@ class CodeNode(BaseNode):
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
@staticmethod
|
||||
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
|
||||
"""This function convert boolean to integers when the output schema specifies a NUMBER type.
|
||||
|
||||
This ensures compatibility with existing workflows that may use
|
||||
`True` and `False` as values for NUMBER type outputs.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
return value
|
||||
|
||||
@ -1,11 +1,31 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
|
||||
[
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _validate_type(segment_type: SegmentType) -> SegmentType:
|
||||
if segment_type not in _ALLOWED_OUTPUT_FROM_CODE:
|
||||
raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}")
|
||||
return segment_type
|
||||
|
||||
|
||||
class CodeNodeData(BaseNodeData):
|
||||
"""
|
||||
@ -13,7 +33,7 @@ class CodeNodeData(BaseNodeData):
|
||||
"""
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: Optional[dict[str, "CodeNodeData.Output"]] = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
|
||||
@ -1,36 +1,43 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
_Condition = Literal[
|
||||
|
||||
class FilterOperator(StrEnum):
|
||||
# string conditions
|
||||
"contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"in",
|
||||
"empty",
|
||||
"not contains",
|
||||
"is not",
|
||||
"not in",
|
||||
"not empty",
|
||||
CONTAINS = "contains"
|
||||
START_WITH = "start with"
|
||||
END_WITH = "end with"
|
||||
IS = "is"
|
||||
IN = "in"
|
||||
EMPTY = "empty"
|
||||
NOT_CONTAINS = "not contains"
|
||||
IS_NOT = "is not"
|
||||
NOT_IN = "not in"
|
||||
NOT_EMPTY = "not empty"
|
||||
# number conditions
|
||||
"=",
|
||||
"≠",
|
||||
"<",
|
||||
">",
|
||||
"≥",
|
||||
"≤",
|
||||
]
|
||||
EQUAL = "="
|
||||
NOT_EQUAL = "≠"
|
||||
LESS_THAN = "<"
|
||||
GREATER_THAN = ">"
|
||||
GREATER_THAN_OR_EQUAL = "≥"
|
||||
LESS_THAN_OR_EQUAL = "≤"
|
||||
|
||||
|
||||
class Order(StrEnum):
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class FilterCondition(BaseModel):
|
||||
key: str = ""
|
||||
comparison_operator: _Condition = "contains"
|
||||
value: str | Sequence[str] = ""
|
||||
comparison_operator: FilterOperator = FilterOperator.CONTAINS
|
||||
# the value is bool if the filter operator is comparing with
|
||||
# a boolean constant.
|
||||
value: str | Sequence[str] | bool = ""
|
||||
|
||||
|
||||
class FilterBy(BaseModel):
|
||||
@ -38,10 +45,10 @@ class FilterBy(BaseModel):
|
||||
conditions: Sequence[FilterCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OrderBy(BaseModel):
|
||||
class OrderByConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
key: str = ""
|
||||
value: Literal["asc", "desc"] = "asc"
|
||||
value: Order = Order.ASC
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
@ -57,6 +64,6 @@ class ExtractConfig(BaseModel):
|
||||
class ListOperatorNodeData(BaseNodeData):
|
||||
variable: Sequence[str] = Field(default_factory=list)
|
||||
filter_by: FilterBy
|
||||
order_by: OrderBy
|
||||
order_by: OrderByConfig
|
||||
limit: Limit
|
||||
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)
|
||||
|
||||
@ -1,18 +1,40 @@
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Optional, TypeAlias, TypeVar
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
||||
from .entities import ListOperatorNodeData
|
||||
from .entities import FilterOperator, ListOperatorNodeData, Order
|
||||
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
||||
|
||||
_SUPPORTED_TYPES_TUPLE = (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayStringSegment,
|
||||
ArrayBooleanSegment,
|
||||
)
|
||||
_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
|
||||
"""Returns the negation of a given filter function. If the original filter
|
||||
returns `True` for a value, the negated filter will return `False`, and vice versa.
|
||||
"""
|
||||
|
||||
def wrapper(value: _T) -> bool:
|
||||
return not filter_(value)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ListOperatorNode(BaseNode):
|
||||
_node_type = NodeType.LIST_OPERATOR
|
||||
@ -69,11 +91,8 @@ class ListOperatorNode(BaseNode):
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
|
||||
error_message = (
|
||||
f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
|
||||
"or ArrayStringSegment"
|
||||
)
|
||||
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
|
||||
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||
)
|
||||
@ -122,9 +141,7 @@ class ListOperatorNode(BaseNode):
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
def _apply_filter(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
filter_func: Callable[[Any], bool]
|
||||
result: list[Any] = []
|
||||
for condition in self._node_data.filter_by.conditions:
|
||||
@ -154,33 +171,35 @@ class ListOperatorNode(BaseNode):
|
||||
)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayBooleanSegment):
|
||||
if not isinstance(condition.value, bool):
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
raise AssertionError("this statment should be unreachable.")
|
||||
return variable
|
||||
|
||||
def _apply_order(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
result = _order_string(order=self._node_data.order_by.value, array=variable.value)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayNumberSegment):
|
||||
result = _order_number(order=self._node_data.order_by.value, array=variable.value)
|
||||
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
|
||||
result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
result = _order_file(
|
||||
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
|
||||
)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable")
|
||||
|
||||
return variable
|
||||
|
||||
def _apply_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
result = variable.value[: self._node_data.limit.size]
|
||||
return variable.model_copy(update={"value": result})
|
||||
|
||||
def _extract_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
|
||||
if value < 1:
|
||||
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
|
||||
@ -232,11 +251,11 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
|
||||
case "empty":
|
||||
return lambda x: x == ""
|
||||
case "not contains":
|
||||
return lambda x: not _contains(value)(x)
|
||||
return _negation(_contains(value))
|
||||
case "is not":
|
||||
return lambda x: not _is(value)(x)
|
||||
return _negation(_is(value))
|
||||
case "not in":
|
||||
return lambda x: not _in(value)(x)
|
||||
return _negation(_in(value))
|
||||
case "not empty":
|
||||
return lambda x: x != ""
|
||||
case _:
|
||||
@ -248,7 +267,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
|
||||
case "in":
|
||||
return _in(value)
|
||||
case "not in":
|
||||
return lambda x: not _in(value)(x)
|
||||
return _negation(_in(value))
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
@ -271,6 +290,16 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]:
|
||||
match condition:
|
||||
case FilterOperator.IS:
|
||||
return _is(value)
|
||||
case FilterOperator.IS_NOT:
|
||||
return _negation(_is(value))
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
||||
extract_func: Callable[[File], Any]
|
||||
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
|
||||
@ -298,7 +327,7 @@ def _endswith(value: str) -> Callable[[str], bool]:
|
||||
return lambda x: x.endswith(value)
|
||||
|
||||
|
||||
def _is(value: str) -> Callable[[str], bool]:
|
||||
def _is(value: _T) -> Callable[[_T], bool]:
|
||||
return lambda x: x == value
|
||||
|
||||
|
||||
@ -330,21 +359,13 @@ def _ge(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x >= value
|
||||
|
||||
|
||||
def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]):
|
||||
return sorted(array, key=lambda x: x, reverse=order == "desc")
|
||||
|
||||
|
||||
def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
|
||||
return sorted(array, key=lambda x: x, reverse=order == "desc")
|
||||
|
||||
|
||||
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
|
||||
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
|
||||
extract_func: Callable[[File], Any]
|
||||
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
|
||||
extract_func = _get_file_extract_string_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
|
||||
elif order_by == "size":
|
||||
extract_func = _get_file_extract_number_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
|
||||
else:
|
||||
raise InvalidKeyError(f"Invalid order key: {order_by}")
|
||||
|
||||
@ -3,7 +3,7 @@ import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
@ -55,7 +55,6 @@ from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
@ -90,6 +89,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -161,7 +161,7 @@ class LLMNode(BaseNode):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
node_inputs: Optional[dict[str, Any]] = None
|
||||
process_data = None
|
||||
result_text = ""
|
||||
@ -737,7 +737,7 @@ class LLMNode(BaseNode):
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
and isinstance(prompt_messages[-1].content, list)
|
||||
):
|
||||
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
|
||||
@ -12,9 +12,11 @@ _VALID_VAR_TYPE = frozenset(
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@ -404,11 +404,11 @@ class LoopNode(BaseNode):
|
||||
for node_id in loop_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
_outputs = {}
|
||||
_outputs: dict[str, Segment | int | None] = {}
|
||||
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
|
||||
_loop_variable_segment = variable_pool.get(loop_variable_selector)
|
||||
if _loop_variable_segment:
|
||||
_outputs[loop_variable_key] = _loop_variable_segment.value
|
||||
_outputs[loop_variable_key] = _loop_variable_segment
|
||||
else:
|
||||
_outputs[loop_variable_key] = None
|
||||
|
||||
@ -522,21 +522,30 @@ class LoopNode(BaseNode):
|
||||
return variable_mapping
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
if var_type in ["array[string]", "array[number]", "array[object]"]:
|
||||
if value and isinstance(value, str):
|
||||
value = json.loads(value)
|
||||
if var_type in [
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
]:
|
||||
if original_value and isinstance(original_value, str):
|
||||
value = json.loads(original_value)
|
||||
else:
|
||||
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
|
||||
value = []
|
||||
elif var_type == SegmentType.ARRAY_BOOLEAN:
|
||||
value = original_value
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
try:
|
||||
return build_segment_with_type(var_type, value)
|
||||
return build_segment_with_type(var_type, value=value)
|
||||
except TypeMismatchError as type_exc:
|
||||
# Attempt to parse the value as a JSON-encoded string, if applicable.
|
||||
if not isinstance(value, str):
|
||||
if not isinstance(original_value, str):
|
||||
raise
|
||||
try:
|
||||
value = json.loads(value)
|
||||
value = json.loads(original_value)
|
||||
except ValueError:
|
||||
raise type_exc
|
||||
return build_segment_with_type(var_type, value)
|
||||
|
||||
@ -1,10 +1,46 @@
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Annotated, Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
Field,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
|
||||
_OLD_BOOL_TYPE_NAME = "bool"
|
||||
_OLD_SELECT_TYPE_NAME = "select"
|
||||
|
||||
_VALID_PARAMETER_TYPES = frozenset(
|
||||
[
|
||||
SegmentType.STRING, # "string",
|
||||
SegmentType.NUMBER, # "number",
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
_OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node
|
||||
_OLD_SELECT_TYPE_NAME, # string type with enumeration choices.
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _validate_type(parameter_type: str) -> SegmentType:
|
||||
if not isinstance(parameter_type, str):
|
||||
raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}")
|
||||
if parameter_type not in _VALID_PARAMETER_TYPES:
|
||||
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
|
||||
|
||||
if parameter_type == _OLD_BOOL_TYPE_NAME:
|
||||
return SegmentType.BOOLEAN
|
||||
elif parameter_type == _OLD_SELECT_TYPE_NAME:
|
||||
return SegmentType.STRING
|
||||
return SegmentType(parameter_type)
|
||||
|
||||
|
||||
class _ParameterConfigError(Exception):
|
||||
@ -17,7 +53,7 @@ class ParameterConfig(BaseModel):
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"]
|
||||
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
|
||||
options: Optional[list[str]] = None
|
||||
description: str
|
||||
required: bool
|
||||
@ -32,17 +68,20 @@ class ParameterConfig(BaseModel):
|
||||
return str(value)
|
||||
|
||||
def is_array_type(self) -> bool:
|
||||
return self.type in ("array[string]", "array[number]", "array[object]")
|
||||
return self.type.is_array_type()
|
||||
|
||||
def element_type(self) -> Literal["string", "number", "object"]:
|
||||
if self.type == "array[number]":
|
||||
return "number"
|
||||
elif self.type == "array[string]":
|
||||
return "string"
|
||||
elif self.type == "array[object]":
|
||||
return "object"
|
||||
else:
|
||||
raise _ParameterConfigError(f"{self.type} is not array type.")
|
||||
def element_type(self) -> SegmentType:
|
||||
"""Return the element type of the parameter.
|
||||
|
||||
Raises a ValueError if the parameter's type is not an array type.
|
||||
"""
|
||||
element_type = self.type.element_type()
|
||||
# At this point, self.type is guaranteed to be one of `ARRAY_STRING`,
|
||||
# `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`.
|
||||
#
|
||||
# See: _VALID_PARAMETER_TYPES for reference.
|
||||
assert element_type is not None, f"the element type should not be None, {self.type=}"
|
||||
return element_type
|
||||
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
@ -74,16 +113,18 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
for parameter in self.parameters:
|
||||
parameter_schema: dict[str, Any] = {"description": parameter.description}
|
||||
|
||||
if parameter.type in {"string", "select"}:
|
||||
if parameter.type == SegmentType.STRING:
|
||||
parameter_schema["type"] = "string"
|
||||
elif parameter.type.startswith("array"):
|
||||
elif parameter.type.is_array_type():
|
||||
parameter_schema["type"] = "array"
|
||||
nested_type = parameter.type[6:-1]
|
||||
parameter_schema["items"] = {"type": nested_type}
|
||||
element_type = parameter.type.element_type()
|
||||
if element_type is None:
|
||||
raise AssertionError("element type should not be None.")
|
||||
parameter_schema["items"] = {"type": element_type.value}
|
||||
else:
|
||||
parameter_schema["type"] = parameter.type
|
||||
|
||||
if parameter.type == "select":
|
||||
if parameter.options:
|
||||
parameter_schema["enum"] = parameter.options
|
||||
|
||||
parameters["properties"][parameter.name] = parameter_schema
|
||||
|
||||
@ -1,3 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
|
||||
class ParameterExtractorNodeError(ValueError):
|
||||
"""Base error for ParameterExtractorNode."""
|
||||
|
||||
@ -48,3 +53,23 @@ class InvalidArrayValueError(ParameterExtractorNodeError):
|
||||
|
||||
class InvalidModelModeError(ParameterExtractorNodeError):
|
||||
"""Raised when the model mode is invalid."""
|
||||
|
||||
|
||||
class InvalidValueTypeError(ParameterExtractorNodeError):
|
||||
def __init__(
|
||||
self,
|
||||
/,
|
||||
parameter_name: str,
|
||||
expected_type: SegmentType,
|
||||
actual_type: SegmentType | None,
|
||||
value: Any,
|
||||
) -> None:
|
||||
message = (
|
||||
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
|
||||
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
|
||||
)
|
||||
super().__init__(message)
|
||||
self.parameter_name = parameter_name
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
self.value = value
|
||||
|
||||
@ -26,7 +26,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
@ -39,16 +39,13 @@ from factories.variable_factory import build_segment_with_type
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
from .exc import (
|
||||
InvalidArrayValueError,
|
||||
InvalidBoolValueError,
|
||||
InvalidInvokeResultError,
|
||||
InvalidModelModeError,
|
||||
InvalidModelTypeError,
|
||||
InvalidNumberOfParametersError,
|
||||
InvalidNumberValueError,
|
||||
InvalidSelectValueError,
|
||||
InvalidStringValueError,
|
||||
InvalidTextContentTypeError,
|
||||
InvalidValueTypeError,
|
||||
ModelSchemaNotFoundError,
|
||||
ParameterExtractorNodeError,
|
||||
RequiredParameterMissingError,
|
||||
@ -549,9 +546,6 @@ class ParameterExtractorNode(BaseNode):
|
||||
return prompt_messages
|
||||
|
||||
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
"""
|
||||
Validate result.
|
||||
"""
|
||||
if len(data.parameters) != len(result):
|
||||
raise InvalidNumberOfParametersError("Invalid number of parameters")
|
||||
|
||||
@ -559,101 +553,106 @@ class ParameterExtractorNode(BaseNode):
|
||||
if parameter.required and parameter.name not in result:
|
||||
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
|
||||
|
||||
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
|
||||
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
|
||||
raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
|
||||
raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
|
||||
raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type.startswith("array"):
|
||||
parameters = result.get(parameter.name)
|
||||
if not isinstance(parameters, list):
|
||||
raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
|
||||
nested_type = parameter.type[6:-1]
|
||||
for item in parameters:
|
||||
if nested_type == "number" and not isinstance(item, int | float):
|
||||
raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
|
||||
if nested_type == "string" and not isinstance(item, str):
|
||||
raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
|
||||
if nested_type == "object" and not isinstance(item, dict):
|
||||
raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
|
||||
param_value = result.get(parameter.name)
|
||||
if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL):
|
||||
inferred_type = SegmentType.infer_segment_type(param_value)
|
||||
raise InvalidValueTypeError(
|
||||
parameter_name=parameter.name,
|
||||
expected_type=parameter.type,
|
||||
actual_type=inferred_type,
|
||||
value=param_value,
|
||||
)
|
||||
if parameter.type == SegmentType.STRING and parameter.options:
|
||||
if param_value not in parameter.options:
|
||||
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _transform_number(value: int | float | str | bool) -> int | float | None:
|
||||
"""
|
||||
Attempts to transform the input into an integer or float.
|
||||
|
||||
Returns:
|
||||
int or float: The transformed number if the conversion is successful.
|
||||
None: If the transformation fails.
|
||||
|
||||
Note:
|
||||
Boolean values `True` and `False` are converted to integers `1` and `0`, respectively.
|
||||
This behavior ensures compatibility with existing workflows that may use boolean types as integers.
|
||||
"""
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
elif isinstance(value, (int, float)):
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
return None
|
||||
if "." in value:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
"""
|
||||
Transform result into standard format.
|
||||
"""
|
||||
transformed_result = {}
|
||||
transformed_result: dict[str, Any] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.name in result:
|
||||
param_value = result[parameter.name]
|
||||
# transform value
|
||||
if parameter.type == "number":
|
||||
if isinstance(result[parameter.name], int | float):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
elif isinstance(result[parameter.name], str):
|
||||
try:
|
||||
if "." in result[parameter.name]:
|
||||
result[parameter.name] = float(result[parameter.name])
|
||||
else:
|
||||
result[parameter.name] = int(result[parameter.name])
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
# TODO: bool is not supported in the current version
|
||||
# elif parameter.type == 'bool':
|
||||
# if isinstance(result[parameter.name], bool):
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
# elif isinstance(result[parameter.name], str):
|
||||
# if result[parameter.name].lower() in ['true', 'false']:
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
|
||||
# elif isinstance(result[parameter.name], int):
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
elif parameter.type in {"string", "select"}:
|
||||
if isinstance(result[parameter.name], str):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
if parameter.type == SegmentType.NUMBER:
|
||||
transformed = self._transform_number(param_value)
|
||||
if transformed is not None:
|
||||
transformed_result[parameter.name] = transformed
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
if isinstance(result[parameter.name], (bool, int)):
|
||||
transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
# elif isinstance(result[parameter.name], str):
|
||||
# if result[parameter.name].lower() in ["true", "false"]:
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true")
|
||||
elif parameter.type == SegmentType.STRING:
|
||||
if isinstance(param_value, str):
|
||||
transformed_result[parameter.name] = param_value
|
||||
elif parameter.is_array_type():
|
||||
if isinstance(result[parameter.name], list):
|
||||
if isinstance(param_value, list):
|
||||
nested_type = parameter.element_type()
|
||||
assert nested_type is not None
|
||||
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
|
||||
transformed_result[parameter.name] = segment_value
|
||||
for item in result[parameter.name]:
|
||||
if nested_type == "number":
|
||||
if isinstance(item, int | float):
|
||||
segment_value.value.append(item)
|
||||
elif isinstance(item, str):
|
||||
try:
|
||||
if "." in item:
|
||||
segment_value.value.append(float(item))
|
||||
else:
|
||||
segment_value.value.append(int(item))
|
||||
except ValueError:
|
||||
pass
|
||||
elif nested_type == "string":
|
||||
for item in param_value:
|
||||
if nested_type == SegmentType.NUMBER:
|
||||
transformed = self._transform_number(item)
|
||||
if transformed is not None:
|
||||
segment_value.value.append(transformed)
|
||||
elif nested_type == SegmentType.STRING:
|
||||
if isinstance(item, str):
|
||||
segment_value.value.append(item)
|
||||
elif nested_type == "object":
|
||||
elif nested_type == SegmentType.OBJECT:
|
||||
if isinstance(item, dict):
|
||||
segment_value.value.append(item)
|
||||
elif nested_type == SegmentType.BOOLEAN:
|
||||
if isinstance(item, bool):
|
||||
segment_value.value.append(item)
|
||||
|
||||
if parameter.name not in transformed_result:
|
||||
if parameter.type == "number":
|
||||
transformed_result[parameter.name] = 0
|
||||
elif parameter.type == "bool":
|
||||
transformed_result[parameter.name] = False
|
||||
elif parameter.type in {"string", "select"}:
|
||||
transformed_result[parameter.name] = ""
|
||||
elif parameter.type.startswith("array"):
|
||||
if parameter.type.is_array_type():
|
||||
transformed_result[parameter.name] = build_segment_with_type(
|
||||
segment_type=SegmentType(parameter.type), value=[]
|
||||
)
|
||||
elif parameter.type in (SegmentType.STRING, SegmentType.SECRET):
|
||||
transformed_result[parameter.name] = ""
|
||||
elif parameter.type == SegmentType.NUMBER:
|
||||
transformed_result[parameter.name] = 0
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
transformed_result[parameter.name] = False
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
return transformed_result
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.segments import BooleanSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
@ -158,8 +159,8 @@ class VariableAssignerNode(BaseNode):
|
||||
def get_zero_value(t: SegmentType):
|
||||
# TODO(QuantumGhost): this should be a method of `SegmentType`.
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
return variable_factory.build_segment([])
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN:
|
||||
return variable_factory.build_segment_with_type(t, [])
|
||||
case SegmentType.OBJECT:
|
||||
return variable_factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
@ -170,5 +171,7 @@ def get_zero_value(t: SegmentType):
|
||||
return variable_factory.build_segment(0.0)
|
||||
case SegmentType.NUMBER:
|
||||
return variable_factory.build_segment(0)
|
||||
case SegmentType.BOOLEAN:
|
||||
return BooleanSegment(value=False)
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported variable type: {t}")
|
||||
|
||||
@ -4,9 +4,11 @@ from core.variables import SegmentType
|
||||
EMPTY_VALUE_MAPPING = {
|
||||
SegmentType.STRING: "",
|
||||
SegmentType.NUMBER: 0,
|
||||
SegmentType.BOOLEAN: False,
|
||||
SegmentType.OBJECT: {},
|
||||
SegmentType.ARRAY_ANY: [],
|
||||
SegmentType.ARRAY_STRING: [],
|
||||
SegmentType.ARRAY_NUMBER: [],
|
||||
SegmentType.ARRAY_OBJECT: [],
|
||||
SegmentType.ARRAY_BOOLEAN: [],
|
||||
}
|
||||
|
||||
@ -16,28 +16,15 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.FLOAT,
|
||||
SegmentType.BOOLEAN,
|
||||
}
|
||||
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
|
||||
# Only number variable can be added, subtracted, multiplied or divided
|
||||
return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
|
||||
case Operation.APPEND | Operation.EXTEND:
|
||||
case Operation.APPEND | Operation.EXTEND | Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
|
||||
# Only array variable can be appended or extended
|
||||
return variable_type in {
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_FILE,
|
||||
}
|
||||
case Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
|
||||
# Only array variable can have elements removed
|
||||
return variable_type in {
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_FILE,
|
||||
}
|
||||
return variable_type.is_array_type()
|
||||
case _:
|
||||
return False
|
||||
|
||||
@ -50,7 +37,7 @@ def is_variable_input_supported(*, operation: Operation):
|
||||
|
||||
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
match variable_type:
|
||||
case SegmentType.STRING | SegmentType.OBJECT:
|
||||
case SegmentType.STRING | SegmentType.OBJECT | SegmentType.BOOLEAN:
|
||||
return operation in {Operation.OVER_WRITE, Operation.SET}
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
return operation in {
|
||||
@ -72,6 +59,9 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
case SegmentType.STRING:
|
||||
return isinstance(value, str)
|
||||
|
||||
case SegmentType.BOOLEAN:
|
||||
return isinstance(value, bool)
|
||||
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
if not isinstance(value, int | float):
|
||||
return False
|
||||
@ -91,6 +81,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
return isinstance(value, int | float)
|
||||
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
|
||||
return isinstance(value, dict)
|
||||
case SegmentType.ARRAY_BOOLEAN if operation == Operation.APPEND:
|
||||
return isinstance(value, bool)
|
||||
|
||||
# Array & Extend / Overwrite
|
||||
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
@ -101,6 +93,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
|
||||
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
|
||||
case SegmentType.ARRAY_BOOLEAN if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, bool) for item in value)
|
||||
|
||||
case _:
|
||||
return False
|
||||
|
||||
@ -45,5 +45,5 @@ class SubVariableCondition(BaseModel):
|
||||
class Condition(BaseModel):
|
||||
variable_selector: list[str]
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None = None
|
||||
value: str | Sequence[str] | bool | None = None
|
||||
sub_variable_condition: SubVariableCondition | None = None
|
||||
|
||||
@ -1,13 +1,27 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from core.file import FileAttribute, file_manager
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayBooleanSegment, BooleanSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
from .entities import Condition, SubCondition, SupportedComparisonOperator
|
||||
|
||||
|
||||
def _convert_to_bool(value: Any) -> bool:
|
||||
if isinstance(value, int):
|
||||
return bool(value)
|
||||
|
||||
if isinstance(value, str):
|
||||
loaded = json.loads(value)
|
||||
if isinstance(loaded, (int, bool)):
|
||||
return bool(loaded)
|
||||
|
||||
raise TypeError(f"unexpected value: type={type(value)}, value={value}")
|
||||
|
||||
|
||||
class ConditionProcessor:
|
||||
def process_conditions(
|
||||
self,
|
||||
@ -48,9 +62,16 @@ class ConditionProcessor:
|
||||
)
|
||||
else:
|
||||
actual_value = variable.value if variable else None
|
||||
expected_value = condition.value
|
||||
expected_value: str | Sequence[str] | bool | list[bool] | None = condition.value
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = variable_pool.convert_template(expected_value).text
|
||||
# Here we need to explicit convet the input string to boolean.
|
||||
if isinstance(variable, (BooleanSegment, ArrayBooleanSegment)) and expected_value is not None:
|
||||
# The following two lines is for compatibility with existing workflows.
|
||||
if isinstance(expected_value, list):
|
||||
expected_value = [_convert_to_bool(i) for i in expected_value]
|
||||
else:
|
||||
expected_value = _convert_to_bool(expected_value)
|
||||
input_conditions.append(
|
||||
{
|
||||
"actual_value": actual_value,
|
||||
@ -77,7 +98,7 @@ def _evaluate_condition(
|
||||
*,
|
||||
operator: SupportedComparisonOperator,
|
||||
value: Any,
|
||||
expected: str | Sequence[str] | None,
|
||||
expected: Union[str, Sequence[str], bool | Sequence[bool], None],
|
||||
) -> bool:
|
||||
match operator:
|
||||
case "contains":
|
||||
@ -130,7 +151,7 @@ def _assert_contains(*, value: Any, expected: Any) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
|
||||
if not isinstance(value, str | list):
|
||||
if not isinstance(value, (str, list)):
|
||||
raise ValueError("Invalid actual value type: string or array")
|
||||
|
||||
if expected not in value:
|
||||
@ -142,7 +163,7 @@ def _assert_not_contains(*, value: Any, expected: Any) -> bool:
|
||||
if not value:
|
||||
return True
|
||||
|
||||
if not isinstance(value, str | list):
|
||||
if not isinstance(value, (str, list)):
|
||||
raise ValueError("Invalid actual value type: string or array")
|
||||
|
||||
if expected in value:
|
||||
@ -178,8 +199,8 @@ def _assert_is(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("Invalid actual value type: string")
|
||||
if not isinstance(value, (str, bool)):
|
||||
raise ValueError("Invalid actual value type: string or boolean")
|
||||
|
||||
if value != expected:
|
||||
return False
|
||||
@ -190,8 +211,8 @@ def _assert_is_not(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("Invalid actual value type: string")
|
||||
if not isinstance(value, (str, bool)):
|
||||
raise ValueError("Invalid actual value type: string or boolean")
|
||||
|
||||
if value == expected:
|
||||
return False
|
||||
@ -214,10 +235,13 @@ def _assert_equal(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
if not isinstance(value, (int, float, bool)):
|
||||
raise ValueError("Invalid actual value type: number or boolean")
|
||||
|
||||
if isinstance(value, int):
|
||||
# Handle boolean comparison
|
||||
if isinstance(value, bool):
|
||||
expected = bool(expected)
|
||||
elif isinstance(value, int):
|
||||
expected = int(expected)
|
||||
else:
|
||||
expected = float(expected)
|
||||
@ -231,10 +255,13 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
if not isinstance(value, (int, float, bool)):
|
||||
raise ValueError("Invalid actual value type: number or boolean")
|
||||
|
||||
if isinstance(value, int):
|
||||
# Handle boolean comparison
|
||||
if isinstance(value, bool):
|
||||
expected = bool(expected)
|
||||
elif isinstance(value, int):
|
||||
expected = int(expected)
|
||||
else:
|
||||
expected = float(expected)
|
||||
@ -248,7 +275,7 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
@ -265,7 +292,7 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
@ -282,7 +309,7 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
@ -299,7 +326,7 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
|
||||
Reference in New Issue
Block a user