refactor: move workflow package to dify_graph (#32844)

This commit is contained in:
-LAN-
2026-03-02 18:42:30 +08:00
committed by GitHub
parent 9c33923985
commit c917838f9c
613 changed files with 2008 additions and 2012 deletions

View File

@ -0,0 +1,3 @@
from .parameter_extractor_node import ParameterExtractorNode
__all__ = ["ParameterExtractorNode"]

View File

@ -0,0 +1,129 @@
from typing import Annotated, Any, Literal
from pydantic import (
BaseModel,
BeforeValidator,
Field,
field_validator,
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.nodes.base import BaseNodeData
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
from dify_graph.variables.types import SegmentType
_OLD_BOOL_TYPE_NAME = "bool"
_OLD_SELECT_TYPE_NAME = "select"
_VALID_PARAMETER_TYPES = frozenset(
[
SegmentType.STRING, # "string",
SegmentType.NUMBER, # "number",
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
_OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node
_OLD_SELECT_TYPE_NAME, # string type with enumeration choices.
]
)
def _validate_type(parameter_type: str) -> SegmentType:
if parameter_type not in _VALID_PARAMETER_TYPES:
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
if parameter_type == _OLD_BOOL_TYPE_NAME:
return SegmentType.BOOLEAN
elif parameter_type == _OLD_SELECT_TYPE_NAME:
return SegmentType.STRING
return SegmentType(parameter_type)
class ParameterConfig(BaseModel):
"""
Parameter Config.
"""
name: str
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
options: list[str] | None = None
description: str
required: bool
@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value) -> str:
if not value:
raise ValueError("Parameter name is required")
if value in {"__reason", "__is_success"}:
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
return str(value)
def is_array_type(self) -> bool:
return self.type.is_array_type()
def element_type(self) -> SegmentType:
"""Return the element type of the parameter.
Raises a ValueError if the parameter's type is not an array type.
"""
element_type = self.type.element_type()
# At this point, self.type is guaranteed to be one of `ARRAY_STRING`,
# `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`.
#
# See: _VALID_PARAMETER_TYPES for reference.
assert element_type is not None, f"the element type should not be None, {self.type=}"
return element_type
class ParameterExtractorNodeData(BaseNodeData):
"""
Parameter Extractor Node Data.
"""
model: ModelConfig
query: list[str]
parameters: list[ParameterConfig]
instruction: str | None = None
memory: MemoryConfig | None = None
reasoning_mode: Literal["function_call", "prompt"]
vision: VisionConfig = Field(default_factory=VisionConfig)
@field_validator("reasoning_mode", mode="before")
@classmethod
def set_reasoning_mode(cls, v) -> str:
return v or "function_call"
def get_parameter_json_schema(self):
"""
Get parameter json schema.
:return: parameter json schema
"""
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description}
if parameter.type == SegmentType.STRING:
parameter_schema["type"] = "string"
elif parameter.type.is_array_type():
parameter_schema["type"] = "array"
element_type = parameter.type.element_type()
if element_type is None:
raise AssertionError("element type should not be None.")
parameter_schema["items"] = {"type": element_type.value}
else:
parameter_schema["type"] = parameter.type
if parameter.options:
parameter_schema["enum"] = parameter.options
parameters["properties"][parameter.name] = parameter_schema
if parameter.required:
parameters["required"].append(parameter.name)
return parameters

View File

@ -0,0 +1,75 @@
from typing import Any
from dify_graph.variables.types import SegmentType
class ParameterExtractorNodeError(ValueError):
"""Base error for ParameterExtractorNode."""
class InvalidModelTypeError(ParameterExtractorNodeError):
"""Raised when the model is not a Large Language Model."""
class ModelSchemaNotFoundError(ParameterExtractorNodeError):
"""Raised when the model schema is not found."""
class InvalidInvokeResultError(ParameterExtractorNodeError):
"""Raised when the invoke result is invalid."""
class InvalidTextContentTypeError(ParameterExtractorNodeError):
"""Raised when the text content type is invalid."""
class InvalidNumberOfParametersError(ParameterExtractorNodeError):
"""Raised when the number of parameters is invalid."""
class RequiredParameterMissingError(ParameterExtractorNodeError):
"""Raised when a required parameter is missing."""
class InvalidSelectValueError(ParameterExtractorNodeError):
"""Raised when a select value is invalid."""
class InvalidNumberValueError(ParameterExtractorNodeError):
"""Raised when a number value is invalid."""
class InvalidBoolValueError(ParameterExtractorNodeError):
"""Raised when a bool value is invalid."""
class InvalidStringValueError(ParameterExtractorNodeError):
"""Raised when a string value is invalid."""
class InvalidArrayValueError(ParameterExtractorNodeError):
"""Raised when an array value is invalid."""
class InvalidModelModeError(ParameterExtractorNodeError):
"""Raised when the model mode is invalid."""
class InvalidValueTypeError(ParameterExtractorNodeError):
def __init__(
self,
/,
parameter_name: str,
expected_type: SegmentType,
actual_type: SegmentType | None,
value: Any,
):
message = (
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
)
super().__init__(message)
self.parameter_name = parameter_name
self.expected_type = expected_type
self.actual_type = actual_type
self.value = value

View File

@ -0,0 +1,854 @@
import contextlib
import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from core.model_manager import ModelInstance
from core.model_runtime.entities import ImagePromptMessageContent
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageRole,
PromptMessageTool,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base import variable_template_parser
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.llm import llm_utils
from dify_graph.runtime import VariablePool
from dify_graph.variables.types import ArrayValidation, SegmentType
from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData
from .exc import (
InvalidModelModeError,
InvalidModelTypeError,
InvalidNumberOfParametersError,
InvalidSelectValueError,
InvalidTextContentTypeError,
InvalidValueTypeError,
ModelSchemaNotFoundError,
ParameterExtractorNodeError,
RequiredParameterMissingError,
)
from .prompts import (
CHAT_EXAMPLE,
CHAT_GENERATE_JSON_PROMPT,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
COMPLETION_GENERATE_JSON_PROMPT,
FUNCTION_CALLING_EXTRACTOR_EXAMPLE,
FUNCTION_CALLING_EXTRACTOR_NAME,
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT,
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.runtime import GraphRuntimeState
def extract_json(text):
"""
From a given JSON started from '{' or '[' extract the complete JSON object.
"""
stack = []
for i, c in enumerate(text):
if c in {"{", "["}:
stack.append(c)
elif c in {"}", "]"}:
# check if stack is empty
if not stack:
return text[:i]
# check if the last element in stack is matching
if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["):
stack.pop()
if not stack:
return text[: i + 1]
else:
return text[:i]
return None
class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
Parameter Extractor Node.
"""
node_type = NodeType.PARAMETER_EXTRACTOR
_model_instance: ModelInstance
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_memory: PromptMessageMemory | None
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
memory: PromptMessageMemory | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"model": {
"prompt_templates": {
"completion_model": {
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
"stop": ["Human:"],
}
}
}
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
"""
Run the node.
"""
node_data = self.node_data
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""
variable_pool = self.graph_runtime_state.variable_pool
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=node_data.vision.configs.variable_selector,
)
if node_data.vision.enabled
else []
)
model_instance = self._model_instance
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise InvalidModelTypeError("Model is not a Large Language Model")
try:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
memory = self._memory
if (
set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}
and node_data.reasoning_mode == "function_call"
):
# use function call
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
node_data=node_data,
query=query,
variable_pool=self.graph_runtime_state.variable_pool,
model_instance=model_instance,
memory=memory,
files=files,
vision_detail=node_data.vision.configs.detail,
)
else:
# use prompt engineering
prompt_messages = self._generate_prompt_engineering_prompt(
data=node_data,
query=query,
variable_pool=self.graph_runtime_state.variable_pool,
model_instance=model_instance,
memory=memory,
files=files,
vision_detail=node_data.vision.configs.detail,
)
prompt_message_tools = []
inputs = {
"query": query,
"files": [f.to_dict() for f in files],
"parameters": jsonable_encoder(node_data.parameters),
"instruction": jsonable_encoder(node_data.instruction),
}
process_data = {
"model_mode": node_data.model.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=node_data.model.mode, prompt_messages=prompt_messages
),
"usage": None,
"function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
"tool_call": None,
"model_provider": model_instance.provider,
"model_name": model_instance.model_name,
}
try:
text, usage, tool_call = self._invoke(
model_instance=model_instance,
prompt_messages=prompt_messages,
tools=prompt_message_tools,
stop=model_instance.stop,
)
process_data["usage"] = jsonable_encoder(usage)
process_data["tool_call"] = jsonable_encoder(tool_call)
process_data["llm_text"] = text
except ParameterExtractorNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
outputs={"__is_success": 0, "__reason": str(e)},
error=str(e),
metadata={},
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)},
error=str(e),
metadata={},
)
error = None
if tool_call:
result = self._extract_json_from_tool_call(tool_call)
else:
result = self._extract_complete_json_response(text)
if not result:
result = self._generate_default_result(node_data)
error = "Failed to extract result from function call or text response, using empty result."
try:
result = self._validate_result(data=node_data, result=result or {})
except ParameterExtractorNodeError as e:
error = str(e)
# transform result into standard format
result = self._transform_result(data=node_data, result=result or {})
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={
"__is_success": 1 if not error else 0,
"__reason": error,
"__usage": jsonable_encoder(usage),
**result,
},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
def _invoke(
self,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
stop: Sequence[str],
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=dict(model_instance.parameters),
tools=tools,
stop=list(stop),
stream=False,
user=self.user_id,
)
# handle invoke result
text = invoke_result.message.get_text_content()
if not isinstance(text, str):
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
return text, usage, tool_call
def _generate_function_call_prompt(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
"""
Generate function call prompt.
"""
query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(
content=query, structure=json.dumps(node_data.get_parameter_json_schema())
)
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query,
variable_pool=variable_pool,
model_instance=model_instance,
context="",
)
prompt_template = self._get_function_calling_prompt_template(
node_data, query, variable_pool, memory, rest_token
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=files,
context="",
memory_config=node_data.memory,
memory=None,
model_instance=model_instance,
image_detail_config=vision_detail,
)
# find last user message
last_user_message_idx = -1
for i, prompt_message in enumerate(prompt_messages):
if prompt_message.role == PromptMessageRole.USER:
last_user_message_idx = i
# add function call messages before last user message
example_messages = []
for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE:
id = uuid.uuid4().hex
example_messages.extend(
[
UserPromptMessage(content=example["user"]["query"]),
AssistantPromptMessage(
content=example["assistant"]["text"],
tool_calls=[
AssistantPromptMessage.ToolCall(
id=id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=example["assistant"]["function_call"]["name"],
arguments=json.dumps(example["assistant"]["function_call"]["parameters"]),
),
)
],
),
ToolPromptMessage(
content="Great! You have called the function with the correct parameters.", tool_call_id=id
),
AssistantPromptMessage(
content="I have extracted the parameters, let's move on.",
),
]
)
prompt_messages = (
prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:]
)
# generate tool
tool = PromptMessageTool(
name=FUNCTION_CALLING_EXTRACTOR_NAME,
description="Extract parameters from the natural language text",
parameters=node_data.get_parameter_json_schema(),
)
return prompt_messages, [tool]
def _generate_prompt_engineering_prompt(
self,
data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate prompt engineering prompt.
"""
model_mode = ModelMode(data.model.mode)
if model_mode == ModelMode.COMPLETION:
return self._generate_prompt_engineering_completion_prompt(
node_data=data,
query=query,
variable_pool=variable_pool,
model_instance=model_instance,
memory=memory,
files=files,
vision_detail=vision_detail,
)
elif model_mode == ModelMode.CHAT:
return self._generate_prompt_engineering_chat_prompt(
node_data=data,
query=query,
variable_pool=variable_pool,
model_instance=model_instance,
memory=memory,
files=files,
vision_detail=vision_detail,
)
else:
raise InvalidModelModeError(f"Invalid model mode: {model_mode}")
def _generate_prompt_engineering_completion_prompt(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate completion prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query,
variable_pool=variable_pool,
model_instance=model_instance,
context="",
)
prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={"structure": json.dumps(node_data.get_parameter_json_schema())},
query="",
files=files,
context="",
memory_config=node_data.memory,
# AdvancedPromptTransform is still typed against TokenBufferMemory.
memory=cast(Any, memory),
model_instance=model_instance,
image_detail_config=vision_detail,
)
return prompt_messages
def _generate_prompt_engineering_chat_prompt(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate chat prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query,
variable_pool=variable_pool,
model_instance=model_instance,
context="",
)
prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data,
query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
structure=json.dumps(node_data.get_parameter_json_schema()), text=query
),
variable_pool=variable_pool,
memory=memory,
max_token_limit=rest_token,
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=files,
context="",
memory_config=node_data.memory,
memory=None,
model_instance=model_instance,
image_detail_config=vision_detail,
)
# find last user message
last_user_message_idx = -1
for i, prompt_message in enumerate(prompt_messages):
if prompt_message.role == PromptMessageRole.USER:
last_user_message_idx = i
# add example messages before last user message
example_messages = []
for example in CHAT_EXAMPLE:
example_messages.extend(
[
UserPromptMessage(
content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
structure=json.dumps(example["user"]["json"]),
text=example["user"]["query"],
)
),
AssistantPromptMessage(
content=json.dumps(example["assistant"]["json"]),
),
]
)
prompt_messages = (
prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:]
)
return prompt_messages
def _validate_result(self, data: ParameterExtractorNodeData, result: dict):
if len(data.parameters) != len(result):
raise InvalidNumberOfParametersError("Invalid number of parameters")
for parameter in data.parameters:
if parameter.required and parameter.name not in result:
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
param_value = result.get(parameter.name)
if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL):
inferred_type = SegmentType.infer_segment_type(param_value)
raise InvalidValueTypeError(
parameter_name=parameter.name,
expected_type=parameter.type,
actual_type=inferred_type,
value=param_value,
)
if parameter.type == SegmentType.STRING and parameter.options:
if param_value not in parameter.options:
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
return result
@staticmethod
def _transform_number(value: int | float | str | bool) -> int | float | None:
"""
Attempts to transform the input into an integer or float.
Returns:
int or float: The transformed number if the conversion is successful.
None: If the transformation fails.
Note:
Boolean values `True` and `False` are converted to integers `1` and `0`, respectively.
This behavior ensures compatibility with existing workflows that may use boolean types as integers.
"""
if isinstance(value, bool):
return int(value)
elif isinstance(value, (int, float)):
return value
elif isinstance(value, str):
if "." in value:
try:
return float(value)
except ValueError:
return None
else:
try:
return int(value)
except ValueError:
return None
else:
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
"""
Transform result into standard format.
"""
transformed_result: dict[str, Any] = {}
for parameter in data.parameters:
if parameter.name in result:
param_value = result[parameter.name]
# transform value
if parameter.type == SegmentType.NUMBER:
transformed = self._transform_number(param_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
elif parameter.type == SegmentType.BOOLEAN:
if isinstance(result[parameter.name], (bool, int)):
transformed_result[parameter.name] = bool(result[parameter.name])
# elif isinstance(result[parameter.name], str):
# if result[parameter.name].lower() in ["true", "false"]:
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true")
elif parameter.type == SegmentType.STRING:
if isinstance(param_value, str):
transformed_result[parameter.name] = param_value
elif parameter.is_array_type():
if isinstance(param_value, list):
nested_type = parameter.element_type()
assert nested_type is not None
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
transformed_result[parameter.name] = segment_value
for item in param_value:
if nested_type == SegmentType.NUMBER:
transformed = self._transform_number(item)
if transformed is not None:
segment_value.value.append(transformed)
elif nested_type == SegmentType.STRING:
if isinstance(item, str):
segment_value.value.append(item)
elif nested_type == SegmentType.OBJECT:
if isinstance(item, dict):
segment_value.value.append(item)
elif nested_type == SegmentType.BOOLEAN:
if isinstance(item, bool):
segment_value.value.append(item)
if parameter.name not in transformed_result:
if parameter.type.is_array_type():
transformed_result[parameter.name] = build_segment_with_type(
segment_type=SegmentType(parameter.type), value=[]
)
elif parameter.type in (SegmentType.STRING, SegmentType.SECRET):
transformed_result[parameter.name] = ""
elif parameter.type == SegmentType.NUMBER:
transformed_result[parameter.name] = 0
elif parameter.type == SegmentType.BOOLEAN:
transformed_result[parameter.name] = False
else:
raise AssertionError("this statement should be unreachable.")
return transformed_result
def _extract_complete_json_response(self, result: str) -> dict | None:
"""
Extract complete json response.
"""
# extract json from the text
for idx in range(len(result)):
if result[idx] == "{" or result[idx] == "[":
json_str = extract_json(result[idx:])
if json_str:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
logger.info("extra error: %s", result)
return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
"""
Extract json from tool call.
"""
if not tool_call or not tool_call.function.arguments:
return None
result = tool_call.function.arguments
# extract json from the arguments
for idx in range(len(result)):
if result[idx] == "{" or result[idx] == "[":
json_str = extract_json(result[idx:])
if json_str:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
logger.info("extra error: %s", result)
return None
def _generate_default_result(self, data: ParameterExtractorNodeData):
"""
Generate default result.
"""
result: dict[str, Any] = {}
for parameter in data.parameters:
if parameter.type == "number":
result[parameter.name] = 0
elif parameter.type == "boolean":
result[parameter.name] = False
elif parameter.type in {"string", "select"}:
result[parameter.name] = ""
return result
def _get_function_calling_prompt_template(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode)
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window:
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]
else:
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
def _get_prompt_engineering_prompt_template(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window:
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]
elif model_mode == ModelMode.COMPLETION:
return CompletionModelPromptTemplate(
text=COMPLETION_GENERATE_JSON_PROMPT.format(
histories=memory_str, text=input_text, instruction=instruction
)
.replace("{γγγ", "")
.replace("}γγγ", "")
)
else:
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
def _calculate_rest_token(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
context: str | None,
) -> int:
try:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
else:
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_instance=model_instance,
)
rest_tokens = 2000
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
curr_message_tokens = (
model_type_instance.get_num_tokens(
model_instance.model_name, model_instance.credentials, prompt_messages
)
+ 1000
) # add 1000 to ensure tool call messages
max_tokens = 0
for parameter_rule in model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_instance.parameters.get(parameter_rule.name)
or model_instance.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
if typed_node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
for selector in selectors:
variable_mapping[selector.variable] = selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
return variable_mapping

View File

@ -0,0 +1,184 @@
from typing import Any
FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters"
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy.
### Task
Always call the `{FUNCTION_CALLING_EXTRACTOR_NAME}` function with the correct parameters. Ensure that the information extraction is contextual and aligns with the provided criteria.
### Memory
Here is the chat history between the human and assistant, provided within <histories> tags:
<histories>
\x7bhistories\x7d
</histories>
### Instructions:
Some additional information is provided below. Always adhere to these instructions as closely as possible:
<instruction>
\x7binstruction\x7d
</instruction>
Steps:
1. Review the chat history provided within the <histories> tags.
2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text.
3. Generate a well-formatted output using the defined functions and arguments.
4. Use the `extract_parameter` function to create structured outputs with appropriate parameters.
5. Do not include any XML tags in your output.
### Example
To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples.
### Final Output
Produce well-formatted function calls in json without XML tags, as shown in the example.
""" # noqa: E501
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside <context></context> XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside <structure></structure> XML tags.
<context>
\x7bcontent\x7d
</context>
<structure>
\x7bstructure\x7d
</structure>
""" # noqa: E501
FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [
{
"user": {
"query": "What is the weather today in SF?",
"function": {
"name": FUNCTION_CALLING_EXTRACTOR_NAME,
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The location to get the weather information",
"required": True,
},
},
"required": ["location"],
},
},
},
"assistant": {
"text": "I need always call the function with the correct parameters."
" in this case, I need to call the function with the location parameter.",
"function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"location": "San Francisco"}},
},
},
{
"user": {
"query": "I want to eat some apple pie.",
"function": {
"name": FUNCTION_CALLING_EXTRACTOR_NAME,
"parameters": {
"type": "object",
"properties": {"food": {"type": "string", "description": "The food to eat", "required": True}},
"required": ["food"],
},
},
},
"assistant": {
"text": "I need always call the function with the correct parameters."
" in this case, I need to call the function with the food parameter.",
"function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"food": "apple pie"}},
},
},
]
COMPLETION_GENERATE_JSON_PROMPT = """### Instructions:
Some extra information are provided below, I should always follow the instructions as possible as I can.
<instructions>
{instruction}
</instructions>
### Extract parameter Workflow
I need to extract the following information from the input text. The <information to be extracted> tag specifies the 'type', 'description' and 'required' of the information to be extracted.
<information to be extracted>
{{ structure }}
</information to be extracted>
Step 1: Carefully read the input and understand the structure of the expected output.
Step 2: Extract relevant parameters from the provided text based on the name and description of object.
Step 3: Structure the extracted parameters to JSON object as specified in <structure>.
Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted.
### Memory
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
### Structure
Here is the structure of the expected output, I should always follow the output structure.
{{γγγ
'properties1': 'relevant text extracted from input',
'properties2': 'relevant text extracted from input',
}}γγγ
### Input Text
Inside <text></text> XML tags, there is a text that I should extract parameters and convert to a JSON object.
<text>
{text}
</text>
### Answer
I should always output a valid JSON object. Output nothing other than the JSON object.
```JSON
""" # noqa: E501
CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object.
The structure of the JSON object you can found in the instructions.
### Memory
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
### Instructions:
Some extra information are provided below, you should always follow the instructions as possible as you can.
<instructions>
{instructions}
</instructions>
"""
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE = """### Structure
Here is the structure of the JSON object, you should always follow the structure.
<structure>
{structure}
</structure>
### Text to be converted to JSON
Inside <text></text> XML tags, there is a text that you should convert to a JSON object.
<text>
{text}
</text>
"""
CHAT_EXAMPLE = [
{
"user": {
"query": "What is the weather today in SF?",
"json": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The location to get the weather information",
"required": True,
}
},
"required": ["location"],
},
},
"assistant": {"text": "I need to output a valid JSON object.", "json": {"location": "San Francisco"}},
},
{
"user": {
"query": "I want to eat some apple pie.",
"json": {
"type": "object",
"properties": {"food": {"type": "string", "description": "The food to eat", "required": True}},
"required": ["food"],
},
},
"assistant": {"text": "I need to output a valid JSON object.", "json": {"food": "apple pie"}},
},
]