mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
feat: add assemble variable builder api
This commit is contained in:
@ -1,8 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
import json_repair
|
||||
|
||||
@ -398,6 +398,488 @@ class LLMGenerator:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
|
||||
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@classmethod
|
||||
def generate_with_context(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
parameter_name: str,
|
||||
language: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate extractor code node based on conversation context.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant/workspace ID
|
||||
workflow_id: Workflow ID
|
||||
node_id: Current tool/llm node ID
|
||||
parameter_name: Parameter name to generate code for
|
||||
language: Code language (python3/javascript)
|
||||
prompt_messages: Multi-turn conversation history (last message is instruction)
|
||||
model_config: Model configuration (provider, name, completion_params)
|
||||
|
||||
Returns:
|
||||
dict with CodeNodeData format:
|
||||
- variables: Input variable selectors
|
||||
- code_language: Code language
|
||||
- code: Generated code
|
||||
- outputs: Output definitions
|
||||
- message: Explanation
|
||||
- error: Error message if any
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
# Get workflow
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(App).where(App.id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
return cls._error_response(f"App {workflow_id} not found")
|
||||
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return cls._error_response(f"Workflow for app {workflow_id} not found")
|
||||
|
||||
# Get upstream nodes via edge backtracking
|
||||
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
|
||||
|
||||
# Get current node info
|
||||
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
|
||||
if not current_node:
|
||||
return cls._error_response(f"Node {node_id} not found")
|
||||
|
||||
# Get parameter info
|
||||
parameter_info = cls._get_parameter_info(
|
||||
tenant_id=tenant_id,
|
||||
node_data=current_node.get("data", {}),
|
||||
parameter_name=parameter_name,
|
||||
)
|
||||
|
||||
# Build system prompt
|
||||
system_prompt = cls._build_extractor_system_prompt(
|
||||
upstream_nodes=upstream_nodes,
|
||||
current_node=current_node,
|
||||
parameter_info=parameter_info,
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Construct complete prompt_messages with system prompt
|
||||
complete_messages: list[PromptMessage] = [
|
||||
SystemPromptMessage(content=system_prompt),
|
||||
*prompt_messages,
|
||||
]
|
||||
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
|
||||
# Get model instance and schema
|
||||
provider = model_config.get("provider", "")
|
||||
model_name = model_config.get("name", "")
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
if not model_schema:
|
||||
return cls._error_response(f"Model schema not found for {model_name}")
|
||||
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
json_schema = cls._get_code_node_json_schema()
|
||||
|
||||
try:
|
||||
response = invoke_llm_with_structured_output(
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=complete_messages,
|
||||
json_schema=json_schema,
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return cls._parse_code_node_output(
|
||||
response.structured_output, language, parameter_info.get("type", "string")
|
||||
)
|
||||
|
||||
except InvokeError as e:
|
||||
return cls._error_response(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate with context, model: %s", model_config.get("name"))
|
||||
return cls._error_response(f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def _error_response(cls, error: str) -> dict:
|
||||
"""Return error response in CodeNodeData format."""
|
||||
return {
|
||||
"variables": [],
|
||||
"code_language": "python3",
|
||||
"code": "",
|
||||
"outputs": {},
|
||||
"message": "",
|
||||
"error": error,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
parameter_name: str,
|
||||
language: str,
|
||||
model_config: dict | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate suggested questions for context generation.
|
||||
|
||||
Returns dict with questions array and error field.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
# Get workflow context (reuse existing logic)
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(App).where(App.id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
return {"questions": [], "error": f"App {workflow_id} not found"}
|
||||
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return {"questions": [], "error": f"Workflow for app {workflow_id} not found"}
|
||||
|
||||
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
|
||||
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
|
||||
if not current_node:
|
||||
return {"questions": [], "error": f"Node {node_id} not found"}
|
||||
|
||||
parameter_info = cls._get_parameter_info(
|
||||
tenant_id=tenant_id,
|
||||
node_data=current_node.get("data", {}),
|
||||
parameter_name=parameter_name,
|
||||
)
|
||||
|
||||
# Build prompt
|
||||
system_prompt = cls._build_suggested_questions_prompt(
|
||||
upstream_nodes=upstream_nodes,
|
||||
current_node=current_node,
|
||||
parameter_info=parameter_info,
|
||||
language=language,
|
||||
)
|
||||
|
||||
prompt_messages: list[PromptMessage] = [
|
||||
SystemPromptMessage(content=system_prompt),
|
||||
]
|
||||
|
||||
# Get model instance - use default if model_config not provided
|
||||
model_manager = ModelManager()
|
||||
if model_config:
|
||||
provider = model_config.get("provider", "")
|
||||
model_name = model_config.get("name", "")
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
else:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
model_name = model_instance.model
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
if not model_schema:
|
||||
return {"questions": [], "error": f"Model schema not found for {model_name}"}
|
||||
|
||||
completion_params = model_config.get("completion_params", {}) if model_config else {}
|
||||
model_parameters = {**completion_params, "max_tokens": 256}
|
||||
json_schema = cls._get_suggested_questions_json_schema()
|
||||
|
||||
try:
|
||||
response = invoke_llm_with_structured_output(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
questions = response.structured_output.get("questions", []) if response.structured_output else []
|
||||
return {"questions": questions, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
return {"questions": [], "error": str(e)}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate suggested questions, model: %s", model_name)
|
||||
return {"questions": [], "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@classmethod
|
||||
def _build_suggested_questions_prompt(
|
||||
cls,
|
||||
upstream_nodes: list[dict],
|
||||
current_node: dict,
|
||||
parameter_info: dict,
|
||||
language: str = "English",
|
||||
) -> str:
|
||||
"""Build minimal prompt for suggested questions generation."""
|
||||
# Simplify upstream nodes to reduce tokens
|
||||
sources = [f"{n['title']}({','.join(n.get('outputs', {}).keys())})" for n in upstream_nodes[:5]]
|
||||
param_type = parameter_info.get("type", "string")
|
||||
param_desc = parameter_info.get("description", "")[:100]
|
||||
|
||||
return f"""Suggest 3 code generation questions for extracting data.
|
||||
Sources: {", ".join(sources)}
|
||||
Target: {parameter_info.get("name")}({param_type}) - {param_desc}
|
||||
Output 3 short, practical questions in {language}."""
|
||||
|
||||
@classmethod
|
||||
def _get_suggested_questions_json_schema(cls) -> dict:
|
||||
"""Return JSON Schema for suggested questions."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"questions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 3,
|
||||
"maxItems": 3,
|
||||
"description": "3 suggested questions",
|
||||
},
|
||||
},
|
||||
"required": ["questions"],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_code_node_json_schema(cls) -> dict:
|
||||
"""Return JSON Schema for structured output."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"variables": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"variable": {"type": "string", "description": "Variable name in code"},
|
||||
"value_selector": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Path like [node_id, output_name]",
|
||||
},
|
||||
},
|
||||
"required": ["variable", "value_selector"],
|
||||
},
|
||||
},
|
||||
"code": {"type": "string", "description": "Generated code with main function"},
|
||||
"outputs": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"properties": {"type": {"type": "string"}},
|
||||
},
|
||||
"description": "Output definitions, key is output name",
|
||||
},
|
||||
"explanation": {"type": "string", "description": "Brief explanation of the code"},
|
||||
},
|
||||
"required": ["variables", "code", "outputs", "explanation"],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_upstream_nodes(cls, graph_dict: Mapping[str, Any], node_id: str) -> list[dict]:
|
||||
"""
|
||||
Get all upstream nodes via edge backtracking.
|
||||
|
||||
Traverses the graph backwards from node_id to collect all reachable nodes.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
nodes = {n["id"]: n for n in graph_dict.get("nodes", [])}
|
||||
edges = graph_dict.get("edges", [])
|
||||
|
||||
# Build reverse adjacency list
|
||||
reverse_adj: dict[str, list[str]] = defaultdict(list)
|
||||
for edge in edges:
|
||||
reverse_adj[edge["target"]].append(edge["source"])
|
||||
|
||||
# BFS to find all upstream nodes
|
||||
visited: set[str] = set()
|
||||
queue = [node_id]
|
||||
upstream: list[dict] = []
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for source in reverse_adj.get(current, []):
|
||||
if source not in visited:
|
||||
visited.add(source)
|
||||
queue.append(source)
|
||||
if source in nodes:
|
||||
upstream.append(cls._extract_node_info(nodes[source]))
|
||||
|
||||
return upstream
|
||||
|
||||
@classmethod
|
||||
def _get_node_by_id(cls, graph_dict: Mapping[str, Any], node_id: str) -> dict | None:
|
||||
"""Get node by ID from graph."""
|
||||
for node in graph_dict.get("nodes", []):
|
||||
if node["id"] == node_id:
|
||||
return node
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_node_info(cls, node: dict) -> dict:
|
||||
"""Extract minimal node info with outputs based on node type."""
|
||||
node_type = node["data"]["type"]
|
||||
node_data = node.get("data", {})
|
||||
|
||||
# Build outputs based on node type (only type, no description to reduce tokens)
|
||||
outputs: dict[str, str] = {}
|
||||
match node_type:
|
||||
case "start":
|
||||
for var in node_data.get("variables", []):
|
||||
name = var.get("variable", var.get("name", ""))
|
||||
outputs[name] = var.get("type", "string")
|
||||
case "llm":
|
||||
outputs["text"] = "string"
|
||||
case "code":
|
||||
for name, output in node_data.get("outputs", {}).items():
|
||||
outputs[name] = output.get("type", "string")
|
||||
case "http-request":
|
||||
outputs = {"body": "string", "status_code": "number", "headers": "object"}
|
||||
case "knowledge-retrieval":
|
||||
outputs["result"] = "array[object]"
|
||||
case "tool":
|
||||
outputs = {"text": "string", "json": "object"}
|
||||
case _:
|
||||
outputs["output"] = "string"
|
||||
|
||||
info: dict = {
|
||||
"id": node["id"],
|
||||
"title": node_data.get("title", node["id"]),
|
||||
"outputs": outputs,
|
||||
}
|
||||
# Only include description if not empty
|
||||
desc = node_data.get("desc", "")
|
||||
if desc:
|
||||
info["desc"] = desc
|
||||
|
||||
return info
|
||||
|
||||
@classmethod
|
||||
def _get_parameter_info(cls, tenant_id: str, node_data: dict, parameter_name: str) -> dict:
|
||||
"""Get parameter info from tool schema using ToolManager."""
|
||||
default_info = {"name": parameter_name, "type": "string", "description": ""}
|
||||
|
||||
if node_data.get("type") != "tool":
|
||||
return default_info
|
||||
|
||||
try:
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
provider_type_str = node_data.get("provider_type", "")
|
||||
provider_type = ToolProviderType(provider_type_str) if provider_type_str else ToolProviderType.BUILT_IN
|
||||
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=provider_type,
|
||||
provider_id=node_data.get("provider_id", ""),
|
||||
tool_name=node_data.get("tool_name", ""),
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
parameters = tool_runtime.get_merged_runtime_parameters()
|
||||
for param in parameters:
|
||||
if param.name == parameter_name:
|
||||
return {
|
||||
"name": param.name,
|
||||
"type": param.type.value if hasattr(param.type, "value") else str(param.type),
|
||||
"description": param.llm_description
|
||||
or (param.human_description.en_US if param.human_description else ""),
|
||||
"required": param.required,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug("Failed to get parameter info from ToolManager: %s", e)
|
||||
|
||||
return default_info
|
||||
|
||||
@classmethod
|
||||
def _build_extractor_system_prompt(
|
||||
cls,
|
||||
upstream_nodes: list[dict],
|
||||
current_node: dict,
|
||||
parameter_info: dict,
|
||||
language: str,
|
||||
) -> str:
|
||||
"""Build system prompt for extractor code generation."""
|
||||
upstream_json = json.dumps(upstream_nodes, indent=2, ensure_ascii=False)
|
||||
param_type = parameter_info.get("type", "string")
|
||||
return f"""You are a code generator for workflow automation.
|
||||
|
||||
Generate {language} code to extract/transform upstream node outputs for the target parameter.
|
||||
|
||||
## Upstream Nodes
|
||||
{upstream_json}
|
||||
|
||||
## Target
|
||||
Node: {current_node["data"].get("title", current_node["id"])}
|
||||
Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("description", "")}
|
||||
|
||||
## Requirements
|
||||
- Write a main function that returns type: {param_type}
|
||||
- Use value_selector format: ["node_id", "output_name"]
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _parse_code_node_output(cls, content: Mapping[str, Any] | None, language: str, parameter_type: str) -> dict:
|
||||
"""
|
||||
Parse structured output to CodeNodeData format.
|
||||
|
||||
Args:
|
||||
content: Structured output dict from invoke_llm_with_structured_output
|
||||
language: Code language
|
||||
parameter_type: Expected parameter type
|
||||
|
||||
Returns dict with variables, code_language, code, outputs, message, error.
|
||||
"""
|
||||
if content is None:
|
||||
return cls._error_response("Empty or invalid response from LLM")
|
||||
|
||||
# Validate and normalize variables
|
||||
variables = [
|
||||
{"variable": v.get("variable", ""), "value_selector": v.get("value_selector", [])}
|
||||
for v in content.get("variables", [])
|
||||
if isinstance(v, dict)
|
||||
]
|
||||
|
||||
outputs = content.get("outputs", {"result": {"type": parameter_type}})
|
||||
|
||||
return {
|
||||
"variables": variables,
|
||||
"code_language": language,
|
||||
"code": content.get("code", ""),
|
||||
"outputs": outputs,
|
||||
"message": content.get("explanation", ""),
|
||||
"error": "",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def instruction_modify_legacy(
|
||||
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
|
||||
|
||||
45
api/core/llm_generator/utils.py
Normal file
45
api/core/llm_generator/utils.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""Utility functions for LLM generator."""
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_prompt_messages(messages: list[dict]) -> list[PromptMessage]:
|
||||
"""
|
||||
Deserialize list of dicts to list[PromptMessage].
|
||||
|
||||
Expected format:
|
||||
[
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."},
|
||||
]
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
for msg in messages:
|
||||
role = PromptMessageRole.value_of(msg["role"])
|
||||
content = msg.get("content", "")
|
||||
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
result.append(UserPromptMessage(content=content))
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
result.append(AssistantPromptMessage(content=content))
|
||||
case PromptMessageRole.SYSTEM:
|
||||
result.append(SystemPromptMessage(content=content))
|
||||
case PromptMessageRole.TOOL:
|
||||
result.append(ToolPromptMessage(content=content, tool_call_id=msg.get("tool_call_id", "")))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def serialize_prompt_messages(messages: list[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Serialize list[PromptMessage] to list of dicts.
|
||||
"""
|
||||
return [{"role": msg.role.value, "content": msg.content} for msg in messages]
|
||||
Reference in New Issue
Block a user