mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-21 00:46:41 +08:00
Compare commits
15 Commits
v0.21.1
...
hiring-lin
| Author | SHA1 | Date | |
|---|---|---|---|
| 2288116447 | |||
| 5d5a4554e1 | |||
| 33ce449c8b | |||
| 04856acc69 | |||
| 77e2ed5e01 | |||
| b2000029c8 | |||
| b112f68681 | |||
| ed78da062c | |||
| 616cab4f97 | |||
| 4f6018982d | |||
| 7a063e83a7 | |||
| 3f9bdc70ee | |||
| 3d870ff51f | |||
| 1f28908d6e | |||
| fb51a988b6 |
@ -429,6 +429,8 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
|
||||
|
||||
See also: [https://www.comfy.org/](https://www.comfy.org/)
|
||||
|
||||
> _psst — we're hiring!_ Help build ComfyUI: [comfy.org/careers](https://www.comfy.org/careers)
|
||||
|
||||
## Frontend Development
|
||||
|
||||
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
|
||||
|
||||
44
SECURITY.md
Normal file
44
SECURITY.md
Normal file
@ -0,0 +1,44 @@
|
||||
# Security Policy
|
||||
|
||||
## Scope
|
||||
|
||||
ComfyUI is designed to run locally. By default, the server binds to `127.0.0.1`, meaning only the user's own machine can reach it. Our threat model assumes:
|
||||
|
||||
- The user installed ComfyUI through a supported channel: the desktop application, the portable build, or a manual install following the README.
|
||||
- The user has not installed untrusted custom nodes. Custom nodes are arbitrary Python code and are trusted as much as any other software the user chooses to install.
|
||||
- Anyone with access to the ComfyUI URL is trusted (a direct consequence of the localhost-only default).
|
||||
- PyTorch and other dependencies are at the versions we ship or recommend in the README.
|
||||
|
||||
A report is in scope only if it affects a user operating within this threat model.
|
||||
|
||||
## What We Consider a Vulnerability
|
||||
|
||||
We want to hear about issues where a **reasonable user** — someone who does not install random untrusted nodes and who reads UI prompts and warnings before clicking through them — can be harmed by ComfyUI itself.
|
||||
|
||||
The clearest example: a workflow file that such a user might plausibly load and run, using only built-in nodes, that results in **untrusted code execution, arbitrary file read/write outside expected directories, or credential/data exfiltration**.
|
||||
|
||||
When submitting a report, please include a clear description of *why this is a problem for a typical local ComfyUI user*. Reports without this context are difficult to act on.
|
||||
|
||||
## What We Do Not Consider a Security Vulnerability
|
||||
|
||||
Please report the following through our regular [GitHub issues](https://github.com/comfyanonymous/ComfyUI/issues) instead. Filing them as security reports will likely cause them to be deprioritized or closed.
|
||||
|
||||
- **Issues requiring `--listen` or any non-default network exposure.** ComfyUI binds to localhost by default. If a remote attacker needs to reach the server for the attack to work, the user has chosen to expose it and is responsible for securing that deployment (firewall, reverse proxy, authentication, etc.). These are bugs, not vulnerabilities.
|
||||
- **`torch.load` and related deserialization issues in old PyTorch versions.** These are upstream PyTorch issues. Our distributions ship with — and our documentation recommends — recent PyTorch versions where these are addressed.
|
||||
- **Vulnerabilities that depend on outdated library versions** that we neither ship nor recommend (e.g., requiring PyTorch 2.6 or older).
|
||||
- **Issues that require a specific custom node to be installed.** Custom nodes are third-party code. Report these to the maintainer of that node.
|
||||
- **Crashes, hangs, or resource exhaustion from a loaded workflow.** Annoying, but not a security issue in our model. File a regular bug.
|
||||
- **Social-engineering scenarios** where the user is expected to ignore an explicit UI warning or prompt.
|
||||
|
||||
## Reporting
|
||||
|
||||
If you believe you have found an issue that falls within the scope above, please report it privately via GitHub's [Report a vulnerability](https://github.com/comfyanonymous/ComfyUI/security/advisories/new) feature rather than opening a public issue.
|
||||
|
||||
Please include:
|
||||
|
||||
1. A description of the vulnerability and the affected component.
|
||||
2. Reproduction steps, ideally with a minimal workflow file or proof-of-concept.
|
||||
3. The ComfyUI version, install method (desktop / portable / manual), and OS.
|
||||
4. An explanation of how this affects a typical local user as described in the threat model.
|
||||
|
||||
We will acknowledge valid reports and coordinate a fix and disclosure timeline with you.
|
||||
@ -38,40 +38,54 @@ def is_valid_version(version: str) -> bool:
|
||||
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
|
||||
return bool(re.match(pattern, version))
|
||||
|
||||
def get_installed_frontend_version():
|
||||
"""Get the currently installed frontend package version."""
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
return frontend_version_str
|
||||
|
||||
|
||||
def get_required_frontend_version():
|
||||
return get_required_packages_versions().get("comfyui-frontend-package", None)
|
||||
|
||||
|
||||
def check_frontend_version():
|
||||
"""Check if the frontend version is up to date."""
|
||||
COMFY_PACKAGE_VERSIONS = []
|
||||
def get_comfy_package_versions():
|
||||
"""List installed/required versions for every comfy* package in requirements.txt."""
|
||||
if COMFY_PACKAGE_VERSIONS:
|
||||
return COMFY_PACKAGE_VERSIONS.copy()
|
||||
out = COMFY_PACKAGE_VERSIONS
|
||||
for name, required in (get_required_packages_versions() or {}).items():
|
||||
if not name.startswith("comfy"):
|
||||
continue
|
||||
try:
|
||||
installed = version(name)
|
||||
except Exception:
|
||||
installed = None
|
||||
out.append({"name": name, "installed": installed, "required": required})
|
||||
return out.copy()
|
||||
|
||||
try:
|
||||
frontend_version_str = get_installed_frontend_version()
|
||||
frontend_version = parse_version(frontend_version_str)
|
||||
required_frontend_str = get_required_frontend_version()
|
||||
required_frontend = parse_version(required_frontend_str)
|
||||
if frontend_version < required_frontend:
|
||||
|
||||
def check_comfy_packages_versions():
|
||||
"""Warn for every comfy* package whose installed version is below requirements.txt."""
|
||||
from packaging.version import InvalidVersion, parse as parse_pep440
|
||||
for pkg in get_comfy_package_versions():
|
||||
installed_str = pkg["installed"]
|
||||
required_str = pkg["required"]
|
||||
if not installed_str or not required_str:
|
||||
continue
|
||||
try:
|
||||
outdated = parse_pep440(installed_str) < parse_pep440(required_str)
|
||||
except InvalidVersion as e:
|
||||
logging.error(f"Failed to check {pkg['name']} version: {e}")
|
||||
continue
|
||||
if outdated:
|
||||
app.logger.log_startup_warning(
|
||||
f"""
|
||||
________________________________________________________________________
|
||||
WARNING WARNING WARNING WARNING WARNING
|
||||
|
||||
Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
|
||||
Installed {pkg["name"]} version {installed_str} is lower than the recommended version {required_str}.
|
||||
|
||||
{frontend_install_warning_message()}
|
||||
{get_missing_requirements_message()}
|
||||
________________________________________________________________________
|
||||
""".strip()
|
||||
)
|
||||
else:
|
||||
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check frontend version: {e}")
|
||||
logging.info("{} version: {}".format(pkg["name"], installed_str))
|
||||
|
||||
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
@ -201,6 +215,11 @@ class FrontendManager:
|
||||
def get_required_templates_version(cls) -> str:
|
||||
return get_required_packages_versions().get("comfyui-workflow-templates", None)
|
||||
|
||||
@classmethod
|
||||
def get_comfy_package_versions(cls):
|
||||
"""List installed/required versions for every comfy* package in requirements.txt."""
|
||||
return get_comfy_package_versions()
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
try:
|
||||
@ -341,7 +360,7 @@ comfyui-workflow-templates is not installed.
|
||||
main error source might be request timeout or invalid URL.
|
||||
"""
|
||||
if version_string == DEFAULT_VERSION_STRING:
|
||||
check_frontend_version()
|
||||
check_comfy_packages_versions()
|
||||
return cls.default_frontend_path()
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
@ -403,7 +422,7 @@ comfyui-workflow-templates is not installed.
|
||||
except Exception as e:
|
||||
logging.error("Failed to initialize frontend: %s", e)
|
||||
logging.info("Falling back to the default frontend.")
|
||||
check_frontend_version()
|
||||
check_comfy_packages_versions()
|
||||
return cls.default_frontend_path()
|
||||
@classmethod
|
||||
def template_asset_handler(cls):
|
||||
|
||||
@ -141,8 +141,7 @@ manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", he
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
||||
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
||||
vram_group.add_argument("--lowvram", action="store_true", help="Doesn't do anything if dynamic vram is enabled. If dynamic vram isn't being used this option makes the text encoders run on the CPU.")
|
||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||
|
||||
|
||||
@ -106,6 +106,7 @@ class Dino2Encoder(torch.nn.Module):
|
||||
class Dino2PatchEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.projection = operations.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=dim,
|
||||
@ -125,17 +126,37 @@ class Dino2Embeddings(torch.nn.Module):
|
||||
super().__init__()
|
||||
patch_size = 14
|
||||
image_size = 518
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
||||
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
||||
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
|
||||
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) # mask_token is a pre-training param, kept only so strict loading accepts the key.
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
|
||||
def interpolate_pos_encoding(self, x, h_pixels, w_pixels):
|
||||
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32)
|
||||
|
||||
class_pos = pos_embed[:, 0:1]
|
||||
patch_pos = pos_embed[:, 1:]
|
||||
N = patch_pos.shape[1]
|
||||
M = int(N ** 0.5)
|
||||
h0 = h_pixels // self.patch_size
|
||||
w0 = w_pixels // self.patch_size
|
||||
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
|
||||
|
||||
patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2)
|
||||
patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False)
|
||||
patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
x = self.patch_embeddings(pixel_values)
|
||||
# TODO: mask_token?
|
||||
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||
if x.shape[1] - 1 == self.position_embeddings.shape[1] - 1:
|
||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||
else:
|
||||
h, w = pixel_values.shape[-2:]
|
||||
x = x + self.interpolate_pos_encoding(x, h, w)
|
||||
return x
|
||||
|
||||
|
||||
@ -158,3 +179,21 @@ class Dinov2Model(torch.nn.Module):
|
||||
x = self.layernorm(x)
|
||||
pooled_output = x[:, 0, :]
|
||||
return x, i, pooled_output, None
|
||||
|
||||
def get_intermediate_layers(self, pixel_values, indices, apply_norm=True):
|
||||
x = self.embeddings(pixel_values)
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
n_layers = len(self.encoder.layer)
|
||||
resolved = [(i if i >= 0 else n_layers + i) for i in indices]
|
||||
target = set(resolved)
|
||||
max_idx = max(resolved)
|
||||
n_skip = 1 # skip cls token
|
||||
cache = {}
|
||||
for i, layer in enumerate(self.encoder.layer):
|
||||
x = layer(x, optimized_attention)
|
||||
if i in target:
|
||||
normed = self.layernorm(x) if apply_norm else x
|
||||
cache[i] = (normed[:, n_skip:], normed[:, 0])
|
||||
if i >= max_idx:
|
||||
break
|
||||
return [cache[i] for i in resolved]
|
||||
|
||||
@ -22,26 +22,25 @@ class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int, per_frame: bool = False):
|
||||
"""
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
|
||||
tensor: [batch, num_tokens, feature_dim] (per-token, default) or
|
||||
[batch, num_frames, feature_dim] (per_frame=True, already compressed).
|
||||
patches_per_frame: spatial patches per frame; pass None to disable compression.
|
||||
"""
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.batch_size, n, self.feature_dim = tensor.shape
|
||||
if per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
|
||||
# All patches in a frame are identical, so we only keep the first one
|
||||
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
|
||||
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
|
||||
self.num_frames = n
|
||||
self.data = tensor
|
||||
elif patches_per_frame is not None and n >= patches_per_frame and n % patches_per_frame == 0:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = n // patches_per_frame
|
||||
# All patches in a frame are identical — keep only the first.
|
||||
self.data = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)[:, :, 0, :].contiguous()
|
||||
else:
|
||||
# Not divisible or too small - store directly without compression
|
||||
self.patches_per_frame = 1
|
||||
self.num_frames = num_tokens
|
||||
self.num_frames = n
|
||||
self.data = tensor
|
||||
|
||||
def expand(self):
|
||||
@ -716,32 +715,35 @@ class LTXAVModel(LTXVModel):
|
||||
|
||||
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||
"""Prepare timestep embeddings."""
|
||||
# TODO: some code reuse is needed here.
|
||||
grid_mask = kwargs.get("grid_mask", None)
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
|
||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||
timestep_scaled.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||
v_patches_per_frame = None
|
||||
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
# Reshape to [batch_size, num_tokens, dim] and compress for storage
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||
# Used by compute_prompt_timestep and the audio cross-attention paths.
|
||||
timestep_scaled = (timestep[:, grid_mask] if grid_mask is not None else timestep) * self.timestep_scale_multiplier
|
||||
|
||||
# When patches in a frame share a timestep (no spatial mask), project one row per frame instead of one per token
|
||||
per_frame_path = v_patches_per_frame is not None and (timestep.numel() // batch_size) % v_patches_per_frame == 0
|
||||
if per_frame_path:
|
||||
per_frame = timestep.reshape(batch_size, -1, v_patches_per_frame)[:, :, 0]
|
||||
if grid_mask is not None:
|
||||
# All-or-nothing per frame when has_spatial_mask=False.
|
||||
per_frame = per_frame[:, grid_mask[::v_patches_per_frame]]
|
||||
ts_input = per_frame * self.timestep_scale_multiplier
|
||||
else:
|
||||
ts_input = timestep_scaled
|
||||
|
||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||
ts_input.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
|
||||
|
||||
v_prompt_timestep = compute_prompt_timestep(
|
||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||
|
||||
@ -358,6 +358,61 @@ def apply_split_rotary_emb(input_tensor, cos, sin):
|
||||
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
|
||||
|
||||
|
||||
class GuideAttentionMask:
|
||||
"""Holds the two per-group masks for LTXV guide self-attention.
|
||||
_attention_with_guide_mask splits queries into noisy and tracked-guide
|
||||
groups, so the largest mask is (1, 1, tracked_count, T).
|
||||
"""
|
||||
__slots__ = ("guide_start", "tracked_count", "noisy_mask", "tracked_mask")
|
||||
|
||||
def __init__(self, total_tokens, guide_start, tracked_count, tracked_weights):
|
||||
device = tracked_weights.device
|
||||
dtype = tracked_weights.dtype
|
||||
finfo = torch.finfo(dtype)
|
||||
|
||||
pos = tracked_weights > 0
|
||||
log_w = torch.full_like(tracked_weights, finfo.min)
|
||||
log_w[pos] = torch.log(tracked_weights[pos].clamp(min=finfo.tiny))
|
||||
|
||||
self.guide_start = guide_start
|
||||
self.tracked_count = tracked_count
|
||||
|
||||
self.noisy_mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype)
|
||||
self.noisy_mask[:, :, :, guide_start:guide_start + tracked_count] = log_w.view(1, 1, 1, -1)
|
||||
|
||||
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
|
||||
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||
|
||||
|
||||
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
|
||||
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
|
||||
groups, so each group needs only its own sub-mask. Avoids materializing
|
||||
the (1,1,T,T) dense mask.
|
||||
"""
|
||||
guide_start = guide_mask.guide_start
|
||||
tracked_end = guide_start + guide_mask.tracked_count
|
||||
|
||||
out = torch.empty_like(q)
|
||||
|
||||
if guide_start > 0: # In practice currently guides are always after noise, guard for safety if this changes.
|
||||
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
low_precision_attention=False, # sageattn mask support is unreliable
|
||||
)
|
||||
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
if tracked_end < q.shape[1]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes.
|
||||
out[:, tracked_end:, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, tracked_end:, :], k, v, heads,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -412,8 +467,10 @@ class CrossAttention(nn.Module):
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
elif isinstance(mask, GuideAttentionMask):
|
||||
out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
|
||||
# Apply per-head gating if enabled
|
||||
if self.to_gate_logits is not None:
|
||||
@ -1063,7 +1120,9 @@ class LTXVModel(LTXBaseModel):
|
||||
additional_args["resolved_guide_entries"] = resolved_entries
|
||||
|
||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
|
||||
if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
|
||||
# Total surviving guide tokens (all guides)
|
||||
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
|
||||
@ -1099,12 +1158,12 @@ class LTXVModel(LTXBaseModel):
|
||||
if not resolved_entries:
|
||||
return None
|
||||
|
||||
# Check if any attenuation is actually needed
|
||||
needs_attenuation = any(
|
||||
e["strength"] < 1.0 or e.get("pixel_mask") is not None
|
||||
# strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention.
|
||||
needs_mask = any(
|
||||
e["strength"] != 1.0 or e.get("pixel_mask") is not None
|
||||
for e in resolved_entries
|
||||
)
|
||||
if not needs_attenuation:
|
||||
if not needs_mask:
|
||||
return None
|
||||
|
||||
# Build per-guide-token weights for all tracked guide tokens.
|
||||
@ -1159,16 +1218,11 @@ class LTXVModel(LTXBaseModel):
|
||||
# Concatenate per-token weights for all tracked guides
|
||||
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
|
||||
|
||||
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
|
||||
if (tracked_weights >= 1.0).all():
|
||||
# Skip when every weight is exactly 1.0 (additive bias would be 0).
|
||||
if (tracked_weights == 1.0).all():
|
||||
return None
|
||||
|
||||
# Build the mask: guide tokens are at the end of the sequence.
|
||||
# Tracked guides come first (in order), untracked follow.
|
||||
return self._build_self_attention_mask(
|
||||
total_tokens, num_guide_tokens, total_tracked,
|
||||
tracked_weights, guide_start, device, dtype,
|
||||
)
|
||||
return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights)
|
||||
|
||||
@staticmethod
|
||||
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
|
||||
@ -1234,45 +1288,6 @@ class LTXVModel(LTXBaseModel):
|
||||
|
||||
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
|
||||
|
||||
@staticmethod
|
||||
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
|
||||
tracked_weights, guide_start, device, dtype):
|
||||
"""Build a log-space additive self-attention bias mask.
|
||||
|
||||
Attenuates attention between noisy tokens and tracked guide tokens.
|
||||
Untracked guide tokens (at the end of the guide portion) keep full attention.
|
||||
|
||||
Args:
|
||||
total_tokens: Total sequence length.
|
||||
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
|
||||
tracked_count: Number of tracked guide tokens (first in the guide portion).
|
||||
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
|
||||
guide_start: Index where guide tokens begin in the sequence.
|
||||
device: Target device.
|
||||
dtype: Target dtype.
|
||||
|
||||
Returns:
|
||||
(1, 1, total_tokens, total_tokens) additive bias mask.
|
||||
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
|
||||
"""
|
||||
finfo = torch.finfo(dtype)
|
||||
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
|
||||
tracked_end = guide_start + tracked_count
|
||||
|
||||
# Convert weights to log-space bias
|
||||
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
|
||||
log_w = torch.full_like(w, finfo.min)
|
||||
positive_mask = w > 0
|
||||
if positive_mask.any():
|
||||
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
|
||||
|
||||
# noisy → tracked guides: each noisy row gets the same per-guide weight
|
||||
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
|
||||
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
|
||||
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||
|
||||
return mask
|
||||
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
|
||||
"""Process transformer blocks for LTXV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
189
comfy/ldm/moge/geometry.py
Normal file
189
comfy/ldm/moge/geometry.py
Normal file
@ -0,0 +1,189 @@
|
||||
"""Pure-torch + scipy geometry helpers for MoGe inference and mesh export."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from scipy.optimize import least_squares
|
||||
|
||||
def normalized_view_plane_uv(width: int, height: int, aspect_ratio: Optional[float] = None,
|
||||
dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> torch.Tensor:
|
||||
"""Normalized view-plane UV coordinates with corners at +/-(W, H)/diagonal."""
|
||||
if aspect_ratio is None:
|
||||
aspect_ratio = width / height
|
||||
span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
|
||||
span_y = 1.0 / (1 + aspect_ratio ** 2) ** 0.5
|
||||
u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
|
||||
v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
|
||||
u, v = torch.meshgrid(u, v, indexing="xy")
|
||||
return torch.stack([u, v], dim=-1)
|
||||
|
||||
|
||||
def intrinsics_from_focal_center(fx: torch.Tensor, fy: torch.Tensor, cx: torch.Tensor, cy: torch.Tensor) -> torch.Tensor:
|
||||
"""Assemble (..., 3, 3) intrinsics from broadcastable fx, fy, cx, cy."""
|
||||
fx, fy, cx, cy = [torch.as_tensor(v) for v in (fx, fy, cx, cy)]
|
||||
fx, fy, cx, cy = torch.broadcast_tensors(fx, fy, cx, cy)
|
||||
zero = torch.zeros_like(fx)
|
||||
one = torch.ones_like(fx)
|
||||
return torch.stack([
|
||||
torch.stack([fx, zero, cx], dim=-1),
|
||||
torch.stack([zero, fy, cy], dim=-1),
|
||||
torch.stack([zero, zero, one], dim=-1),
|
||||
], dim=-2)
|
||||
|
||||
|
||||
def depth_map_to_point_map(depth: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
|
||||
"""Back-project a (..., H, W) depth map through K^-1 to (..., H, W, 3) camera-space points.
|
||||
|
||||
Intrinsics use normalized image coords (x in [0, 1] left->right, y in [0, 1] top->bottom).
|
||||
"""
|
||||
H, W = depth.shape[-2:]
|
||||
device, dtype = depth.device, depth.dtype
|
||||
u = (torch.arange(W, dtype=dtype, device=device) + 0.5) / W
|
||||
v = (torch.arange(H, dtype=dtype, device=device) + 0.5) / H
|
||||
grid_v, grid_u = torch.meshgrid(v, u, indexing="ij")
|
||||
pix = torch.stack([grid_u, grid_v, torch.ones_like(grid_u)], dim=-1)
|
||||
K_inv = torch.linalg.inv(intrinsics)
|
||||
rays = torch.einsum("...ij,hwj->...hwi", K_inv, pix)
|
||||
return rays * depth.unsqueeze(-1)
|
||||
|
||||
|
||||
def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray,
|
||||
focal: Optional[float] = None) -> Tuple[float, float]:
|
||||
"""LM-solve for z-shift; when focal is None, also recovers the optimal focal."""
|
||||
uv = uv.reshape(-1, 2)
|
||||
xy = xyz[..., :2].reshape(-1, 2)
|
||||
z = xyz[..., 2].reshape(-1)
|
||||
|
||||
def fn(shift):
|
||||
xy_proj = xy / (z + shift)[:, None]
|
||||
f = focal if focal is not None else (xy_proj * uv).sum() / np.square(xy_proj).sum()
|
||||
return (f * xy_proj - uv).ravel()
|
||||
|
||||
sol = least_squares(fn, x0=0.0, ftol=1e-3, method="lm")
|
||||
shift = float(np.asarray(sol["x"]).squeeze())
|
||||
if focal is None:
|
||||
xy_proj = xy / (z + shift)[:, None]
|
||||
focal = float((xy_proj * uv).sum() / np.square(xy_proj).sum())
|
||||
return shift, focal
|
||||
|
||||
|
||||
def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = None,
|
||||
focal: Optional[torch.Tensor] = None, downsample_size: Tuple[int, int] = (64, 64)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Recover the focal length and z-shift that turn points into a metric point map.
|
||||
|
||||
Optical center is at the image center; returned focal is relative to half the image diagonal.
|
||||
Returns (focal, shift) on the same device/dtype as points.
|
||||
"""
|
||||
shape = points.shape
|
||||
H, W = shape[-3], shape[-2]
|
||||
points_b = points.reshape(-1, H, W, 3)
|
||||
mask_b = None if mask is None else mask.reshape(-1, H, W)
|
||||
focal_b = None if focal is None else focal.reshape(-1)
|
||||
|
||||
uv = normalized_view_plane_uv(W, H, dtype=points.dtype, device=points.device)
|
||||
|
||||
points_lr = F.interpolate(points_b.permute(0, 3, 1, 2), downsample_size, mode="nearest").permute(0, 2, 3, 1)
|
||||
uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode="nearest").squeeze(0).permute(1, 2, 0)
|
||||
mask_lr = None
|
||||
if mask_b is not None:
|
||||
mask_lr = F.interpolate(mask_b.to(torch.float32).unsqueeze(1), downsample_size, mode="nearest").squeeze(1) > 0
|
||||
|
||||
uv_np = uv_lr.detach().cpu().numpy()
|
||||
points_np = points_lr.detach().cpu().numpy()
|
||||
mask_np = None if mask_lr is None else mask_lr.detach().cpu().numpy()
|
||||
focal_np = None if focal_b is None else focal_b.detach().cpu().numpy()
|
||||
|
||||
out_focal: list = []
|
||||
out_shift: list = []
|
||||
for i in range(points_b.shape[0]):
|
||||
if mask_np is None:
|
||||
xyz_i = points_np[i].reshape(-1, 3)
|
||||
uv_i = uv_np.reshape(-1, 2)
|
||||
else:
|
||||
sel = mask_np[i]
|
||||
if sel.sum() < 2:
|
||||
out_focal.append(1.0)
|
||||
out_shift.append(0.0)
|
||||
continue
|
||||
xyz_i = points_np[i][sel]
|
||||
uv_i = uv_np[sel]
|
||||
if focal_np is None:
|
||||
shift_i, focal_i = _solve_optimal_shift(uv_i, xyz_i)
|
||||
out_focal.append(focal_i)
|
||||
else:
|
||||
shift_i, _ = _solve_optimal_shift(uv_i, xyz_i, focal=float(focal_np[i]))
|
||||
out_shift.append(shift_i)
|
||||
|
||||
shift_t = torch.tensor(out_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
|
||||
if focal is None:
|
||||
focal_t = torch.tensor(out_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
|
||||
else:
|
||||
focal_t = focal.reshape(shape[:-3])
|
||||
return focal_t, shift_t
|
||||
|
||||
|
||||
def depth_map_edge(depth: torch.Tensor, atol: Optional[float] = None, rtol: Optional[float] = None, kernel_size: int = 3) -> torch.Tensor:
|
||||
"""Per-pixel boolean: True where the local depth window's max-min span exceeds atol or rtol*depth."""
|
||||
shape = depth.shape
|
||||
d = depth.reshape(-1, 1, *shape[-2:])
|
||||
pad = kernel_size // 2
|
||||
diff = F.max_pool2d(d, kernel_size, stride=1, padding=pad) + F.max_pool2d(-d, kernel_size, stride=1, padding=pad)
|
||||
edge = torch.zeros_like(d, dtype=torch.bool)
|
||||
if atol is not None:
|
||||
edge |= diff > atol
|
||||
if rtol is not None:
|
||||
edge |= (diff / d.clamp_min(1e-6)).nan_to_num_() > rtol
|
||||
return edge.reshape(*shape)
|
||||
|
||||
|
||||
def triangulate_grid_mesh(points: torch.Tensor, mask: Optional[torch.Tensor] = None, decimation: int = 1, discontinuity_threshold: float = 0.04,
|
||||
depth: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Triangulate a (H, W, 3) point map into (vertices, faces, uvs) on CPU.
|
||||
|
||||
Vertices: pixels with finite coords (passing optional mask). Quads with four valid corners
|
||||
become two triangles. depth overrides the scalar used for the rtol edge check; pass radial
|
||||
depth for panoramas (the default points[..., 2] goes negative below the equator).
|
||||
"""
|
||||
points = points.detach().cpu()
|
||||
finite = torch.isfinite(points).all(dim=-1)
|
||||
if mask is None:
|
||||
mask = finite
|
||||
else:
|
||||
mask = mask.detach().cpu().to(torch.bool) & finite
|
||||
|
||||
if discontinuity_threshold > 0:
|
||||
d = depth.detach().cpu() if depth is not None else points[..., 2]
|
||||
# Replace inf with 0 so max-pool doesn't poison neighbourhoods (mask above already excludes those pixels).
|
||||
d_finite = torch.where(finite, d, torch.zeros_like(d))
|
||||
edge = depth_map_edge(d_finite, rtol=discontinuity_threshold)
|
||||
mask = mask & ~edge
|
||||
|
||||
if decimation > 1:
|
||||
points = points[::decimation, ::decimation].contiguous()
|
||||
mask = mask[::decimation, ::decimation].contiguous()
|
||||
H, W = points.shape[:2]
|
||||
|
||||
flat_mask = mask.reshape(-1)
|
||||
idx = torch.full((H * W,), -1, dtype=torch.long)
|
||||
n_valid = int(flat_mask.sum().item())
|
||||
idx[flat_mask] = torch.arange(n_valid, dtype=torch.long)
|
||||
idx = idx.reshape(H, W)
|
||||
|
||||
vertices = points.reshape(-1, 3)[flat_mask].contiguous()
|
||||
|
||||
yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij")
|
||||
u = xx.float() / max(W - 1, 1)
|
||||
v = yy.float() / max(H - 1, 1)
|
||||
uvs = torch.stack([u, v], dim=-1).reshape(-1, 2)[flat_mask].contiguous()
|
||||
|
||||
a, b, c, d = idx[:-1, :-1], idx[:-1, 1:], idx[1:, 1:], idx[1:, :-1]
|
||||
quad_ok = (a >= 0) & (b >= 0) & (c >= 0) & (d >= 0)
|
||||
a, b, c, d = a[quad_ok], b[quad_ok], c[quad_ok], d[quad_ok]
|
||||
faces = torch.cat([torch.stack([a, b, c], dim=-1), torch.stack([a, c, d], dim=-1)], dim=0).contiguous()
|
||||
return vertices, faces, uvs
|
||||
347
comfy/ldm/moge/model.py
Normal file
347
comfy/ldm/moge/model.py
Normal file
@ -0,0 +1,347 @@
|
||||
"""MoGe v1 / v2 inference modules and a state-dict-driven builder.
|
||||
|
||||
V1: DINOv2 backbone + multi-output head (points, mask).
|
||||
V2: DINOv2 encoder + neck + per-output heads (points, mask, normal, optional metric-scale MLP).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
|
||||
from comfy.image_encoders.dino2 import Dinov2Model
|
||||
|
||||
from .geometry import depth_map_to_point_map, intrinsics_from_focal_center, recover_focal_shift
|
||||
from .modules import ConvStack, DINOv2Encoder, HeadV1, MLP, _view_plane_uv_grid
|
||||
|
||||
|
||||
def _remap_points(points: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply the exp remap: z -> exp(z), xy stays linear and gets scaled by the new z."""
|
||||
xy, z = points.split([2, 1], dim=-1)
|
||||
z = torch.exp(z)
|
||||
return torch.cat([xy * z, z], dim=-1)
|
||||
|
||||
|
||||
def _detect_dinov2(sd: dict, prefix: str) -> Dict[str, Any]:
|
||||
# All shipped MoGe checkpoints use plain DINOv2
|
||||
hidden = sd[prefix + "embeddings.cls_token"].shape[-1]
|
||||
layer_prefix = prefix + "encoder.layer."
|
||||
depth = 1 + max(int(k[len(layer_prefix):].split(".")[0]) for k in sd if k.startswith(layer_prefix))
|
||||
return {
|
||||
"hidden_size": hidden,
|
||||
"num_attention_heads": hidden // 64,
|
||||
"num_hidden_layers": depth,
|
||||
"layer_norm_eps": 1e-6,
|
||||
"use_swiglu_ffn": False,
|
||||
}
|
||||
|
||||
|
||||
class MoGeModelV1(nn.Module):
|
||||
"""MoGe v1: DINOv2 backbone + HeadV1 (points, mask)."""
|
||||
|
||||
image_mean: torch.Tensor
|
||||
image_std: torch.Tensor
|
||||
|
||||
intermediate_layers = 4
|
||||
num_tokens_range: Tuple[Number, Number] = (1200, 2500)
|
||||
mask_threshold = 0.5
|
||||
|
||||
def __init__(self, backbone: Dict[str, Any], dim_upsample: List[int] = (256, 128, 128),
|
||||
num_res_blocks: int = 1, dim_times_res_block_hidden: int = 1,
|
||||
dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
super().__init__()
|
||||
self.backbone = Dinov2Model(backbone, dtype, device, operations)
|
||||
self.head = HeadV1(dim_in=backbone["hidden_size"], dim_upsample=list(dim_upsample),
|
||||
num_res_blocks=num_res_blocks, dim_times_res_block_hidden=dim_times_res_block_hidden,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
||||
self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
||||
|
||||
def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
|
||||
H, W = image.shape[-2:]
|
||||
resize = ((num_tokens * 14 ** 2) / (H * W)) ** 0.5
|
||||
rh, rw = int(H * resize), int(W * resize)
|
||||
x = F.interpolate(image, (rh, rw), mode="bicubic", align_corners=False, antialias=True)
|
||||
x = (x - self.image_mean) / self.image_std
|
||||
x14 = F.interpolate(x, (rh // 14 * 14, rw // 14 * 14), mode="bilinear", align_corners=False, antialias=True)
|
||||
|
||||
n_layers = len(self.backbone.encoder.layer)
|
||||
indices = list(range(n_layers - self.intermediate_layers, n_layers))
|
||||
feats = self.backbone.get_intermediate_layers(x14, indices, apply_norm=True)
|
||||
|
||||
points, mask = self.head(feats, x)
|
||||
points = F.interpolate(points.float(), (H, W), mode="bilinear", align_corners=False)
|
||||
points = _remap_points(points.permute(0, 2, 3, 1))
|
||||
|
||||
mask = F.interpolate(mask.float(), (H, W), mode="bilinear", align_corners=False).squeeze(1)
|
||||
|
||||
return {"points": points, "mask": mask}
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
"""Detect the v1 head config from sd, build a model, and load weights."""
|
||||
n_up = 1 + max(int(k.split(".")[2]) for k in sd if k.startswith("head.upsample_blocks."))
|
||||
dim_upsample = [sd[f"head.upsample_blocks.{i}.0.0.weight"].shape[1] for i in range(n_up)]
|
||||
# Each upsample stage is Sequential[upsampler, *res_blocks]; count res blocks at level 0.
|
||||
num_res_blocks = max({int(k.split(".")[3]) for k in sd if k.startswith("head.upsample_blocks.0.")})
|
||||
hidden_out = sd["head.upsample_blocks.0.1.layers.2.weight"].shape[0]
|
||||
dim_times = max(hidden_out // dim_upsample[0], 1)
|
||||
model = cls(backbone=_detect_dinov2(sd, prefix="backbone."),
|
||||
dim_upsample=dim_upsample, num_res_blocks=num_res_blocks, dim_times_res_block_hidden=dim_times,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
model.load_state_dict(sd, strict=True)
|
||||
return model
|
||||
|
||||
|
||||
class MoGeModelV2(nn.Module):
|
||||
"""MoGe v2: DINOv2 encoder + neck + per-output heads (points/mask/normal/metric-scale)."""
|
||||
|
||||
intermediate_layers = 4
|
||||
num_tokens_range: Tuple[Number, Number] = (1200, 3600)
|
||||
|
||||
def __init__(self,
|
||||
encoder: Dict[str, Any],
|
||||
neck: Dict[str, Any],
|
||||
points_head: Dict[str, Any],
|
||||
mask_head: Dict[str, Any],
|
||||
scale_head: Dict[str, Any],
|
||||
normal_head: Optional[Dict[str, Any]] = None,
|
||||
dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
super().__init__()
|
||||
self.encoder = DINOv2Encoder(**encoder, dtype=dtype, device=device, operations=operations)
|
||||
self.neck = ConvStack(**neck, dtype=dtype, device=device, operations=operations)
|
||||
self.points_head = ConvStack(**points_head, dtype=dtype, device=device, operations=operations)
|
||||
self.mask_head = ConvStack(**mask_head, dtype=dtype, device=device, operations=operations)
|
||||
self.scale_head = MLP(**scale_head, dtype=dtype, device=device, operations=operations)
|
||||
if normal_head is not None:
|
||||
self.normal_head = ConvStack(**normal_head, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
|
||||
B, _, H, W = image.shape
|
||||
device, dtype = image.device, image.dtype
|
||||
aspect_ratio = W / H
|
||||
base_h = round((num_tokens / aspect_ratio) ** 0.5)
|
||||
base_w = round((num_tokens * aspect_ratio) ** 0.5)
|
||||
|
||||
feat_top, cls_token = self.encoder(image, base_h, base_w, return_class_token=True)
|
||||
|
||||
# 5-level pyramid: feat at level 0 concatenated with UV, other levels UV-only.
|
||||
levels = [_view_plane_uv_grid(B, base_h * (2 ** L), base_w * (2 ** L), aspect_ratio, dtype, device)
|
||||
for L in range(5)]
|
||||
levels[0] = torch.cat([feat_top, levels[0]], dim=1)
|
||||
|
||||
feats = self.neck(levels)
|
||||
|
||||
def _resize(v):
|
||||
return F.interpolate(v, (H, W), mode="bilinear", align_corners=False)
|
||||
|
||||
points = _remap_points(_resize(self.points_head(feats)[-1]).permute(0, 2, 3, 1))
|
||||
mask = _resize(self.mask_head(feats)[-1]).squeeze(1).sigmoid()
|
||||
metric_scale = self.scale_head(cls_token).squeeze(1).exp()
|
||||
|
||||
result = {"points": points, "mask": mask, "metric_scale": metric_scale}
|
||||
if hasattr(self, "normal_head"):
|
||||
normal = _resize(self.normal_head(feats)[-1])
|
||||
result["normal"] = F.normalize(normal.permute(0, 2, 3, 1), dim=-1)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
"""Detect the v2 encoder/neck/heads config from sd, build a model, and load weights."""
|
||||
backbone = _detect_dinov2(sd, prefix="encoder.backbone.")
|
||||
depth = backbone["num_hidden_layers"]
|
||||
n = cls.intermediate_layers
|
||||
encoder = {
|
||||
"backbone": backbone,
|
||||
"intermediate_layers": [(depth // n) * (i + 1) - 1 for i in range(n)],
|
||||
"dim_out": sd["encoder.output_projections.0.weight"].shape[0],
|
||||
}
|
||||
# scale_head is an MLP: Sequential of [Linear, ReLU, ..., Linear]; Linear weight is (out, in).
|
||||
scale_idxs = sorted({int(k.split(".")[1]) for k in sd if k.startswith("scale_head.")})
|
||||
scale_first = sd[f"scale_head.{scale_idxs[0]}.weight"]
|
||||
cfg: Dict[str, Any] = {
|
||||
"encoder": encoder,
|
||||
"neck": cls._detect_convstack(sd, "neck."),
|
||||
"points_head": cls._detect_convstack(sd, "points_head."),
|
||||
"mask_head": cls._detect_convstack(sd, "mask_head."),
|
||||
"scale_head": {"dims": [scale_first.shape[1]] + [sd[f"scale_head.{i}.weight"].shape[0] for i in scale_idxs]},
|
||||
}
|
||||
if any(k.startswith("normal_head.") for k in sd):
|
||||
cfg["normal_head"] = cls._detect_convstack(sd, "normal_head.")
|
||||
model = cls(**cfg, dtype=dtype, device=device, operations=operations)
|
||||
model.load_state_dict(sd, strict=True)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _detect_convstack(sd: dict, prefix: str) -> Dict[str, Any]:
|
||||
"""Reconstruct a ConvStack config from the keys under prefix"""
|
||||
in_keys = [k for k in sd if k.startswith(f"{prefix}input_blocks.") and k.endswith(".weight")]
|
||||
n = 1 + max(int(k[len(f"{prefix}input_blocks."):].split(".")[0]) for k in in_keys)
|
||||
|
||||
in_shapes = [sd[f"{prefix}input_blocks.{i}.weight"].shape for i in range(n)]
|
||||
has_out = lambda i: f"{prefix}output_blocks.{i}.weight" in sd
|
||||
has_norm = f"{prefix}res_blocks.0.0.layers.0.weight" in sd
|
||||
|
||||
def num_res_at(i):
|
||||
rb_prefix = f"{prefix}res_blocks.{i}."
|
||||
return len({int(k[len(rb_prefix):].split(".")[0]) for k in sd if k.startswith(rb_prefix)})
|
||||
|
||||
return {
|
||||
"dim_in": [s[1] for s in in_shapes],
|
||||
"dim_res_blocks": [s[0] for s in in_shapes],
|
||||
"dim_out": [sd[f"{prefix}output_blocks.{i}.weight"].shape[0] if has_out(i) else None for i in range(n)],
|
||||
"num_res_blocks": [num_res_at(i) for i in range(n)],
|
||||
"resamplers": ["conv_transpose" if f"{prefix}resamplers.{i}.0.weight" in sd else "bilinear"
|
||||
for i in range(n - 1)],
|
||||
"res_block_in_norm": "layer_norm" if has_norm else "none",
|
||||
"res_block_hidden_norm": "group_norm" if has_norm else "none",
|
||||
}
|
||||
|
||||
|
||||
# Translate the Meta-style DINOv2 keys MoGe ships to the naming ComfyUI DINOv2 port expects,
|
||||
# and split each fused qkv tensor into Q/K/V.
|
||||
_DINOV2_TOPLEVEL_RENAMES = {
|
||||
"patch_embed.proj.weight": "embeddings.patch_embeddings.projection.weight",
|
||||
"patch_embed.proj.bias": "embeddings.patch_embeddings.projection.bias",
|
||||
"cls_token": "embeddings.cls_token",
|
||||
"pos_embed": "embeddings.position_embeddings",
|
||||
"register_tokens": "embeddings.register_tokens",
|
||||
"mask_token": "embeddings.mask_token",
|
||||
"norm.weight": "layernorm.weight",
|
||||
"norm.bias": "layernorm.bias",
|
||||
}
|
||||
_DINOV2_BLOCK_RENAMES = [
|
||||
("ls1.gamma", "layer_scale1.lambda1"),
|
||||
("ls2.gamma", "layer_scale2.lambda1"),
|
||||
("attn.proj.", "attention.output.dense."),
|
||||
("mlp.w12.", "mlp.weights_in."),
|
||||
("mlp.w3.", "mlp.weights_out."),
|
||||
]
|
||||
|
||||
|
||||
def _remap_state_dict(sd: dict) -> dict:
|
||||
if "model" in sd and "model_config" in sd:
|
||||
sd = sd["model"]
|
||||
prefix = "encoder.backbone." if any(k.startswith("encoder.backbone.") for k in sd) else "backbone."
|
||||
out: dict = {}
|
||||
for k, v in sd.items():
|
||||
if not k.startswith(prefix):
|
||||
out[k] = v
|
||||
continue
|
||||
rel = k[len(prefix):]
|
||||
if rel in _DINOV2_TOPLEVEL_RENAMES:
|
||||
out[prefix + _DINOV2_TOPLEVEL_RENAMES[rel]] = v
|
||||
continue
|
||||
if not rel.startswith("blocks."):
|
||||
out[k] = v
|
||||
continue
|
||||
_, idx, sub = rel.split(".", 2)
|
||||
if sub in ("attn.qkv.weight", "attn.qkv.bias"):
|
||||
tail = sub.rsplit(".", 1)[1]
|
||||
q, kw, vw = v.chunk(3, dim=0)
|
||||
base = f"{prefix}encoder.layer.{idx}.attention.attention"
|
||||
out[f"{base}.query.{tail}"] = q
|
||||
out[f"{base}.key.{tail}"] = kw
|
||||
out[f"{base}.value.{tail}"] = vw
|
||||
continue
|
||||
for old, new in _DINOV2_BLOCK_RENAMES:
|
||||
sub = sub.replace(old, new)
|
||||
out[f"{prefix}encoder.layer.{idx}.{sub}"] = v
|
||||
return out
|
||||
|
||||
|
||||
def build_from_state_dict(sd: dict, dtype=None, device=None, operations=comfy.ops.manual_cast) -> nn.Module:
|
||||
"""Dispatch to v1 or v2 based on the DINOv2 backbone prefix."""
|
||||
sd = _remap_state_dict(sd)
|
||||
cls = MoGeModelV2 if any(k.startswith("encoder.backbone.") for k in sd) else MoGeModelV1
|
||||
return cls.from_state_dict(sd, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
|
||||
class MoGeModel:
|
||||
"""Loaded MoGe model + ComfyUI memory management."""
|
||||
|
||||
def __init__(self, state_dict: dict):
|
||||
# text encoder dtype closest match
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
|
||||
self.model = build_from_state_dict(state_dict, dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast).eval()
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.version = "v2" if hasattr(self.model, "encoder") else "v1"
|
||||
self.mask_threshold = float(getattr(self.model, "mask_threshold", 0.5))
|
||||
nt = getattr(self.model, "num_tokens_range", (1200, 2500 if self.version == "v1" else 3600))
|
||||
self.num_tokens_range = (int(nt[0]), int(nt[1]))
|
||||
|
||||
def infer(self, image: torch.Tensor, num_tokens: Optional[int] = None,
|
||||
resolution_level: int = 9, fov_x: Optional[Union[Number, torch.Tensor]] = None,
|
||||
force_projection: bool = True, apply_mask: bool = True,
|
||||
apply_metric_scale: bool = True
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Run a single MoGe forward + post-process pass. image is (B, 3, H, W) in [0, 1]."""
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
image = image.to(device=self.load_device, dtype=self.dtype)
|
||||
H, W = image.shape[-2:]
|
||||
aspect_ratio = W / H
|
||||
|
||||
if num_tokens is None:
|
||||
lo, hi = self.num_tokens_range
|
||||
num_tokens = int(lo + (resolution_level / 9) * (hi - lo))
|
||||
|
||||
out = self.model.forward(image, num_tokens=num_tokens)
|
||||
points = out["points"].float() # recover_focal_shift goes through scipy on CPU; needs fp32.
|
||||
mask_binary = out["mask"] > self.mask_threshold
|
||||
normal = out.get("normal")
|
||||
metric_scale = out.get("metric_scale")
|
||||
|
||||
diag = (1 + aspect_ratio ** 2) ** 0.5
|
||||
|
||||
def focal_from_fov_deg(deg):
|
||||
fov = torch.as_tensor(deg, device=points.device, dtype=points.dtype)
|
||||
return aspect_ratio / diag / torch.tan(torch.deg2rad(fov / 2))
|
||||
|
||||
if fov_x is None:
|
||||
focal, shift = recover_focal_shift(points, mask_binary)
|
||||
# Fall back to 60 deg FoV when the least-squares solver flips the focal sign.
|
||||
bad = ~torch.isfinite(focal) | (focal <= 0)
|
||||
if bool(bad.any()):
|
||||
focal = torch.where(bad, focal_from_fov_deg(60.0), focal)
|
||||
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
|
||||
else:
|
||||
focal = focal_from_fov_deg(fov_x).expand(points.shape[0])
|
||||
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
|
||||
|
||||
f_diag = focal / 2 * diag
|
||||
half = torch.tensor(0.5, device=points.device, dtype=points.dtype)
|
||||
intrinsics = intrinsics_from_focal_center(f_diag / aspect_ratio, f_diag, half, half)
|
||||
points[..., 2] = points[..., 2] + shift[..., None, None]
|
||||
# v2 only: filter mask by depth>0 to drop metric-scale negative-depth artifacts.
|
||||
if self.version == "v2":
|
||||
mask_binary = mask_binary & (points[..., 2] > 0)
|
||||
depth = points[..., 2].clone()
|
||||
|
||||
if force_projection:
|
||||
points = depth_map_to_point_map(depth, intrinsics=intrinsics)
|
||||
|
||||
if apply_metric_scale and metric_scale is not None:
|
||||
points = points * metric_scale[:, None, None, None]
|
||||
depth = depth * metric_scale[:, None, None]
|
||||
|
||||
if apply_mask:
|
||||
points = torch.where(mask_binary[..., None], points, torch.full_like(points, float("inf")))
|
||||
depth = torch.where(mask_binary, depth, torch.full_like(depth, float("inf")))
|
||||
if normal is not None:
|
||||
normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal))
|
||||
|
||||
result = {"points": points, "depth": depth, "intrinsics": intrinsics, "mask": mask_binary}
|
||||
if normal is not None:
|
||||
result["normal"] = normal
|
||||
return result
|
||||
204
comfy/ldm/moge/modules.py
Normal file
204
comfy/ldm/moge/modules.py
Normal file
@ -0,0 +1,204 @@
|
||||
"""Building blocks for MoGe: residual conv stack, resamplers, MLP, DINOv2 encoder, v1 head."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy.image_encoders.dino2 import Dinov2Model
|
||||
|
||||
from .geometry import normalized_view_plane_uv
|
||||
|
||||
|
||||
def _conv2d(operations, c_in: int, c_out: int, k: int = 3, *, dtype=None, device=None):
|
||||
return operations.Conv2d(c_in, c_out, kernel_size=k, padding=k // 2, padding_mode="replicate", dtype=dtype, device=device)
|
||||
|
||||
|
||||
def _view_plane_uv_grid(batch: int, height: int, width: int, aspect_ratio: float, dtype, device) -> torch.Tensor:
|
||||
"""Batched normalized view-plane UV grid as a (B, 2, H, W) tensor."""
|
||||
uv = normalized_view_plane_uv(width, height, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
|
||||
return uv.permute(2, 0, 1).unsqueeze(0).expand(batch, -1, -1, -1)
|
||||
|
||||
|
||||
def _concat_view_plane_uv(x: torch.Tensor, aspect_ratio: float) -> torch.Tensor:
|
||||
"""Append a 2-channel normalized view-plane UV grid to x along the channel dim."""
|
||||
uv = _view_plane_uv_grid(x.shape[0], x.shape[-2], x.shape[-1], aspect_ratio, x.dtype, x.device)
|
||||
return torch.cat([x, uv], dim=1)
|
||||
|
||||
|
||||
class ResidualConvBlock(nn.Module):
|
||||
def __init__(self, channels: int, hidden_channels: Optional[int] = None, in_norm: str = "layer_norm", hidden_norm: str = "group_norm",
|
||||
dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
super().__init__()
|
||||
hidden_channels = hidden_channels if hidden_channels is not None else channels
|
||||
|
||||
in_norm_layer = operations.GroupNorm(1, channels, dtype=dtype, device=device) if in_norm == "layer_norm" else nn.Identity()
|
||||
hidden_norm_layer = (operations.GroupNorm(max(hidden_channels // 32, 1), hidden_channels, dtype=dtype, device=device)
|
||||
if hidden_norm == "group_norm" else nn.Identity())
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
in_norm_layer, nn.ReLU(), _conv2d(operations, channels, hidden_channels, dtype=dtype, device=device),
|
||||
hidden_norm_layer, nn.ReLU(), _conv2d(operations, hidden_channels, channels, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x) + x
|
||||
|
||||
|
||||
class Resampler(nn.Sequential):
|
||||
"""2x upsampler: ConvTranspose2d(2x2) or bilinear upsample, followed by a 3x3 conv."""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, type_: str, dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
if type_ == "conv_transpose":
|
||||
up = operations.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, dtype=dtype, device=device)
|
||||
conv_in = out_channels
|
||||
else: # "bilinear"
|
||||
up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
|
||||
conv_in = in_channels
|
||||
super().__init__(up, _conv2d(operations, conv_in, out_channels, dtype=dtype, device=device))
|
||||
|
||||
|
||||
class MLP(nn.Sequential):
|
||||
def __init__(self, dims: Sequence[int], dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
layers = []
|
||||
for d_in, d_out in zip(dims[:-2], dims[1:-1]):
|
||||
layers.append(operations.Linear(d_in, d_out, dtype=dtype, device=device))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
layers.append(operations.Linear(dims[-2], dims[-1], dtype=dtype, device=device))
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
class ConvStack(nn.Module):
|
||||
def __init__(self, dim_in: List[Optional[int]], dim_res_blocks: List[int], dim_out: List[Optional[int]], resamplers: List[str],
|
||||
num_res_blocks: List[int], dim_times_res_block_hidden: int = 1, res_block_in_norm: str = "layer_norm", res_block_hidden_norm: str = "group_norm",
|
||||
dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
super().__init__()
|
||||
|
||||
self.input_blocks = nn.ModuleList([
|
||||
(_conv2d(operations, d_in, d_res, k=1, dtype=dtype, device=device)
|
||||
if d_in is not None else nn.Identity())
|
||||
for d_in, d_res in zip(dim_in, dim_res_blocks)
|
||||
])
|
||||
|
||||
self.resamplers = nn.ModuleList([
|
||||
Resampler(prev, succ, type_=r, dtype=dtype, device=device, operations=operations)
|
||||
for prev, succ, r in zip(dim_res_blocks[:-1], dim_res_blocks[1:], resamplers)
|
||||
])
|
||||
|
||||
self.res_blocks = nn.ModuleList([
|
||||
nn.Sequential(*[
|
||||
ResidualConvBlock(d_res, dim_times_res_block_hidden * d_res, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_res_blocks[i])
|
||||
])
|
||||
for i, d_res in enumerate(dim_res_blocks)
|
||||
])
|
||||
|
||||
self.output_blocks = nn.ModuleList([
|
||||
(_conv2d(operations, d_res, d_out, k=1, dtype=dtype, device=device)
|
||||
if d_out is not None else nn.Identity())
|
||||
for d_out, d_res in zip(dim_out, dim_res_blocks)
|
||||
])
|
||||
|
||||
def forward(self, in_features: List[Optional[torch.Tensor]]):
|
||||
out_features = []
|
||||
x = None
|
||||
for i in range(len(self.res_blocks)):
|
||||
feat = self.input_blocks[i](in_features[i]) if in_features[i] is not None else None
|
||||
if i == 0:
|
||||
x = feat
|
||||
elif feat is not None:
|
||||
x = x + feat
|
||||
x = self.res_blocks[i](x)
|
||||
out_features.append(self.output_blocks[i](x))
|
||||
if i < len(self.res_blocks) - 1:
|
||||
x = self.resamplers[i](x)
|
||||
return out_features
|
||||
|
||||
|
||||
class DINOv2Encoder(nn.Module):
|
||||
"""Comfy DINOv2 backbone with per-layer 1x1 projection heads."""
|
||||
|
||||
def __init__(self, backbone: dict, intermediate_layers: List[int], dim_out: int, dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
super().__init__()
|
||||
self.intermediate_layers = list(intermediate_layers)
|
||||
dim_features = backbone["hidden_size"]
|
||||
self.backbone = Dinov2Model(backbone, dtype, device, operations)
|
||||
self.output_projections = nn.ModuleList([
|
||||
_conv2d(operations, dim_features, dim_out, k=1, dtype=dtype, device=device)
|
||||
for _ in range(len(self.intermediate_layers))
|
||||
])
|
||||
self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
||||
self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
||||
|
||||
def forward(self, image: torch.Tensor, token_rows: int, token_cols: int,
|
||||
return_class_token: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=True)
|
||||
image_14 = (image_14 - self.image_mean) / self.image_std
|
||||
feats = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, apply_norm=True)
|
||||
x = torch.stack([
|
||||
proj(feat.permute(0, 2, 1).unflatten(2, (token_rows, token_cols)).contiguous())
|
||||
for proj, (feat, _cls) in zip(self.output_projections, feats)
|
||||
], dim=1).sum(dim=1)
|
||||
if return_class_token:
|
||||
return x, feats[-1][1]
|
||||
return x
|
||||
|
||||
|
||||
class HeadV1(nn.Module):
|
||||
"""v1 head: 4 backbone-feature projections -> shared upsample stack -> per-target output convs (points, mask)."""
|
||||
|
||||
NUM_FEATURES = 4
|
||||
DIM_PROJ = 512
|
||||
DIM_OUT = (3, 1) # 3 channels for points, 1 for mask
|
||||
LAST_CONV_CHANNELS = 32
|
||||
|
||||
def __init__(self, dim_in: int, dim_upsample: List[int] = (256, 128, 128), num_res_blocks: int = 1, dim_times_res_block_hidden: int = 1,
|
||||
dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
super().__init__()
|
||||
self.projects = nn.ModuleList([
|
||||
_conv2d(operations, dim_in, self.DIM_PROJ, k=1, dtype=dtype, device=device)
|
||||
for _ in range(self.NUM_FEATURES)
|
||||
])
|
||||
def upsampler(in_ch, out_ch):
|
||||
return nn.Sequential(
|
||||
operations.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2, dtype=dtype, device=device),
|
||||
_conv2d(operations, out_ch, out_ch, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
in_chs = [self.DIM_PROJ] + list(dim_upsample[:-1])
|
||||
self.upsample_blocks = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
upsampler(in_ch + 2, out_ch),
|
||||
*(ResidualConvBlock(out_ch, dim_times_res_block_hidden * out_ch, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_res_blocks))
|
||||
)
|
||||
for in_ch, out_ch in zip(in_chs, dim_upsample)
|
||||
])
|
||||
self.output_block = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
_conv2d(operations, dim_upsample[-1] + 2, self.LAST_CONV_CHANNELS, dtype=dtype, device=device),
|
||||
nn.ReLU(inplace=True),
|
||||
_conv2d(operations, self.LAST_CONV_CHANNELS, d_out, k=1, dtype=dtype, device=device),
|
||||
)
|
||||
for d_out in self.DIM_OUT
|
||||
])
|
||||
|
||||
def forward(self, hidden_states, image: torch.Tensor):
|
||||
img_h, img_w = image.shape[-2:]
|
||||
patch_h, patch_w = img_h // 14, img_w // 14
|
||||
aspect = img_w / img_h
|
||||
x = torch.stack([
|
||||
proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous())
|
||||
for proj, (feat, _cls) in zip(self.projects, hidden_states)
|
||||
], dim=1).sum(dim=1)
|
||||
|
||||
for block in self.upsample_blocks:
|
||||
x = block(_concat_view_plane_uv(x, aspect))
|
||||
|
||||
x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
|
||||
x = _concat_view_plane_uv(x, aspect)
|
||||
return [block(x) for block in self.output_block]
|
||||
313
comfy/ldm/moge/panorama.py
Normal file
313
comfy/ldm/moge/panorama.py
Normal file
@ -0,0 +1,313 @@
|
||||
"""Panorama (equirectangular) inference helpers for MoGe.
|
||||
|
||||
Splits an equirect into 12 perspective views via an icosahedron camera rig, runs
|
||||
the model per view, and stitches per-view distance maps back into a single
|
||||
equirect distance map via a multi-scale Poisson + gradient sparse solve.
|
||||
Image sampling uses F.grid_sample (GPU); the sparse solve uses lsmr (CPU).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from scipy.ndimage import convolve, map_coordinates
|
||||
from scipy.sparse import vstack, csr_array
|
||||
from scipy.sparse.linalg import lsmr
|
||||
|
||||
|
||||
def _icosahedron_directions() -> np.ndarray:
|
||||
"""12 icosahedron-vertex directions (non-normalised, matching upstream's vertex order)."""
|
||||
A = (1.0 + np.sqrt(5.0)) / 2.0
|
||||
return np.array([
|
||||
[0, 1, A], [0, -1, A], [0, 1, -A], [0, -1, -A],
|
||||
[1, A, 0], [-1, A, 0], [1, -A, 0], [-1, -A, 0],
|
||||
[A, 0, 1], [A, 0, -1], [-A, 0, 1], [-A, 0, -1],
|
||||
], dtype=np.float32)
|
||||
|
||||
|
||||
def _intrinsics_from_fov(fov_x_rad: float, fov_y_rad: float) -> np.ndarray:
|
||||
"""Normalised-image (unit-square) K matrix."""
|
||||
fx = 0.5 / np.tan(fov_x_rad / 2)
|
||||
fy = 0.5 / np.tan(fov_y_rad / 2)
|
||||
return np.array([[fx, 0, 0.5], [0, fy, 0.5], [0, 0, 1]], dtype=np.float32)
|
||||
|
||||
|
||||
def _extrinsics_look_at(eye: np.ndarray, target: np.ndarray, up: np.ndarray) -> np.ndarray:
|
||||
"""OpenCV-convention world->camera extrinsics for an array of look-at targets (N, 4, 4)."""
|
||||
eye = np.asarray(eye, dtype=np.float32)
|
||||
target = np.asarray(target, dtype=np.float32)
|
||||
up = np.asarray(up, dtype=np.float32)
|
||||
if target.ndim == 1:
|
||||
target = target[None]
|
||||
|
||||
fwd = target - eye
|
||||
fwd = fwd / np.linalg.norm(fwd, axis=-1, keepdims=True).clip(1e-12)
|
||||
right = np.cross(fwd, up)
|
||||
right_norm = np.linalg.norm(right, axis=-1, keepdims=True)
|
||||
# Fall back to an arbitrary perpendicular if forward is parallel to up.
|
||||
parallel = right_norm.squeeze(-1) < 1e-6
|
||||
if parallel.any():
|
||||
alt_up = np.array([1, 0, 0], dtype=np.float32)
|
||||
right = np.where(parallel[:, None], np.cross(fwd, alt_up), right)
|
||||
right_norm = np.linalg.norm(right, axis=-1, keepdims=True)
|
||||
right = right / right_norm.clip(1e-12)
|
||||
new_up = np.cross(fwd, right)
|
||||
|
||||
R = np.stack([right, new_up, fwd], axis=-2)
|
||||
t = -np.einsum("nij,j->ni", R, eye)
|
||||
E = np.zeros((R.shape[0], 4, 4), dtype=np.float32)
|
||||
E[:, :3, :3] = R
|
||||
E[:, :3, 3] = t
|
||||
E[:, 3, 3] = 1.0
|
||||
return E
|
||||
|
||||
|
||||
def get_panorama_cameras() -> Tuple[np.ndarray, List[np.ndarray]]:
|
||||
"""Returns (extrinsics (12, 4, 4), [intrinsics] * 12) for icosahedron views at 90 deg FoV."""
|
||||
targets = _icosahedron_directions()
|
||||
eye = np.zeros(3, dtype=np.float32)
|
||||
up = np.array([0, 0, 1], dtype=np.float32)
|
||||
extrinsics = _extrinsics_look_at(eye, targets, up)
|
||||
K = _intrinsics_from_fov(np.deg2rad(90.0), np.deg2rad(90.0))
|
||||
return extrinsics, [K] * len(targets)
|
||||
|
||||
|
||||
def spherical_uv_to_directions(uv: np.ndarray) -> np.ndarray:
|
||||
"""Equirect UV in [0, 1] -> 3D unit-direction (Z up)."""
|
||||
theta = (1 - uv[..., 0]) * (2 * np.pi)
|
||||
phi = uv[..., 1] * np.pi
|
||||
return np.stack([
|
||||
np.sin(phi) * np.cos(theta),
|
||||
np.sin(phi) * np.sin(theta),
|
||||
np.cos(phi),
|
||||
], axis=-1).astype(np.float32)
|
||||
|
||||
|
||||
def directions_to_spherical_uv(directions: np.ndarray) -> np.ndarray:
|
||||
"""3D direction -> equirect UV in [0, 1]."""
|
||||
n = np.linalg.norm(directions, axis=-1, keepdims=True).clip(1e-12)
|
||||
d = directions / n
|
||||
u = 1 - np.arctan2(d[..., 1], d[..., 0]) / (2 * np.pi) % 1.0
|
||||
v = np.arccos(d[..., 2].clip(-1, 1)) / np.pi
|
||||
return np.stack([u, v], axis=-1).astype(np.float32)
|
||||
|
||||
|
||||
def _uv_grid(H: int, W: int) -> np.ndarray:
|
||||
"""Pixel-center UV grid in [0, 1]; (H, W, 2)."""
|
||||
u = (np.arange(W, dtype=np.float32) + 0.5) / W
|
||||
v = (np.arange(H, dtype=np.float32) + 0.5) / H
|
||||
return np.stack(np.meshgrid(u, v, indexing="xy"), axis=-1)
|
||||
|
||||
|
||||
def _unproject_cv(uv: np.ndarray, depth: np.ndarray,
|
||||
extrinsics: np.ndarray, intrinsics: np.ndarray) -> np.ndarray:
|
||||
"""Back-project pixels into world coords (OpenCV convention)."""
|
||||
pix = np.concatenate([uv, np.ones_like(uv[..., :1])], axis=-1)
|
||||
K_inv = np.linalg.inv(intrinsics)
|
||||
cam = pix @ K_inv.T * depth[..., None]
|
||||
cam_h = np.concatenate([cam, np.ones_like(cam[..., :1])], axis=-1)
|
||||
E_inv = np.linalg.inv(extrinsics)
|
||||
return (cam_h @ E_inv.T)[..., :3]
|
||||
|
||||
|
||||
def _project_cv(points: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""World coords -> (uv, depth) in the camera (OpenCV convention)."""
|
||||
pts_h = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1)
|
||||
cam = pts_h @ extrinsics.T
|
||||
cam_xyz = cam[..., :3]
|
||||
depth = cam_xyz[..., 2]
|
||||
proj = cam_xyz @ intrinsics.T
|
||||
uv = proj[..., :2] / proj[..., 2:3].clip(1e-12)
|
||||
return uv.astype(np.float32), depth.astype(np.float32)
|
||||
|
||||
|
||||
def _grid_sample_uv(img_bchw: torch.Tensor, uv: torch.Tensor, mode: str = "bilinear") -> torch.Tensor:
|
||||
"""Sample img_bchw at UV-in-[0,1] coords uv of shape (B, H, W, 2); replicate-border."""
|
||||
grid = uv * 2.0 - 1.0
|
||||
return F.grid_sample(img_bchw, grid, mode=mode, padding_mode="border", align_corners=False)
|
||||
|
||||
|
||||
def split_panorama_image(image: torch.Tensor, extrinsics: np.ndarray, intrinsics: List[np.ndarray], resolution: int) -> torch.Tensor:
|
||||
"""(3, Hp, Wp) equirect on any device -> (N, 3, R, R) perspective crops on the same device."""
|
||||
device = image.device
|
||||
N = len(extrinsics)
|
||||
uv = _uv_grid(resolution, resolution)
|
||||
sample_uvs = []
|
||||
for i in range(N):
|
||||
world = _unproject_cv(uv, np.ones(uv.shape[:-1], dtype=np.float32), extrinsics[i], intrinsics[i])
|
||||
sample_uvs.append(directions_to_spherical_uv(world))
|
||||
sample_uvs = np.stack(sample_uvs, axis=0)
|
||||
|
||||
img_bchw = image.unsqueeze(0).expand(N, -1, -1, -1).contiguous()
|
||||
sample_uvs_t = torch.from_numpy(sample_uvs).to(device=device, dtype=image.dtype)
|
||||
return _grid_sample_uv(img_bchw, sample_uvs_t, mode="bilinear")
|
||||
|
||||
|
||||
def _poisson_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
|
||||
"""Sparse Laplacian operator over the H x W grid."""
|
||||
grid_index = np.arange(H * W).reshape(H, W)
|
||||
grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode="wrap" if wrap_x else "edge")
|
||||
grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode="wrap" if wrap_y else "edge")
|
||||
|
||||
data = np.array([[-4, 1, 1, 1, 1]], dtype=np.float32).repeat(H * W, axis=0).reshape(-1)
|
||||
indices = np.stack([
|
||||
grid_index[1:-1, 1:-1],
|
||||
grid_index[:-2, 1:-1], grid_index[2:, 1:-1],
|
||||
grid_index[1:-1, :-2], grid_index[1:-1, 2:],
|
||||
], axis=-1).reshape(-1)
|
||||
indptr = np.arange(0, H * W * 5 + 1, 5)
|
||||
return csr_array((data, indices, indptr), shape=(H * W, H * W))
|
||||
|
||||
|
||||
def _grad_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
|
||||
"""Sparse forward-difference operator over the H x W grid."""
|
||||
grid_index = np.arange(W * H).reshape(H, W)
|
||||
if wrap_x:
|
||||
grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode="wrap")
|
||||
if wrap_y:
|
||||
grid_index = np.pad(grid_index, ((0, 1), (0, 0)), mode="wrap")
|
||||
|
||||
data = np.concatenate([
|
||||
np.concatenate([
|
||||
np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1),
|
||||
-np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1),
|
||||
], axis=1).reshape(-1),
|
||||
np.concatenate([
|
||||
np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1),
|
||||
-np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1),
|
||||
], axis=1).reshape(-1),
|
||||
])
|
||||
indices = np.concatenate([
|
||||
np.concatenate([grid_index[:, :-1].reshape(-1, 1), grid_index[:, 1:].reshape(-1, 1)], axis=1).reshape(-1),
|
||||
np.concatenate([grid_index[:-1, :].reshape(-1, 1), grid_index[1:, :].reshape(-1, 1)], axis=1).reshape(-1),
|
||||
])
|
||||
nx = grid_index.shape[0] * (grid_index.shape[1] - 1)
|
||||
ny = (grid_index.shape[0] - 1) * grid_index.shape[1]
|
||||
indptr = np.arange(0, nx * 2 + ny * 2 + 1, 2)
|
||||
return csr_array((data, indices, indptr), shape=(nx + ny, H * W))
|
||||
|
||||
|
||||
def _scipy_remap_bilinear(img: np.ndarray, sample_pixels: np.ndarray, mode: str = "bilinear") -> np.ndarray:
|
||||
"""Bilinear/nearest sampling at fractional pixel coords; out-of-range clamps to nearest border."""
|
||||
H, W = img.shape[:2]
|
||||
yy = np.clip(sample_pixels[..., 1], 0, H - 1)
|
||||
xx = np.clip(sample_pixels[..., 0], 0, W - 1)
|
||||
order = 1 if mode == "bilinear" else 0
|
||||
if img.ndim == 2:
|
||||
return map_coordinates(img, [yy, xx], order=order, mode="nearest").astype(img.dtype)
|
||||
out = np.stack([
|
||||
map_coordinates(img[..., c], [yy, xx], order=order, mode="nearest")
|
||||
for c in range(img.shape[-1])
|
||||
], axis=-1)
|
||||
return out.astype(img.dtype)
|
||||
|
||||
|
||||
def merge_panorama_depth(width: int, height: int,
|
||||
distance_maps: List[np.ndarray], pred_masks: List[np.ndarray],
|
||||
extrinsics: List[np.ndarray], intrinsics: List[np.ndarray],
|
||||
on_view: Optional[Callable[[], None]] = None,
|
||||
on_solve_start: Optional[Callable[[int, int], None]] = None,
|
||||
on_solve_end: Optional[Callable[[int, int], None]] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Stitch per-view distance maps into a single equirect distance map.
|
||||
|
||||
Recursive multi-scale solve: solves at half resolution first and uses that as the lsmr init
|
||||
for the full-resolution solve. Optional callbacks fire per view processed and around each
|
||||
lsmr solve so callers can drive a progress bar.
|
||||
"""
|
||||
|
||||
if max(width, height) > 256:
|
||||
coarse_depth, _ = merge_panorama_depth(width // 2, height // 2,
|
||||
distance_maps, pred_masks, extrinsics, intrinsics,
|
||||
on_view=on_view,
|
||||
on_solve_start=on_solve_start,
|
||||
on_solve_end=on_solve_end)
|
||||
t = torch.from_numpy(coarse_depth).unsqueeze(0).unsqueeze(0)
|
||||
t = F.interpolate(t, size=(height, width), mode="bilinear", align_corners=False)
|
||||
depth_init = t.squeeze().numpy().astype(np.float32)
|
||||
else:
|
||||
depth_init = None
|
||||
|
||||
spherical_directions = spherical_uv_to_directions(_uv_grid(height, width))
|
||||
|
||||
pano_log_grad_maps, pano_grad_masks = [], []
|
||||
pano_log_lap_maps, pano_lap_masks = [], []
|
||||
pano_pred_masks: List[np.ndarray] = []
|
||||
|
||||
for i in range(len(distance_maps)):
|
||||
proj_uv, proj_depth = _project_cv(spherical_directions, extrinsics[i], intrinsics[i])
|
||||
proj_valid = (proj_depth > 0) & (proj_uv > 0).all(axis=-1) & (proj_uv < 1).all(axis=-1)
|
||||
|
||||
Hd, Wd = distance_maps[i].shape[:2]
|
||||
proj_pixels = np.clip(proj_uv, 0, 1) * np.array([Wd - 1, Hd - 1], dtype=np.float32)
|
||||
|
||||
log_dist = np.log(np.clip(distance_maps[i], 1e-6, None))
|
||||
sampled = _scipy_remap_bilinear(log_dist, proj_pixels, mode="bilinear")
|
||||
pano_log = np.where(proj_valid, sampled, 0.0).astype(np.float32)
|
||||
|
||||
sampled_mask = _scipy_remap_bilinear(pred_masks[i].astype(np.uint8), proj_pixels, mode="nearest")
|
||||
pano_pred = proj_valid & (sampled_mask > 0)
|
||||
|
||||
# Equirect wraps horizontally but not vertically: wrap pad along x, edge pad along y.
|
||||
padded = np.pad(pano_log, ((0, 0), (0, 1)), mode="wrap")
|
||||
gx, gy = padded[:, :-1] - padded[:, 1:], padded[:-1, :] - padded[1:, :]
|
||||
padded_m = np.pad(pano_pred, ((0, 0), (0, 1)), mode="wrap")
|
||||
mx, my = padded_m[:, :-1] & padded_m[:, 1:], padded_m[:-1, :] & padded_m[1:, :]
|
||||
pano_log_grad_maps.append((gx, gy))
|
||||
pano_grad_masks.append((mx, my))
|
||||
|
||||
padded = np.pad(pano_log, ((1, 1), (0, 0)), mode="edge")
|
||||
padded = np.pad(padded, ((0, 0), (1, 1)), mode="wrap")
|
||||
lap_kernel = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32)
|
||||
lap = convolve(padded, lap_kernel)[1:-1, 1:-1]
|
||||
padded_m = np.pad(pano_pred, ((1, 1), (0, 0)), mode="edge")
|
||||
padded_m = np.pad(padded_m, ((0, 0), (1, 1)), mode="wrap")
|
||||
m_kernel = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.uint8)
|
||||
lap_mask = convolve(padded_m.astype(np.uint8), m_kernel)[1:-1, 1:-1] == 5
|
||||
pano_log_lap_maps.append(lap)
|
||||
pano_lap_masks.append(lap_mask)
|
||||
pano_pred_masks.append(pano_pred)
|
||||
|
||||
if on_view is not None:
|
||||
on_view()
|
||||
|
||||
gx = np.stack([m[0] for m in pano_log_grad_maps], axis=0)
|
||||
gy = np.stack([m[1] for m in pano_log_grad_maps], axis=0)
|
||||
mx = np.stack([m[0] for m in pano_grad_masks], axis=0)
|
||||
my = np.stack([m[1] for m in pano_grad_masks], axis=0)
|
||||
gx_avg = (gx * mx).sum(axis=0) / mx.sum(axis=0).clip(1e-3)
|
||||
gy_avg = (gy * my).sum(axis=0) / my.sum(axis=0).clip(1e-3)
|
||||
|
||||
laps = np.stack(pano_log_lap_maps, axis=0)
|
||||
lap_masks = np.stack(pano_lap_masks, axis=0)
|
||||
lap_avg = (laps * lap_masks).sum(axis=0) / lap_masks.sum(axis=0).clip(1e-3)
|
||||
|
||||
grad_x_mask = mx.any(axis=0).reshape(-1)
|
||||
grad_y_mask = my.any(axis=0).reshape(-1)
|
||||
grad_mask = np.concatenate([grad_x_mask, grad_y_mask])
|
||||
lap_mask_flat = lap_masks.any(axis=0).reshape(-1)
|
||||
|
||||
A = vstack([
|
||||
_grad_equation(width, height, wrap_x=True, wrap_y=False)[grad_mask],
|
||||
_poisson_equation(width, height, wrap_x=True, wrap_y=False)[lap_mask_flat],
|
||||
])
|
||||
b = np.concatenate([
|
||||
gx_avg.reshape(-1)[grad_x_mask],
|
||||
gy_avg.reshape(-1)[grad_y_mask],
|
||||
lap_avg.reshape(-1)[lap_mask_flat],
|
||||
])
|
||||
x0 = np.log(np.clip(depth_init, 1e-6, None)).reshape(-1) if depth_init is not None else None
|
||||
|
||||
if on_solve_start is not None:
|
||||
on_solve_start(width, height)
|
||||
x, *_ = lsmr(A, b, atol=1e-5, btol=1e-5, x0=x0, show=False)
|
||||
if on_solve_end is not None:
|
||||
on_solve_end(width, height)
|
||||
|
||||
pano_depth = np.exp(x).reshape(height, width).astype(np.float32)
|
||||
pano_mask = np.any(pano_pred_masks, axis=0)
|
||||
return pano_depth, pano_mask
|
||||
@ -82,6 +82,8 @@ class VAEEncodeAudio(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, audio) -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
raise ValueError("VAEEncodeAudio: input audio is None (source video may have no audio track).")
|
||||
sample_rate = audio["sample_rate"]
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||
if vae_sample_rate != sample_rate:
|
||||
@ -171,6 +173,8 @@ class SaveAudio(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
raise ValueError("SaveAudio: input audio is None (source video may have no audio track).")
|
||||
return IO.NodeOutput(
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
|
||||
)
|
||||
@ -198,6 +202,8 @@ class SaveAudioMP3(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
raise ValueError("SaveAudioMP3: input audio is None (source video may have no audio track).")
|
||||
return IO.NodeOutput(
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
||||
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
||||
@ -226,6 +232,8 @@ class SaveAudioOpus(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
raise ValueError("SaveAudioOpus: input audio is None (source video may have no audio track).")
|
||||
return IO.NodeOutput(
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
||||
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
||||
@ -252,6 +260,8 @@ class PreviewAudio(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio) -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
raise ValueError("PreviewAudio: input audio is None (source video may have no audio track).")
|
||||
return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
|
||||
|
||||
save_flac = execute # TODO: remove
|
||||
@ -392,21 +402,26 @@ class TrimAudioDuration(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
return IO.NodeOutput(None)
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
audio_length = waveform.shape[-1]
|
||||
|
||||
if audio_length == 0:
|
||||
return IO.NodeOutput(audio)
|
||||
|
||||
if start_index < 0:
|
||||
start_frame = audio_length + int(round(start_index * sample_rate))
|
||||
else:
|
||||
start_frame = int(round(start_index * sample_rate))
|
||||
start_frame = max(0, min(start_frame, audio_length - 1))
|
||||
start_frame = max(0, min(start_frame, audio_length))
|
||||
|
||||
end_frame = start_frame + int(round(duration * sample_rate))
|
||||
end_frame = max(0, min(end_frame, audio_length))
|
||||
|
||||
if start_frame >= end_frame:
|
||||
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
|
||||
raise ValueError("TrimAudioDuration: Start time must be less than end time and be within the audio length.")
|
||||
|
||||
return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
|
||||
|
||||
@ -433,11 +448,13 @@ class SplitAudioChannels(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio) -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
return IO.NodeOutput(None, None)
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
if waveform.shape[1] != 2:
|
||||
raise ValueError("AudioSplit: Input audio has only one channel.")
|
||||
raise ValueError(f"AudioSplit: Input audio must be stereo (2 channels), got {waveform.shape[1]} channel(s).")
|
||||
|
||||
left_channel = waveform[..., 0:1, :]
|
||||
right_channel = waveform[..., 1:2, :]
|
||||
@ -465,6 +482,12 @@ class JoinAudioChannels(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio_left, audio_right) -> IO.NodeOutput:
|
||||
if audio_left is None and audio_right is None:
|
||||
return IO.NodeOutput(None)
|
||||
if audio_left is None:
|
||||
return IO.NodeOutput(audio_right)
|
||||
if audio_right is None:
|
||||
return IO.NodeOutput(audio_left)
|
||||
waveform_left = audio_left["waveform"]
|
||||
sample_rate_left = audio_left["sample_rate"]
|
||||
waveform_right = audio_right["waveform"]
|
||||
@ -538,6 +561,12 @@ class AudioConcat(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
|
||||
if audio1 is None and audio2 is None:
|
||||
return IO.NodeOutput(None)
|
||||
if audio1 is None:
|
||||
return IO.NodeOutput(audio2)
|
||||
if audio2 is None:
|
||||
return IO.NodeOutput(audio1)
|
||||
waveform_1 = audio1["waveform"]
|
||||
waveform_2 = audio2["waveform"]
|
||||
sample_rate_1 = audio1["sample_rate"]
|
||||
@ -585,6 +614,12 @@ class AudioMerge(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
|
||||
if audio1 is None and audio2 is None:
|
||||
return IO.NodeOutput(None)
|
||||
if audio1 is None:
|
||||
return IO.NodeOutput(audio2)
|
||||
if audio2 is None:
|
||||
return IO.NodeOutput(audio1)
|
||||
waveform_1 = audio1["waveform"]
|
||||
waveform_2 = audio2["waveform"]
|
||||
sample_rate_1 = audio1["sample_rate"]
|
||||
@ -595,6 +630,9 @@ class AudioMerge(IO.ComfyNode):
|
||||
length_1 = waveform_1.shape[-1]
|
||||
length_2 = waveform_2.shape[-1]
|
||||
|
||||
if length_1 == 0 or length_2 == 0:
|
||||
return IO.NodeOutput({"waveform": waveform_1, "sample_rate": output_sample_rate})
|
||||
|
||||
if length_2 > length_1:
|
||||
logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.")
|
||||
waveform_2 = waveform_2[..., :length_1]
|
||||
@ -646,6 +684,8 @@ class AudioAdjustVolume(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, volume) -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
return IO.NodeOutput(None)
|
||||
if volume == 0:
|
||||
return IO.NodeOutput(audio)
|
||||
waveform = audio["waveform"]
|
||||
@ -729,8 +769,14 @@ class AudioEqualizer3Band(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
return IO.NodeOutput(None)
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
if waveform.shape[-1] == 0:
|
||||
return IO.NodeOutput(audio)
|
||||
|
||||
eq_waveform = waveform.clone()
|
||||
|
||||
# 1. Apply Low Shelf (Bass)
|
||||
|
||||
@ -136,7 +136,7 @@ class ImageFromBatch(IO.ComfyNode):
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("batch_index", default=0, min=0, max=4095),
|
||||
IO.Int.Input("batch_index", default=0, min=-MAX_RESOLUTION, max=MAX_RESOLUTION),
|
||||
IO.Int.Input("length", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
@ -145,7 +145,9 @@ class ImageFromBatch(IO.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, image, batch_index, length) -> IO.NodeOutput:
|
||||
s_in = image
|
||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||
if batch_index < 0:
|
||||
batch_index += s_in.shape[0]
|
||||
batch_index = max(0, min(s_in.shape[0] - 1, batch_index))
|
||||
length = min(s_in.shape[0] - batch_index, length)
|
||||
s = s_in[batch_index:batch_index + length].clone()
|
||||
return IO.NodeOutput(s)
|
||||
|
||||
@ -219,7 +219,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
|
||||
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
|
||||
),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
@ -298,7 +298,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
else:
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||
1.0 - strength,
|
||||
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
@ -318,7 +318,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, cond_length, 1, 1),
|
||||
1.0 - strength,
|
||||
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
|
||||
@ -134,8 +134,11 @@ class ModelSamplingSD3:
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
original = m.get_model_object("model_sampling")
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_parameters(shift=shift, multiplier=multiplier)
|
||||
if hasattr(original, "noise_scale"):
|
||||
model_sampling.set_noise_scale(original.noise_scale)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
@ -315,7 +318,7 @@ class ModelNoiseScale:
|
||||
|
||||
def patch(self, model, noise_scale):
|
||||
m = model.clone()
|
||||
original = m.model.model_sampling
|
||||
original = m.get_model_object("model_sampling")
|
||||
ms = type(original)(m.model.model_config)
|
||||
ms.set_parameters(shift=original.shift, multiplier=original.multiplier)
|
||||
ms.set_noise_scale(noise_scale)
|
||||
|
||||
406
comfy_extras/nodes_moge.py
Normal file
406
comfy_extras/nodes_moge.py
Normal file
@ -0,0 +1,406 @@
|
||||
"""ComfyUI nodes for the native MoGe (Monocular Geometry Estimation) integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy_api.latest import ComfyExtension, Types, io
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.ldm.moge.model import MoGeModel
|
||||
from comfy.ldm.moge.geometry import triangulate_grid_mesh
|
||||
from comfy.ldm.moge.panorama import get_panorama_cameras, split_panorama_image, merge_panorama_depth, spherical_uv_to_directions, _uv_grid
|
||||
import comfy.model_management
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
MoGeModelType = io.Custom("MOGE_MODEL")
|
||||
MoGeGeometry = io.Custom("MOGE_GEOMETRY")
|
||||
|
||||
|
||||
# MOGE_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them):
|
||||
# "points": torch.Tensor (B, H, W, 3)
|
||||
# "depth": torch.Tensor (B, H, W)
|
||||
# "intrinsics": torch.Tensor (B, 3, 3) -- perspective only
|
||||
# "mask": torch.Tensor (B, H, W) bool
|
||||
# "normal": torch.Tensor (B, H, W, 3) -- v2 only
|
||||
# "image": torch.Tensor (B, H, W, 3) in [0, 1], CPU (always present)
|
||||
|
||||
|
||||
def _turbo(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Anton Mikhailov polynomial approximation of the turbo colormap."""
|
||||
x = x.clamp(0.0, 1.0)
|
||||
x2 = x * x
|
||||
x3 = x2 * x
|
||||
x4 = x2 * x2
|
||||
x5 = x4 * x
|
||||
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
|
||||
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
|
||||
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
|
||||
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)
|
||||
|
||||
|
||||
def _normals_from_points(points: torch.Tensor) -> torch.Tensor:
|
||||
"""Camera-space surface normals from a (B, H, W, 3) point map (v1 fallback)."""
|
||||
finite = torch.isfinite(points).all(dim=-1)
|
||||
pts = torch.where(finite.unsqueeze(-1), points, torch.zeros_like(points))
|
||||
dx = pts[..., :, 2:, :] - pts[..., :, :-2, :]
|
||||
dy = pts[..., 2:, :, :] - pts[..., :-2, :, :]
|
||||
dx = torch.nn.functional.pad(dx.permute(0, 3, 1, 2), (1, 1, 0, 0)).permute(0, 2, 3, 1)
|
||||
dy = torch.nn.functional.pad(dy.permute(0, 3, 1, 2), (0, 0, 1, 1)).permute(0, 2, 3, 1)
|
||||
# dy x dx (not dx x dy) so the result is outward-facing in OpenCV (Y-down flips the right-hand rule), matching v2's predicted normals.
|
||||
n = torch.cross(dy, dx, dim=-1)
|
||||
n = torch.nn.functional.normalize(n, dim=-1)
|
||||
return torch.where(finite.unsqueeze(-1), n, torch.zeros_like(n))
|
||||
|
||||
|
||||
def _normalize_disparity(depth: torch.Tensor) -> torch.Tensor:
|
||||
"""Per-batch normalize 1/depth to [0, 1] using 0.1/99.9 percentile clipping."""
|
||||
out = torch.zeros_like(depth)
|
||||
for i in range(depth.shape[0]):
|
||||
d = depth[i]
|
||||
valid = torch.isfinite(d) & (d > 0)
|
||||
if not valid.any():
|
||||
continue
|
||||
disp = torch.where(valid, 1.0 / d.clamp_min(1e-6), torch.zeros_like(d))
|
||||
disp_valid = disp[valid]
|
||||
lo = torch.quantile(disp_valid, 0.001)
|
||||
hi = torch.quantile(disp_valid, 0.999)
|
||||
scale = (hi - lo).clamp_min(1e-6)
|
||||
norm = ((disp - lo) / scale).clamp(0.0, 1.0)
|
||||
out[i] = torch.where(valid, norm, torch.zeros_like(norm))
|
||||
return out
|
||||
|
||||
|
||||
class LoadMoGeModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadMoGeModel",
|
||||
display_name="Load MoGe Model",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("geometry_estimation")),
|
||||
],
|
||||
outputs=[MoGeModelType.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
path = folder_paths.get_full_path_or_raise("geometry_estimation", model_name)
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
return io.NodeOutput(MoGeModel(sd))
|
||||
|
||||
|
||||
class MoGePanoramaInference(io.ComfyNode):
|
||||
"""Equirectangular panorama inference: split into 12 perspective views, run
|
||||
MoGe at fov_x=90 on each, merge via multi-scale Poisson + gradient solve.
|
||||
v2's predicted normals and metric scale are ignored (per-view scales would not align across seams).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGePanoramaInference",
|
||||
display_name="MoGe Panorama Inference",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeModelType.Input("moge_model"),
|
||||
io.Image.Input("image", tooltip="Equirectangular panorama (any aspect)."),
|
||||
io.Int.Input("resolution_level", default=9, min=0, max=9,
|
||||
tooltip="Per-view detail (0 = fastest, 9 = most detailed)."),
|
||||
io.Int.Input("split_resolution", default=512, min=256, max=1024,
|
||||
tooltip="Resolution of each perspective split."),
|
||||
io.Int.Input("merge_resolution", default=1920, min=256, max=8192,
|
||||
tooltip="Long-side resolution of the merged equirect distance map."),
|
||||
io.Int.Input("batch_size", default=4, min=1, max=12,
|
||||
tooltip="Views per inference batch (12 splits total)."),
|
||||
],
|
||||
outputs=[MoGeGeometry.Output(display_name="moge_geometry")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, moge_model, image, resolution_level, split_resolution, merge_resolution, batch_size) -> io.NodeOutput:
|
||||
|
||||
if image.shape[0] != 1:
|
||||
raise ValueError(f"MoGePanoramaInference takes a single image (got batch of {image.shape[0]})")
|
||||
|
||||
image = image[..., :3]
|
||||
H, W = int(image.shape[1]), int(image.shape[2])
|
||||
scale = min(merge_resolution / max(H, W), 1.0)
|
||||
merge_h, merge_w = max(int(H * scale), 32), max(int(W * scale), 32)
|
||||
|
||||
extrinsics, intrinsics = get_panorama_cameras()
|
||||
|
||||
comfy.model_management.load_model_gpu(moge_model.patcher)
|
||||
device = moge_model.load_device
|
||||
img_chw = image[0].movedim(-1, -3).to(device=device, dtype=moge_model.dtype)
|
||||
splits = split_panorama_image(img_chw, extrinsics, intrinsics, split_resolution)
|
||||
|
||||
n_views = splits.shape[0]
|
||||
|
||||
# Weight each lsmr solve by 4^level so the final-resolution solve doesn't leave the bar idle.
|
||||
merge_levels: list[tuple[int, int]] = []
|
||||
w_, h_ = merge_w, merge_h
|
||||
while True:
|
||||
merge_levels.append((w_, h_))
|
||||
if max(w_, h_) <= 256:
|
||||
break
|
||||
w_, h_ = w_ // 2, h_ // 2
|
||||
merge_levels.reverse()
|
||||
|
||||
solve_weight = {wh: 4 ** i for i, wh in enumerate(merge_levels)}
|
||||
n_merge_view_units = n_views * len(merge_levels)
|
||||
n_merge_solve_units = sum(solve_weight.values())
|
||||
|
||||
pbar = comfy.utils.ProgressBar(n_views + n_merge_view_units + n_merge_solve_units)
|
||||
done = 0
|
||||
|
||||
distance_maps: list = []
|
||||
masks: list = []
|
||||
with tqdm(total=n_views, desc="MoGe panorama inference") as tq:
|
||||
for i in range(0, n_views, batch_size):
|
||||
batch = splits[i:i + batch_size]
|
||||
# apply_metric_scale=False: per-view scales would not align across overlap seams.
|
||||
result = moge_model.infer(batch, resolution_level=resolution_level,
|
||||
fov_x=90.0, force_projection=True,
|
||||
apply_mask=False, apply_metric_scale=False)
|
||||
distance_maps.extend(list(result["points"].float().norm(dim=-1).cpu().numpy()))
|
||||
masks.extend(list(result["mask"].cpu().numpy()))
|
||||
n = batch.shape[0]
|
||||
done += n
|
||||
pbar.update_absolute(done)
|
||||
tq.update(n)
|
||||
|
||||
with tqdm(total=n_merge_view_units + n_merge_solve_units, desc="MoGe panorama merge: views") as tq:
|
||||
def _on_merge_view():
|
||||
nonlocal done
|
||||
done += 1
|
||||
pbar.update_absolute(done)
|
||||
tq.update(1)
|
||||
|
||||
def _on_solve_start(w, h):
|
||||
tq.set_description(f"MoGe panorama merge: solving {w}x{h}")
|
||||
|
||||
def _on_solve_end(w, h):
|
||||
nonlocal done
|
||||
weight = solve_weight[(w, h)]
|
||||
done += weight
|
||||
pbar.update_absolute(done)
|
||||
tq.update(weight)
|
||||
tq.set_description("MoGe panorama merge: views")
|
||||
|
||||
pano_depth, pano_mask = merge_panorama_depth(
|
||||
merge_w, merge_h, distance_maps, masks, list(extrinsics), intrinsics,
|
||||
on_view=_on_merge_view, on_solve_start=_on_solve_start, on_solve_end=_on_solve_end)
|
||||
|
||||
pano_depth = torch.from_numpy(pano_depth)
|
||||
pano_mask = torch.from_numpy(pano_mask)
|
||||
|
||||
if (merge_h, merge_w) != (H, W):
|
||||
pano_depth = torch.nn.functional.interpolate(pano_depth[None, None], size=(H, W), mode="bilinear", align_corners=False).squeeze()
|
||||
pano_mask = torch.nn.functional.interpolate(pano_mask[None, None].float(), size=(H, W), mode="nearest").squeeze() > 0
|
||||
|
||||
# Pixels uncovered by any view's predicted foreground are unconstrained in the lsmr solve and stay at log_depth=0 (depth=1)
|
||||
if pano_mask.any() and not pano_mask.all():
|
||||
far = torch.quantile(pano_depth[pano_mask], 0.95) * 5.0
|
||||
pano_depth = torch.where(pano_mask, pano_depth, far)
|
||||
|
||||
directions = torch.from_numpy(spherical_uv_to_directions(_uv_grid(H, W)))
|
||||
points = (directions * pano_depth[..., None]).unsqueeze(0)
|
||||
depth = pano_depth.unsqueeze(0)
|
||||
mask = pano_mask.unsqueeze(0)
|
||||
|
||||
# Points stay in MoGe spherical coords; MoGePointMapToMesh applies the spherical->glTF rotation after triangulation
|
||||
moge_geometry = {"points": points, "depth": depth, "mask": mask, "image": image.cpu()}
|
||||
return io.NodeOutput(moge_geometry)
|
||||
|
||||
|
||||
class MoGeInference(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeInference",
|
||||
display_name="MoGe Inference",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeModelType.Input("moge_model"),
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input("resolution_level", default=9, min=0, max=9,
|
||||
tooltip="0 = fastest, 9 = most detail."),
|
||||
io.Float.Input("fov_x_degrees", default=0.0, min=0.0, max=170.0, step=0.1, advanced=True,
|
||||
tooltip="Horizontal field of view of the source camera. Sets the focal length used to unproject the depth map into 3D. 0 = auto-recover from the predicted points."),
|
||||
io.Int.Input("batch_size", default=4, min=1, max=64,
|
||||
tooltip="Images per inference call. Lower if you OOM on a long video / image set."),
|
||||
io.Boolean.Input("force_projection", default=True, advanced=True),
|
||||
io.Boolean.Input("apply_mask", default=True, advanced=True,
|
||||
tooltip="Set masked-out (sky / invalid) pixels to inf in points and depth so meshing culls them. Disable to keep the raw predicted geometry everywhere; the mask is still returned separately."),
|
||||
],
|
||||
outputs=[MoGeGeometry.Output(display_name="moge_geometry")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, moge_model, image, resolution_level, fov_x_degrees, batch_size, force_projection, apply_mask) -> io.NodeOutput:
|
||||
|
||||
image = image[..., :3]
|
||||
bchw = image.movedim(-1, -3).contiguous()
|
||||
B = bchw.shape[0]
|
||||
fov = None if fov_x_degrees <= 0 else float(fov_x_degrees)
|
||||
|
||||
pbar = comfy.utils.ProgressBar(B)
|
||||
chunks: list[dict] = []
|
||||
with tqdm(total=B, desc="MoGe inference") as tq:
|
||||
for i in range(0, B, batch_size):
|
||||
chunk = bchw[i:i + batch_size]
|
||||
chunks.append(moge_model.infer(chunk, resolution_level=resolution_level, fov_x=fov,
|
||||
force_projection=force_projection, apply_mask=apply_mask))
|
||||
pbar.update_absolute(min(i + batch_size, B))
|
||||
tq.update(chunk.shape[0])
|
||||
|
||||
def stack(field):
|
||||
vals = [c[field] for c in chunks if field in c]
|
||||
return torch.cat(vals, dim=0) if vals else None
|
||||
|
||||
moge_geometry = {"image": image.cpu()}
|
||||
for field in ("points", "depth", "intrinsics", "mask", "normal"):
|
||||
v = stack(field)
|
||||
if v is not None:
|
||||
moge_geometry[field] = v
|
||||
return io.NodeOutput(moge_geometry)
|
||||
|
||||
|
||||
class MoGeRender(io.ComfyNode):
|
||||
"""Render a visualization or mask from a MOGE_GEOMETRY packet."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeRender",
|
||||
display_name="MoGe Render",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeGeometry.Input("moge_geometry"),
|
||||
io.Combo.Input("output", options=["depth", "depth_colored", "normal_opengl", "normal_directx", "mask"], default="depth",
|
||||
tooltip="DirectX vs OpenGL controls the normal-map green-channel convention. DirectX: green = -Y down (Unreal). OpenGL: green = +Y up (Blender, Substance, Unity, glTF)."),
|
||||
],
|
||||
outputs=[io.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, moge_geometry, output) -> io.NodeOutput:
|
||||
is_normal = output in ("normal_directx", "normal_opengl")
|
||||
opengl = output.endswith("_opengl")
|
||||
|
||||
# Pick the input tensor for the chosen mode and validate availability.
|
||||
if output in ("depth", "depth_colored"):
|
||||
if "depth" not in moge_geometry:
|
||||
raise ValueError("moge_geometry has no depth output.")
|
||||
src = moge_geometry["depth"]
|
||||
elif is_normal:
|
||||
if "normal" in moge_geometry:
|
||||
src = moge_geometry["normal"]
|
||||
elif "points" in moge_geometry:
|
||||
src = moge_geometry["points"]
|
||||
else:
|
||||
raise ValueError("moge_geometry has neither normals nor points to derive normals from.")
|
||||
elif output == "mask":
|
||||
if "mask" not in moge_geometry:
|
||||
raise ValueError("moge_geometry has no mask output.")
|
||||
src = moge_geometry["mask"]
|
||||
else:
|
||||
raise ValueError(f"Unknown output mode: {output}")
|
||||
|
||||
B = src.shape[0]
|
||||
pbar = comfy.utils.ProgressBar(B)
|
||||
out: list[torch.Tensor] = []
|
||||
with tqdm(total=B, desc=f"MoGe render: {output}") as tq:
|
||||
for i in range(B):
|
||||
slc = src[i:i + 1].float()
|
||||
if output in ("depth", "depth_colored"):
|
||||
d = _normalize_disparity(slc)
|
||||
out.append(_turbo(d) if output == "depth_colored"
|
||||
else d.unsqueeze(-1).expand(*d.shape, 3).contiguous())
|
||||
elif is_normal:
|
||||
n = slc if "normal" in moge_geometry else _normals_from_points(slc)
|
||||
# MoGe is OpenCV (Z+ into scene); normal-map convention is Z+ out of surface, so flip Z.
|
||||
y_sign = -1.0 if opengl else 1.0
|
||||
n = n * n.new_tensor([1.0, y_sign, -1.0])
|
||||
out.append((n * 0.5 + 0.5).clamp(0.0, 1.0))
|
||||
elif output == "mask":
|
||||
out.append(slc.unsqueeze(-1).expand(*slc.shape, 3).contiguous())
|
||||
pbar.update_absolute(i + 1)
|
||||
tq.update(1)
|
||||
result = torch.cat(out, dim=0).to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
return io.NodeOutput(result)
|
||||
|
||||
|
||||
class MoGePointMapToMesh(io.ComfyNode):
|
||||
"""Triangulate one image of a MoGe point map into a Types.MESH (UVs + texture)."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGePointMapToMesh",
|
||||
display_name="MoGe Point Map to Mesh",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeGeometry.Input("moge_geometry"),
|
||||
io.Int.Input("batch_index", default=0, min=0, max=4096,
|
||||
tooltip="Which image of a batched MoGe geometry to mesh. Per-image vertex counts "
|
||||
"differ, so batches can't be stacked into a single MESH."),
|
||||
io.Int.Input("decimation", default=1, min=1, max=8,
|
||||
tooltip="Vertex stride; 1 = full resolution."),
|
||||
io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Drop pixels whose 3x3 depth span exceeds this fraction. 0 = off."),
|
||||
io.Boolean.Input("texture", default=True,
|
||||
tooltip="Carry the source image through as the baseColor texture."),
|
||||
],
|
||||
outputs=[io.Mesh.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, moge_geometry, batch_index, decimation, discontinuity_threshold, texture) -> io.NodeOutput:
|
||||
if "points" not in moge_geometry:
|
||||
raise ValueError("moge_geometry has no points output.")
|
||||
points = moge_geometry["points"]
|
||||
B = points.shape[0]
|
||||
if batch_index >= B:
|
||||
raise ValueError(f"batch_index {batch_index} out of range; moge_geometry has batch size {B}.")
|
||||
|
||||
# Pass depth so the rtol edge check sees radial depth -- for panoramas
|
||||
# points[..., 2] = cos(phi)*r goes negative below the equator and the rtol clamp would drop the bottom half.
|
||||
edge_depth = moge_geometry["depth"][batch_index] if "depth" in moge_geometry else None
|
||||
verts, faces, uvs = triangulate_grid_mesh(
|
||||
points[batch_index], decimation=decimation,
|
||||
discontinuity_threshold=discontinuity_threshold, depth=edge_depth,
|
||||
)
|
||||
if verts.shape[0] == 0 or faces.shape[0] == 0:
|
||||
raise ValueError("MoGe produced an empty mesh; try discontinuity_threshold=0 or apply_mask=False.")
|
||||
|
||||
if "intrinsics" not in moge_geometry:
|
||||
# Panorama: rotate MoGe spherical (Z up) -> glTF (Y up, Z back), correct for inside-the-sphere viewing)
|
||||
verts = verts[:, [1, 2, 0]].contiguous()
|
||||
else:
|
||||
# Perspective MoGe (X right, Y down, Z forward) -> glTF; face flip keeps winding CCW after the Y/Z flip.
|
||||
verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype)
|
||||
faces = faces[:, [0, 2, 1]].contiguous()
|
||||
|
||||
tex = moge_geometry["image"][batch_index:batch_index + 1] if texture else None
|
||||
mesh = Types.MESH(
|
||||
vertices=verts.unsqueeze(0),
|
||||
faces=faces.unsqueeze(0),
|
||||
uvs=uvs.unsqueeze(0),
|
||||
texture=tex,
|
||||
)
|
||||
return io.NodeOutput(mesh)
|
||||
|
||||
|
||||
class MoGeExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MoGeExtension:
|
||||
return MoGeExtension()
|
||||
@ -56,6 +56,8 @@ folder_names_and_paths["background_removal"] = ([os.path.join(models_dir, "backg
|
||||
|
||||
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geometry_estimation")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
|
||||
1
main.py
1
main.py
@ -474,6 +474,7 @@ def start_comfyui(asyncio_loop=None):
|
||||
comfyui_manager.start()
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
logging.info("psst — we're hiring! https://www.comfy.org/careers")
|
||||
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
init_api_nodes=not args.disable_api_nodes
|
||||
|
||||
7
nodes.py
7
nodes.py
@ -1221,7 +1221,7 @@ class LatentFromBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
||||
"batch_index": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION}),
|
||||
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
@ -1232,7 +1232,9 @@ class LatentFromBatch:
|
||||
def frombatch(self, samples, batch_index, length):
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||
if batch_index < 0:
|
||||
batch_index += s_in.shape[0]
|
||||
batch_index = max(0, min(s_in.shape[0] - 1, batch_index))
|
||||
length = min(s_in.shape[0] - batch_index, length)
|
||||
s["samples"] = s_in[batch_index:batch_index + length].clone()
|
||||
if "noise_mask" in samples:
|
||||
@ -2437,6 +2439,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_wandancer.py",
|
||||
"nodes_hidream_o1.py",
|
||||
"nodes_save_3d.py",
|
||||
"nodes_moge.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
18
openapi.yaml
18
openapi.yaml
@ -6030,6 +6030,24 @@ components:
|
||||
type: string
|
||||
nullable: true
|
||||
description: Minimum required workflow templates version for this ComfyUI build
|
||||
comfy_package_versions:
|
||||
type: array
|
||||
description: Installed and required versions for every comfy* package pinned in requirements.txt
|
||||
items:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
- installed
|
||||
- required
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
installed:
|
||||
type: string
|
||||
nullable: true
|
||||
required:
|
||||
type: string
|
||||
nullable: true
|
||||
devices:
|
||||
type: array
|
||||
items:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.43.18
|
||||
comfyui-workflow-templates==0.9.75
|
||||
comfyui-workflow-templates==0.9.77
|
||||
comfyui-embedded-docs==0.5.0
|
||||
torch
|
||||
torchsde
|
||||
|
||||
@ -656,6 +656,7 @@ class PromptServer():
|
||||
required_frontend_version = FrontendManager.get_required_frontend_version()
|
||||
installed_templates_version = FrontendManager.get_installed_templates_version()
|
||||
required_templates_version = FrontendManager.get_required_templates_version()
|
||||
comfy_package_versions = FrontendManager.get_comfy_package_versions()
|
||||
|
||||
system_stats = {
|
||||
"system": {
|
||||
@ -666,6 +667,7 @@ class PromptServer():
|
||||
"required_frontend_version": required_frontend_version,
|
||||
"installed_templates_version": installed_templates_version,
|
||||
"required_templates_version": required_templates_version,
|
||||
"comfy_package_versions": comfy_package_versions,
|
||||
"python_version": sys.version,
|
||||
"pytorch_version": comfy.model_management.torch_version,
|
||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||
|
||||
@ -52,7 +52,10 @@ def mock_provider(mock_releases):
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
import utils.install_util
|
||||
import app.frontend_management
|
||||
|
||||
utils.install_util.PACKAGE_VERSIONS = {}
|
||||
app.frontend_management.COMFY_PACKAGE_VERSIONS = []
|
||||
|
||||
|
||||
def test_get_release(mock_provider, mock_releases):
|
||||
@ -147,7 +150,7 @@ def test_init_frontend_default_with_mocks():
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||
patch("app.frontend_management.check_comfy_packages_versions") as mock_check,
|
||||
patch.object(
|
||||
FrontendManager, "default_frontend_path", return_value="/mocked/path"
|
||||
),
|
||||
@ -168,7 +171,7 @@ def test_init_frontend_fallback_on_error():
|
||||
patch.object(
|
||||
FrontendManager, "init_frontend_unsafe", side_effect=Exception("Test error")
|
||||
),
|
||||
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||
patch("app.frontend_management.check_comfy_packages_versions") as mock_check,
|
||||
patch.object(
|
||||
FrontendManager, "default_frontend_path", return_value="/default/path"
|
||||
),
|
||||
@ -277,7 +280,9 @@ def test_get_installed_templates_version():
|
||||
|
||||
def test_get_installed_templates_version_not_installed():
|
||||
# Act
|
||||
with patch("app.frontend_management.version", side_effect=Exception("Package not found")):
|
||||
with patch(
|
||||
"app.frontend_management.version", side_effect=Exception("Package not found")
|
||||
):
|
||||
version = FrontendManager.get_installed_templates_version()
|
||||
|
||||
# Assert
|
||||
|
||||
@ -1,9 +1,23 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.model_detection import detect_unet_config, model_config_from_unet_config
|
||||
import comfy.supported_models
|
||||
|
||||
|
||||
def _freeze(value):
|
||||
"""Recursively convert a value to a hashable form so configs can be
|
||||
compared/used as dict keys or set members."""
|
||||
if isinstance(value, dict):
|
||||
return frozenset((k, _freeze(v)) for k, v in value.items())
|
||||
if isinstance(value, (list, tuple)):
|
||||
return tuple(_freeze(v) for v in value)
|
||||
if isinstance(value, set):
|
||||
return frozenset(_freeze(v) for v in value)
|
||||
return value
|
||||
|
||||
|
||||
def _make_longcat_comfyui_sd():
|
||||
"""Minimal ComfyUI-format state dict for pre-converted LongCat-Image weights."""
|
||||
sd = {}
|
||||
@ -110,3 +124,21 @@ class TestModelDetection:
|
||||
model_config = model_config_from_unet_config(unet_config, sd)
|
||||
assert model_config is not None
|
||||
assert type(model_config).__name__ == "FluxSchnell"
|
||||
|
||||
def test_unet_config_and_required_keys_combination_is_unique(self):
|
||||
"""Each model in the registry must have a unique combination of
|
||||
``unet_config`` and ``required_keys``. If two models share the same
|
||||
combination, ``BASE.matches`` cannot disambiguate between them and the
|
||||
first one in the list will always win."""
|
||||
models = comfy.supported_models.models
|
||||
groups = defaultdict(list)
|
||||
for model in models:
|
||||
key = (_freeze(model.unet_config), _freeze(model.required_keys))
|
||||
groups[key].append(model.__name__)
|
||||
|
||||
duplicates = {k: names for k, names in groups.items() if len(names) > 1}
|
||||
assert not duplicates, (
|
||||
"Found models sharing the same (unet_config, required_keys) "
|
||||
"combination, which makes detection ambiguous: "
|
||||
+ "; ".join(", ".join(names) for names in duplicates.values())
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user