Compare commits

..

3 Commits

7 changed files with 25 additions and 77 deletions

View File

@ -14,6 +14,7 @@ from torchvision import transforms
import comfy.patcher_extension
from comfy.ldm.modules.attention import optimized_attention
import comfy.ldm.common_dit
import comfy.quant_ops
# ---------------------- Feed Forward Network -----------------------

View File

@ -5,6 +5,7 @@ import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
import comfy.quant_ops
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
@ -19,15 +20,6 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
return out.to(dtype=torch.float32, device=pos.device)
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
rot_dim = freqs_cis.shape[-1]
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
cos_ = freqs_cis[0]
sin_ = freqs_cis[1]
x1, x2 = x.chunk(2, dim=-1)
x_rotated = torch.cat((-x2, x1), dim=-1)
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
class ErnieImageEmbedND3(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: tuple):
super().__init__()
@ -37,8 +29,16 @@ class ErnieImageEmbedND3(nn.Module):
def forward(self, ids: torch.Tensor) -> torch.Tensor:
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2]
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim]
cos_ = emb[0]
sin_ = emb[1]
N = cos_.shape[-1]
half = N // 2
cos_top = cos_[..., :half].repeat_interleave(2, dim=-1)
sin_top = sin_[..., :half].repeat_interleave(2, dim=-1)
cos_bot = cos_[..., half:].repeat_interleave(2, dim=-1)
sin_bot = sin_[..., half:].repeat_interleave(2, dim=-1)
rot = torch.stack([cos_top, -sin_top, sin_bot, cos_bot], dim=-1)
return rot.reshape(*rot.shape[:-1], 2, 2).unsqueeze(2)
class ErnieImagePatchEmbedDynamic(nn.Module):
def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None):
@ -115,8 +115,7 @@ class ErnieImageAttention(nn.Module):
key = self.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query, key = comfy.quant_ops.ck.apply_rope_split_half(query, key, image_rotary_emb)
q_flat = query.reshape(B, S, -1)
k_flat = key.reshape(B, S, -1)
@ -274,7 +273,7 @@ class ErnieImageModel(nn.Module):
image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1)
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype)
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
del image_ids, text_ids
sample = self.time_proj(timesteps).to(dtype)

View File

@ -1716,6 +1716,13 @@ def is_device_xpu(device):
def is_device_cuda(device):
return is_device_type(device, 'cuda')
def set_torch_device(device):
"""Set the current device for the given torch device. Supports CUDA and XPU."""
if is_device_cuda(device):
torch.cuda.set_device(device)
elif is_device_xpu(device):
torch.xpu.set_device(device)
def is_directml_enabled():
global directml_enabled
if directml_enabled:

View File

@ -17,7 +17,7 @@ class MultiGPUThreadPool:
"""Persistent thread pool for multi-GPU work distribution.
Maintains one worker thread per extra GPU device. Each thread calls
torch.cuda.set_device() once at startup so that compiled kernel caches
set_torch_device() once at startup so that compiled kernel caches
(inductor/triton) stay warm across diffusion steps.
"""
@ -37,7 +37,7 @@ class MultiGPUThreadPool:
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
try:
torch.cuda.set_device(device)
comfy.model_management.set_torch_device(device)
except Exception as e:
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
while True:

View File

@ -464,10 +464,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
try:
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
# XPU/NPU/MPS/CPU/DirectML backends.
torch.cuda.set_device(device)
comfy.model_management.set_torch_device(device)
model_current: BaseModel = model_options["multigpu_clones"][device].model
# run every hooked_to_run separately
with torch.no_grad():

View File

@ -182,61 +182,6 @@ class Preview3DAdvanced(IO.ComfyNode):
)
class PreviewPointCloudGaussianSplat(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="PreviewPointCloudGaussianSplat",
display_name="Preview Point Cloud & Gaussian Splat",
category="3d",
is_experimental=True,
is_output_node=True,
search_aliases=[
"view 3d",
"preview 3d",
"3d viewer",
"view point cloud",
"view pointcloud",
"view splat",
"view gaussian",
"view gaussian splat",
"preview gaussian",
"preview gaussian splat",
"preview point cloud",
"preview pointcloud",
"view 3dgs",
"preview 3dgs",
"preview ply",
"preview spz",
"preview ksplat",
],
inputs=[
IO.MultiType.Input(
"model_file",
types=[
IO.File3DPLY,
IO.File3DSPLAT,
IO.File3DSPZ,
IO.File3DKSPLAT,
],
tooltip="Point cloud or 3DGS file (.ply / .spz / .splat / .ksplat)",
),
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
],
outputs=[
IO.File3DAny.Output(display_name="model_file"),
IO.Load3DCamera.Output(display_name="camera_info"),
],
)
@classmethod
def execute(cls, model_file: Types.File3D, **kwargs) -> IO.NodeOutput:
filename = f"preview3d_{uuid.uuid4().hex}.{model_file.format}"
model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename))
camera_info = kwargs.get("camera_info", None)
return IO.NodeOutput(model_file, camera_info, ui=UI.PreviewUI3D(filename, camera_info))
class Load3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -244,7 +189,6 @@ class Load3DExtension(ComfyExtension):
Load3D,
Preview3D,
Preview3DAdvanced,
PreviewPointCloudGaussianSplat,
]

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.44.19
comfyui-workflow-templates==0.9.91
comfyui-embedded-docs==0.5.1
comfyui-embedded-docs==0.5.2
torch
torchsde
torchvision