feat(workflow): enhance group node functionality with head and leaf node tracking

- Added headNodeIds and leafNodeIds to GroupNodeData to track nodes that receive input and send output outside the group.
- Updated useNodesInteractions hook to include headNodeIds in the group node data.
- Modified isValidConnection logic in useWorkflow to validate connections based on leaf node types for group nodes.
- Enhanced preprocessNodesAndEdges to rebuild temporary edges for group nodes, connecting them to external nodes for visual representation.
This commit is contained in:
zhsama
2026-01-04 20:45:42 +08:00
parent 39010fd153
commit 8834e6e531
4 changed files with 91 additions and 10 deletions

View File

@ -4,7 +4,6 @@ import type {
import type { IterationNodeType } from '../nodes/iteration/types'
import type { LoopNodeType } from '../nodes/loop/types'
import type {
BlockEnum,
Edge,
Node,
ValueSelector,
@ -28,14 +27,12 @@ import {
} from '../constants'
import { findUsedVarNodes, getNodeOutputVars, updateNodeVars } from '../nodes/_base/components/variable/utils'
import { CUSTOM_NOTE_NODE } from '../note-node/constants'
import {
useStore,
useWorkflowStore,
} from '../store'
import {
WorkflowRunningStatus,
} from '../types'
import { BlockEnum, WorkflowRunningStatus } from '../types'
import {
getWorkflowEntryNode,
isWorkflowEntryNode,
@ -381,7 +378,7 @@ export const useWorkflow = () => {
return startNodes
}, [nodesMap, getRootNodesById])
const isValidConnection = useCallback(({ source, sourceHandle: _sourceHandle, target }: Connection) => {
const isValidConnection = useCallback(({ source, sourceHandle, target }: Connection) => {
const {
edges,
getNodes,
@ -396,14 +393,27 @@ export const useWorkflow = () => {
if (sourceNode.parentId !== targetNode.parentId)
return false
// For Group nodes, use the leaf node's type for validation
// sourceHandle format: "${leafNodeId}-${originalSourceHandle}"
let actualSourceType = sourceNode.data.type
if (sourceNode.data.type === BlockEnum.Group && sourceHandle) {
const lastDashIndex = sourceHandle.lastIndexOf('-')
if (lastDashIndex > 0) {
const leafNodeId = sourceHandle.substring(0, lastDashIndex)
const leafNode = nodes.find(node => node.id === leafNodeId)
if (leafNode)
actualSourceType = leafNode.data.type
}
}
if (sourceNode && targetNode) {
const sourceNodeAvailableNextNodes = getAvailableBlocks(sourceNode.data.type, !!sourceNode.parentId).availableNextBlocks
const sourceNodeAvailableNextNodes = getAvailableBlocks(actualSourceType, !!sourceNode.parentId).availableNextBlocks
const targetNodeAvailablePrevNodes = getAvailableBlocks(targetNode.data.type, !!targetNode.parentId).availablePrevBlocks
if (!sourceNodeAvailableNextNodes.includes(targetNode.data.type))
return false
if (!targetNodeAvailablePrevNodes.includes(sourceNode.data.type))
if (!targetNodeAvailablePrevNodes.includes(actualSourceType))
return false
}