Simplify Hunyuan 3D 2.1 swap_cfg_halves gate to a shape check

The previous gate (len(cond_or_uncond) == 2 and set == {0, 1}) was
intended to skip the cond/uncond swap when only one half was present
under MultiGPU CFG Split, but it was too restrictive: it also skipped
batch_size > 1 + CFG (cond_or_uncond like [0, 0, 1, 1] or [0,0,0,0,
1,1,1,1]), where chunk(2) still splits the batch cleanly into a cond
half and an uncond half and the swap is still required.

Switch to context.shape[0] >= 2, matching the parallel fix landed on
master in #13699. The swap is a permutation-invariant no-op when the
two halves don't form a CFG pair (since the output swap_cfg_halves
block immediately undoes the permutation), so the only thing the gate
actually needs to do is guard against chunk(2) on a batch of one.

Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Kosinkadink
2026-05-21 12:14:02 -07:00
parent 822a3ecf73
commit 019261ed96

View File

@ -608,8 +608,7 @@ class HunYuanDiTPlain(nn.Module):
x = x.movedim(-1, -2)
cond_or_uncond = transformer_options.get("cond_or_uncond", [])
swap_cfg_halves = len(cond_or_uncond) == 2 and set(cond_or_uncond) == {0, 1}
swap_cfg_halves = context.shape[0] >= 2
if swap_cfg_halves:
first_half, second_half = context.chunk(2, dim = 0)