Merge main

This commit is contained in:
Yeuoly
2024-09-14 02:47:01 +08:00
959 changed files with 25695 additions and 24057 deletions

View File

@ -10,45 +10,46 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
class ToolEntity(BaseModel):
provider_id: str
provider_type: ToolProviderType
provider_name: str # redundancy
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_label: str # redundancy
tool_configurations: dict[str, Any]
@field_validator('tool_configurations', mode='before')
@field_validator("tool_configurations", mode="before")
@classmethod
def validate_tool_configurations(cls, value, values: ValidationInfo):
if not isinstance(value, dict):
raise ValueError('tool_configurations must be a dictionary')
for key in values.data.get('tool_configurations', {}).keys():
value = values.data.get('tool_configurations', {}).get(key)
raise ValueError("tool_configurations must be a dictionary")
for key in values.data.get("tool_configurations", {}):
value = values.data.get("tool_configurations", {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f'{key} must be a string')
raise ValueError(f"{key} must be a string")
return value
class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal['mixed', 'variable', 'constant']
type: Literal["mixed", "variable", "constant"]
@field_validator('type', mode='before')
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get('value')
if typ == 'mixed' and not isinstance(value, str):
raise ValueError('value must be a string')
elif typ == 'variable':
value = validation_info.data.get("value")
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError('value must be a list')
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError('value must be a list of strings')
elif typ == 'constant' and not isinstance(value, str | int | float | bool):
raise ValueError('value must be a string, int, float, or bool')
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool):
raise ValueError("value must be a string, int, float, or bool")
return typ
"""

View File

@ -35,10 +35,7 @@ class ToolNode(BaseNode):
node_data = cast(ToolNodeData, self.node_data)
# fetch tool icon
tool_info = {
'provider_type': node_data.provider_type.value,
'provider_id': node_data.provider_id
}
tool_info = {"provider_type": node_data.provider_type.value, "provider_id": node_data.provider_id}
# get tool runtime
try:
@ -50,10 +47,8 @@ class ToolNode(BaseNode):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to get tool runtime: {str(e)}'
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to get tool runtime: {str(e)}",
)
)
return
@ -61,15 +56,13 @@ class ToolNode(BaseNode):
# get parameters
tool_parameters = tool_runtime.get_runtime_parameters() or []
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data
tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
for_log=True
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
for_log=True,
)
try:
@ -86,10 +79,8 @@ class ToolNode(BaseNode):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to invoke tool: {str(e)}',
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool: {str(e)}",
)
)
return
@ -126,12 +117,10 @@ class ToolNode(BaseNode):
result[parameter_name] = None
continue
if parameter.type == ToolParameter.ToolParameterType.FILE:
result[parameter_name] = [
v.to_dict() for v in self._fetch_files(variable_pool)
]
result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)]
else:
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == 'variable':
if tool_input.type == "variable":
parameter_value_segment = variable_pool.get(tool_input.value)
if not parameter_value_segment:
raise Exception("input variable dose not exists")
@ -147,14 +136,16 @@ class ToolNode(BaseNode):
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _transform_message(self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any]) -> Generator[RunEvent, None, None]:
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
) -> Generator[RunEvent, None, None]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
@ -169,66 +160,65 @@ class ToolNode(BaseNode):
files: list[FileVar] = []
text = ""
json: list[dict] = []
variables: dict[str, Any] = {}
for message in message_stream:
if message.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
message.type == ToolInvokeMessage.MessageType.IMAGE:
if message.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
url = message.message.text
ext = path.splitext(url)[1]
mimetype = message.meta.get('mime_type', 'image/jpeg')
filename = message.save_as or url.split('/')[-1]
transfer_method = message.meta.get('transfer_method', FileTransferMethod.TOOL_FILE)
mimetype = message.meta.get("mime_type", "image/jpeg")
filename = message.save_as or url.split("/")[-1]
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
# get tool file id
tool_file_id = url.split('/')[-1].split('.')[0]
files.append(FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
url=url,
related_id=tool_file_id,
filename=filename,
extension=ext,
mime_type=mimetype,
))
tool_file_id = url.split("/")[-1].split(".")[0]
files.append(
FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
url=url,
related_id=tool_file_id,
filename=filename,
extension=ext,
mime_type=mimetype,
)
)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split('/')[-1].split('.')[0]
files.append(FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=message.save_as,
extension=path.splitext(message.save_as)[1],
mime_type=message.meta.get('mime_type', 'application/octet-stream'),
))
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
files.append(
FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=message.save_as,
extension=path.splitext(message.save_as)[1],
mime_type=message.meta.get("mime_type", "application/octet-stream"),
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text + '\n'
text += message.message.text + "\n"
yield RunStreamChunkEvent(
chunk_content=message.message.text,
from_variable_selector=[self.node_id, 'text']
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message, ToolInvokeMessage.JsonMessage)
json.append(message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f'Link: {message.message.text}\n'
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(
chunk_content=stream_text,
from_variable_selector=[self.node_id, 'text']
)
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
@ -241,8 +231,7 @@ class ToolNode(BaseNode):
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value,
from_variable_selector=[self.node_id, variable_name]
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
)
else:
variables[variable_name] = variable_value
@ -250,25 +239,15 @@ class ToolNode(BaseNode):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'text': text,
'files': files,
'json': json,
**variables
},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
inputs=parameters_for_log
outputs={"text": text, "files": files, "json": json, **variables},
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
inputs=parameters_for_log,
)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ToolNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -280,18 +259,16 @@ class ToolNode(BaseNode):
result = {}
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed':
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == 'variable':
elif input.type == "variable":
result[parameter_name] = input.value
elif input.type == 'constant':
elif input.type == "constant":
pass
result = {
node_id + '.' + key: value for key, value in result.items()
}
result = {node_id + "." + key: value for key, value in result.items()}
return result