mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-31 21:25:59 +08:00
Compare commits
3 Commits
feat/point
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| cd45f42a83 | |||
| 81aa5a38b2 | |||
| ea73d3b2ea |
@ -14,6 +14,7 @@ from torchvision import transforms
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.quant_ops
|
||||
|
||||
|
||||
# ---------------------- Feed Forward Network -----------------------
|
||||
|
||||
@ -5,6 +5,7 @@ import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import comfy.quant_ops
|
||||
|
||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0
|
||||
@ -19,15 +20,6 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
rot_dim = freqs_cis.shape[-1]
|
||||
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
|
||||
cos_ = freqs_cis[0]
|
||||
sin_ = freqs_cis[1]
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
x_rotated = torch.cat((-x2, x1), dim=-1)
|
||||
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
|
||||
|
||||
class ErnieImageEmbedND3(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: tuple):
|
||||
super().__init__()
|
||||
@ -37,8 +29,16 @@ class ErnieImageEmbedND3(nn.Module):
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
|
||||
emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2]
|
||||
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim]
|
||||
cos_ = emb[0]
|
||||
sin_ = emb[1]
|
||||
N = cos_.shape[-1]
|
||||
half = N // 2
|
||||
cos_top = cos_[..., :half].repeat_interleave(2, dim=-1)
|
||||
sin_top = sin_[..., :half].repeat_interleave(2, dim=-1)
|
||||
cos_bot = cos_[..., half:].repeat_interleave(2, dim=-1)
|
||||
sin_bot = sin_[..., half:].repeat_interleave(2, dim=-1)
|
||||
rot = torch.stack([cos_top, -sin_top, sin_bot, cos_bot], dim=-1)
|
||||
return rot.reshape(*rot.shape[:-1], 2, 2).unsqueeze(2)
|
||||
|
||||
class ErnieImagePatchEmbedDynamic(nn.Module):
|
||||
def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None):
|
||||
@ -115,8 +115,7 @@ class ErnieImageAttention(nn.Module):
|
||||
key = self.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
query, key = comfy.quant_ops.ck.apply_rope_split_half(query, key, image_rotary_emb)
|
||||
|
||||
q_flat = query.reshape(B, S, -1)
|
||||
k_flat = key.reshape(B, S, -1)
|
||||
@ -274,7 +273,7 @@ class ErnieImageModel(nn.Module):
|
||||
|
||||
image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1)
|
||||
|
||||
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype)
|
||||
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
|
||||
del image_ids, text_ids
|
||||
|
||||
sample = self.time_proj(timesteps).to(dtype)
|
||||
|
||||
@ -1716,6 +1716,13 @@ def is_device_xpu(device):
|
||||
def is_device_cuda(device):
|
||||
return is_device_type(device, 'cuda')
|
||||
|
||||
def set_torch_device(device):
|
||||
"""Set the current device for the given torch device. Supports CUDA and XPU."""
|
||||
if is_device_cuda(device):
|
||||
torch.cuda.set_device(device)
|
||||
elif is_device_xpu(device):
|
||||
torch.xpu.set_device(device)
|
||||
|
||||
def is_directml_enabled():
|
||||
global directml_enabled
|
||||
if directml_enabled:
|
||||
|
||||
@ -17,7 +17,7 @@ class MultiGPUThreadPool:
|
||||
"""Persistent thread pool for multi-GPU work distribution.
|
||||
|
||||
Maintains one worker thread per extra GPU device. Each thread calls
|
||||
torch.cuda.set_device() once at startup so that compiled kernel caches
|
||||
set_torch_device() once at startup so that compiled kernel caches
|
||||
(inductor/triton) stay warm across diffusion steps.
|
||||
"""
|
||||
|
||||
@ -37,7 +37,7 @@ class MultiGPUThreadPool:
|
||||
|
||||
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
comfy.model_management.set_torch_device(device)
|
||||
except Exception as e:
|
||||
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
|
||||
while True:
|
||||
|
||||
@ -464,10 +464,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
|
||||
|
||||
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||
try:
|
||||
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
|
||||
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
|
||||
# XPU/NPU/MPS/CPU/DirectML backends.
|
||||
torch.cuda.set_device(device)
|
||||
comfy.model_management.set_torch_device(device)
|
||||
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||
# run every hooked_to_run separately
|
||||
with torch.no_grad():
|
||||
|
||||
@ -182,61 +182,6 @@ class Preview3DAdvanced(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class PreviewPointCloudGaussianSplat(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="PreviewPointCloudGaussianSplat",
|
||||
display_name="Preview Point Cloud & Gaussian Splat",
|
||||
category="3d",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
search_aliases=[
|
||||
"view 3d",
|
||||
"preview 3d",
|
||||
"3d viewer",
|
||||
"view point cloud",
|
||||
"view pointcloud",
|
||||
"view splat",
|
||||
"view gaussian",
|
||||
"view gaussian splat",
|
||||
"preview gaussian",
|
||||
"preview gaussian splat",
|
||||
"preview point cloud",
|
||||
"preview pointcloud",
|
||||
"view 3dgs",
|
||||
"preview 3dgs",
|
||||
"preview ply",
|
||||
"preview spz",
|
||||
"preview ksplat",
|
||||
],
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_file",
|
||||
types=[
|
||||
IO.File3DPLY,
|
||||
IO.File3DSPLAT,
|
||||
IO.File3DSPZ,
|
||||
IO.File3DKSPLAT,
|
||||
],
|
||||
tooltip="Point cloud or 3DGS file (.ply / .spz / .splat / .ksplat)",
|
||||
),
|
||||
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DAny.Output(display_name="model_file"),
|
||||
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_file: Types.File3D, **kwargs) -> IO.NodeOutput:
|
||||
filename = f"preview3d_{uuid.uuid4().hex}.{model_file.format}"
|
||||
model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename))
|
||||
camera_info = kwargs.get("camera_info", None)
|
||||
return IO.NodeOutput(model_file, camera_info, ui=UI.PreviewUI3D(filename, camera_info))
|
||||
|
||||
|
||||
class Load3DExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -244,7 +189,6 @@ class Load3DExtension(ComfyExtension):
|
||||
Load3D,
|
||||
Preview3D,
|
||||
Preview3DAdvanced,
|
||||
PreviewPointCloudGaussianSplat,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.44.19
|
||||
comfyui-workflow-templates==0.9.91
|
||||
comfyui-embedded-docs==0.5.1
|
||||
comfyui-embedded-docs==0.5.2
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
||||
Reference in New Issue
Block a user