feat: add backwards invoke node api

This commit is contained in:
Yeuoly
2024-09-24 18:03:48 +08:00
parent 592f85f7a9
commit 68c10a1672
5 changed files with 335 additions and 42 deletions

View File

@ -0,0 +1,114 @@
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from core.workflow.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig,
)
from core.workflow.nodes.parameter_extractor.entities import (
ParameterConfig,
ParameterExtractorNodeData,
)
from core.workflow.nodes.question_classifier.entities import (
ClassConfig,
QuestionClassifierNodeData,
)
from core.workflow.nodes.question_classifier.entities import (
ModelConfig as QuestionClassifierModelConfig,
)
from services.workflow_service import WorkflowService
class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
@classmethod
def invoke_parameter_extractor(
cls,
tenant_id: str,
user_id: str,
parameters: list[ParameterConfig],
model_config: ParameterExtractorModelConfig,
instruction: str,
query: str,
) -> dict:
"""
Invoke parameter extractor node.
:param tenant_id: str
:param user_id: str
:param parameters: list[ParameterConfig]
:param model_config: ModelConfig
:param instruction: str
:param query: str
:return: dict with __reason, __is_success, and other parameters
"""
workflow_service = WorkflowService()
node_id = "1919810"
node_data = ParameterExtractorNodeData(
title="parameter_extractor",
desc="parameter_extractor",
parameters=parameters,
reasoning_mode="function_call",
query=[node_id, "query"],
model=model_config,
instruction=instruction, # instruct with variables are not supported
)
node_data_dict = node_data.model_dump()
execution = workflow_service.run_free_workflow_node(
node_data_dict,
tenant_id=tenant_id,
user_id=user_id,
node_id=node_id,
user_inputs={
f"{node_id}.query": query,
},
)
output = execution.outputs_dict
return output or {
"__reason": "No parameters extracted",
"__is_success": False,
}
@classmethod
def invoke_question_classifier(
cls,
tenant_id: str,
user_id: str,
model_config: QuestionClassifierModelConfig,
classes: list[ClassConfig],
instruction: str,
query: str,
) -> dict:
"""
Invoke question classifier node.
:param tenant_id: str
:param user_id: str
:param model_config: ModelConfig
:param classes: list[ClassConfig]
:param instruction: str
:param query: str
:return: dict with class_name
"""
workflow_service = WorkflowService()
node_id = "1919810"
node_data = QuestionClassifierNodeData(
title="question_classifier",
desc="question_classifier",
query_variable_selector=[node_id, "query"],
model=model_config,
classes=classes,
instruction=instruction, # instruct with variables are not supported
)
node_data_dict = node_data.model_dump()
execution = workflow_service.run_free_workflow_node(
node_data_dict,
tenant_id=tenant_id,
user_id=user_id,
node_id=node_id,
user_inputs={
f"{node_id}.query": query,
},
)
output = execution.outputs_dict
return output or {
"class_name": classes[0].name,
}

View File

@ -14,6 +14,16 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelType
from core.workflow.nodes.question_classifier.entities import (
ClassConfig,
ModelConfig as QuestionClassifierModelConfig,
)
from core.workflow.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig,
)
from core.workflow.nodes.parameter_extractor.entities import (
ParameterConfig,
)
class RequestInvokeTool(BaseModel):
@ -92,11 +102,27 @@ class RequestInvokeModeration(BaseModel):
"""
class RequestInvokeNode(BaseModel):
class RequestInvokeParameterExtractorNode(BaseModel):
"""
Request to invoke node
Request to invoke parameter extractor node
"""
parameters: list[ParameterConfig]
model: ParameterExtractorModelConfig
instruction: str
query: str
class RequestInvokeQuestionClassifierNode(BaseModel):
"""
Request to invoke question classifier node
"""
query: str
model: QuestionClassifierModelConfig
classes: list[ClassConfig]
instruction: str
class RequestInvokeApp(BaseModel):
"""

View File

@ -205,6 +205,88 @@ class WorkflowEntry:
except Exception as e:
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
@classmethod
def run_free_node(
cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]:
"""
Run free node
NOTE: only parameter_extractor/question_classifier are supported
:param node_data: node data
:param user_id: user id
:param user_inputs: user inputs
:return:
"""
# generate a fake graph
node_config = {"id": node_id, "width": 114, "height": 514, "type": "custom", "data": node_data}
graph_dict = {
"nodes": [node_config],
}
node_type = NodeType.value_of(node_data.get("type", ""))
if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
raise ValueError(f"Node type {node_type} not supported")
node_cls = node_classes.get(node_type)
if not node_cls:
raise ValueError(f"Node class not found for node type {node_type}")
graph = Graph.init(graph_config=graph_dict)
# init variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
environment_variables=[],
)
node_cls = cast(type[BaseNode], node_cls)
# init workflow run state
node_instance: BaseNode = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=GraphInitParams(
tenant_id=tenant_id,
app_id="",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="",
graph_config=graph_dict,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
)
try:
# variable selector to variable mapping
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_dict, config=node_config
)
except NotImplementedError:
variable_mapping = {}
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=tenant_id,
node_type=node_type,
node_data=node_instance.node_data,
)
# run node
generator = node_instance.run()
return node_instance, generator
except Exception as e:
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
@classmethod
def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]:
"""