mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-23 01:28:21 +08:00
Compare commits
6 Commits
robinjhuan
...
v0.22.2
| Author | SHA1 | Date | |
|---|---|---|---|
| 85abace906 | |||
| f5d678d9ee | |||
| 59cafaf744 | |||
| 13e2d133a6 | |||
| ef46f5de76 | |||
| 7e02881b36 |
@ -1,5 +1,2 @@
|
||||
# Admins
|
||||
* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai
|
||||
|
||||
/CODEOWNERS @comfyanonymous
|
||||
/.ci/ @comfyanonymous
|
||||
/.github/ @comfyanonymous
|
||||
|
||||
@ -433,7 +433,7 @@ See also: [https://www.comfy.org/](https://www.comfy.org/)
|
||||
|
||||
## Frontend Development
|
||||
|
||||
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). The compiled JS files (from TS/Vue) are published to [pypi](https://pypi.org/project/comfyui-frontend-package) and installed as a dependency in ComfyUI.
|
||||
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
|
||||
|
||||
### Reporting Issues and Requesting Features
|
||||
|
||||
|
||||
@ -110,11 +110,13 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
|
||||
CACHE_RAM_AUTO_GB = -1.0
|
||||
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metavar="GB", help="Use RAM pressure caching with the specified headroom thresholds. This is the default caching mode. The first value sets the active-cache threshold; the optional second value sets the inactive-cache/pin threshold. Defaults when no values are provided: active 25%% of system RAM (min 4GB, max 32GB), inactive 75%% of system RAM (min 12GB, max 96GB).")
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
@ -243,9 +245,6 @@ if comfy.options.args_parsing:
|
||||
else:
|
||||
args = parser.parse_args([])
|
||||
|
||||
if args.cache_ram is not None and len(args.cache_ram) > 2:
|
||||
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
|
||||
|
||||
if args.windows_standalone_build:
|
||||
args.auto_launch = True
|
||||
|
||||
|
||||
@ -484,23 +484,16 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
||||
|
||||
return weight
|
||||
|
||||
def prefetch_prepared_value(value, counter, destination, stream, copy):
|
||||
def prefetch_prepared_value(value, allocate_buffer, stream):
|
||||
if isinstance(value, torch.Tensor):
|
||||
size = comfy.memory_management.vram_aligned_size(value)
|
||||
offset = counter[0]
|
||||
counter[0] += size
|
||||
if destination is None:
|
||||
return value
|
||||
|
||||
dest = destination[offset:offset + size]
|
||||
if copy:
|
||||
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
|
||||
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
|
||||
elif isinstance(value, weight_adapter.WeightAdapterBase):
|
||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy))
|
||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
|
||||
elif isinstance(value, tuple):
|
||||
return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value)
|
||||
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
|
||||
elif isinstance(value, list):
|
||||
return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value]
|
||||
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
@ -15,7 +15,7 @@ class TensorFileSlice(NamedTuple):
|
||||
size: int
|
||||
|
||||
|
||||
def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):
|
||||
def read_tensor_file_slice_into(tensor, destination):
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
if not isinstance(destination, QuantizedTensor):
|
||||
@ -23,17 +23,12 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
||||
if tensor._layout_cls != destination._layout_cls:
|
||||
return False
|
||||
|
||||
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream,
|
||||
destination2=(destination2._qdata if destination2 is not None else None)):
|
||||
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
|
||||
return False
|
||||
|
||||
dst_orig_dtype = destination._params.orig_dtype
|
||||
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||
if destination2 is not None:
|
||||
dst_orig_dtype = destination2._params.orig_dtype
|
||||
destination2._params.copy_from(destination._params, non_blocking=True)
|
||||
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
|
||||
return True
|
||||
|
||||
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
|
||||
@ -53,17 +48,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
||||
if info.size == 0:
|
||||
return True
|
||||
|
||||
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
|
||||
if hostbuf is not None:
|
||||
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
||||
device_ptr = destination2.data_ptr() if destination2 is not None else 0
|
||||
hostbuf.read_file_slice(file_obj, info.offset, info.size,
|
||||
offset=destination.data_ptr() - hostbuf.get_raw_address(),
|
||||
stream=stream_ptr,
|
||||
device_ptr=device_ptr,
|
||||
device=None if destination2 is None else destination2.device.index)
|
||||
return True
|
||||
|
||||
buf_type = ctypes.c_ubyte * info.size
|
||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||
|
||||
@ -167,7 +151,7 @@ def set_ram_cache_release_state(callback, headroom):
|
||||
extra_ram_release_callback = callback
|
||||
RAM_CACHE_HEADROOM = max(0, int(headroom))
|
||||
|
||||
def extra_ram_release(target, free_active=False):
|
||||
def extra_ram_release(target):
|
||||
if extra_ram_release_callback is None:
|
||||
return 0
|
||||
return extra_ram_release_callback(target, free_active=free_active)
|
||||
return extra_ram_release_callback(target)
|
||||
|
||||
@ -31,7 +31,6 @@ from contextlib import nullcontext
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.vram_buffer
|
||||
|
||||
class VRAMState(Enum):
|
||||
@ -496,14 +495,6 @@ except:
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
DIRTY_MMAPS = set()
|
||||
|
||||
PIN_PRESSURE_HYSTERESIS = 256 * 1024 * 1024
|
||||
|
||||
#Freeing registerables on pressure does imply a GPU sync, so go big on
|
||||
#the hysteresis so each expensive sync gives us back a good chunk.
|
||||
REGISTERABLE_PIN_HYSTERESIS = 2048 * 1024 * 1024
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
sd = module.state_dict()
|
||||
@ -512,46 +503,27 @@ def module_size(module):
|
||||
module_mem += t.nbytes
|
||||
return module_mem
|
||||
|
||||
def mark_mmap_dirty(storage):
|
||||
mmap_refs = getattr(storage, "_comfy_tensor_mmap_refs", None)
|
||||
if mmap_refs is not None:
|
||||
DIRTY_MMAPS.add(mmap_refs[0])
|
||||
|
||||
def free_pins(size, evict_active=False):
|
||||
freed_total = 0
|
||||
for loaded_model in reversed(current_loaded_models):
|
||||
if size <= 0:
|
||||
return freed_total
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
|
||||
freed = model.partially_unload_ram(size)
|
||||
freed_total += freed
|
||||
size -= freed
|
||||
return freed_total
|
||||
|
||||
def ensure_pin_budget(size, evict_active=False):
|
||||
shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available
|
||||
if shortfall <= 0:
|
||||
return True
|
||||
|
||||
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
|
||||
return free_pins(to_free, evict_active=evict_active) >= shortfall
|
||||
|
||||
def ensure_pin_registerable(size, evict_active=False):
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
if shortfall <= 0:
|
||||
return True
|
||||
|
||||
shortfall += REGISTERABLE_PIN_HYSTERESIS
|
||||
for loaded_model in reversed(current_loaded_models):
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
|
||||
shortfall -= model.unregister_inactive_pins(shortfall)
|
||||
if shortfall <= 0:
|
||||
return True
|
||||
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
||||
def module_mmap_residency(module, free=False):
|
||||
mmap_touched_mem = 0
|
||||
module_mem = 0
|
||||
bounced_mmaps = set()
|
||||
sd = module.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
module_mem += t.nbytes
|
||||
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
|
||||
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
|
||||
continue
|
||||
mmap_touched_mem += t.nbytes
|
||||
if not free:
|
||||
continue
|
||||
storage._comfy_tensor_mmap_touched = False
|
||||
mmap_obj = storage._comfy_tensor_mmap_refs[0]
|
||||
if mmap_obj in bounced_mmaps:
|
||||
continue
|
||||
mmap_obj.bounce()
|
||||
bounced_mmaps.add(mmap_obj)
|
||||
return mmap_touched_mem, module_mem
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
@ -581,6 +553,9 @@ class LoadedModel:
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
|
||||
def model_mmap_residency(self, free=False):
|
||||
return self.model.model_mmap_residency(free=free)
|
||||
|
||||
def model_loaded_memory(self):
|
||||
return self.model.loaded_size()
|
||||
|
||||
@ -660,9 +635,15 @@ WINDOWS = any(platform.win32_ver())
|
||||
|
||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||
if WINDOWS:
|
||||
import comfy.windows
|
||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
||||
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
||||
def get_free_ram():
|
||||
return comfy.windows.get_free_ram()
|
||||
else:
|
||||
def get_free_ram():
|
||||
return psutil.virtual_memory().available
|
||||
|
||||
if args.reserve_vram is not None:
|
||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||
@ -676,6 +657,7 @@ def minimum_inference_memory():
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
|
||||
cleanup_models_gc()
|
||||
comfy.memory_management.extra_ram_release(max(pins_required, ram_required))
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@ -691,9 +673,11 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
for x in can_unload_sorted:
|
||||
i = x[-1]
|
||||
memory_to_free = 1e32
|
||||
if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None):
|
||||
pins_to_free = 1e32
|
||||
if not DISABLE_SMART_MEMORY or device is None:
|
||||
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
|
||||
if for_dynamic:
|
||||
pins_to_free = pins_required - get_free_ram()
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||
@ -701,6 +685,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
unloaded_model.append(i)
|
||||
if pins_to_free > 0:
|
||||
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
|
||||
|
||||
for x in can_unload_sorted:
|
||||
i = x[-1]
|
||||
ram_to_free = ram_required - psutil.virtual_memory().available
|
||||
if ram_to_free <= 0 and i not in unloaded_model:
|
||||
continue
|
||||
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
|
||||
if resident_memory > 0:
|
||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
@ -766,16 +762,29 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
model_to_unload.model.detach(unpatch_all=False)
|
||||
model_to_unload.model_finalizer.detach()
|
||||
|
||||
|
||||
total_memory_required = {}
|
||||
total_pins_required = {}
|
||||
total_ram_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
device = loaded_model.device
|
||||
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
||||
resident_memory, model_memory = loaded_model.model_mmap_residency()
|
||||
pinned_memory = loaded_model.model.pinned_memory_size()
|
||||
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
|
||||
#make this JIT to keep as much pinned as possible.
|
||||
pins_required = model_memory - pinned_memory
|
||||
ram_required = model_memory - resident_memory
|
||||
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
|
||||
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
||||
device,
|
||||
for_dynamic=free_for_dynamic)
|
||||
for_dynamic=free_for_dynamic,
|
||||
pins_required=total_pins_required[device],
|
||||
ram_required=total_ram_required[device])
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
@ -1171,7 +1180,6 @@ STREAM_CAST_BUFFERS = {}
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
STREAM_AIMDO_CAST_BUFFERS = {}
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
STREAM_PIN_BUFFERS = {}
|
||||
|
||||
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
||||
|
||||
@ -1212,66 +1220,21 @@ def get_aimdo_cast_buffer(offload_stream, device):
|
||||
if cast_buffer is None:
|
||||
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
|
||||
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
|
||||
return cast_buffer
|
||||
|
||||
def get_pin_buffer(offload_stream):
|
||||
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
|
||||
if pin_buffer is None:
|
||||
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3))
|
||||
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
|
||||
elif offload_stream is not None:
|
||||
event = getattr(pin_buffer, "_comfy_event", None)
|
||||
if event is not None:
|
||||
event.synchronize()
|
||||
delattr(pin_buffer, "_comfy_event")
|
||||
return pin_buffer
|
||||
|
||||
def resize_pin_buffer(pin_buffer, size):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
old_size = pin_buffer.size
|
||||
if size <= old_size:
|
||||
return True
|
||||
growth = size - old_size
|
||||
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
|
||||
ensure_pin_budget(growth, evict_active=True)
|
||||
ensure_pin_registerable(growth, evict_active=True)
|
||||
try:
|
||||
pin_buffer.extend(size=size, reallocate=True)
|
||||
except RuntimeError:
|
||||
return False
|
||||
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
|
||||
return True
|
||||
|
||||
def reset_cast_buffers():
|
||||
global TOTAL_PINNED_MEMORY
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS):
|
||||
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
|
||||
if offload_stream is not None:
|
||||
offload_stream.synchronize()
|
||||
synchronize()
|
||||
|
||||
for mmap_obj in DIRTY_MMAPS:
|
||||
mmap_obj.bounce()
|
||||
DIRTY_MMAPS.clear()
|
||||
|
||||
for pin_buffer in STREAM_PIN_BUFFERS.values():
|
||||
TOTAL_PINNED_MEMORY -= pin_buffer.size
|
||||
TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY)
|
||||
|
||||
for loaded_model in current_loaded_models:
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic():
|
||||
model.model.dynamic_pins[model.load_device]["active"] = False
|
||||
model.partially_unload_ram(1e30, subsets=[ "patches" ])
|
||||
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0])
|
||||
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||
STREAM_PIN_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
@ -1317,7 +1280,7 @@ def sync_stream(device, stream):
|
||||
current_stream(device).wait_stream(stream)
|
||||
|
||||
|
||||
def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
|
||||
def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
wf_context = nullcontext()
|
||||
if stream is not None:
|
||||
wf_context = stream
|
||||
@ -1325,20 +1288,17 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
|
||||
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
|
||||
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
|
||||
with wf_context:
|
||||
for tensor in tensors:
|
||||
dest_view = dest_views.pop(0)
|
||||
dest2_view = dest2_views.pop(0) if dest2_views is not None else None
|
||||
if tensor is None:
|
||||
continue
|
||||
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view, stream=stream, destination2=dest2_view):
|
||||
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
|
||||
continue
|
||||
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
||||
mark_mmap_dirty(storage)
|
||||
if hasattr(storage, "_comfy_tensor_mmap_touched"):
|
||||
storage._comfy_tensor_mmap_touched = True
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
if dest2_view is not None:
|
||||
dest2_view.copy_(dest_view, non_blocking=non_blocking)
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||
@ -1379,18 +1339,14 @@ TOTAL_PINNED_MEMORY = 0
|
||||
MAX_PINNED_MEMORY = -1
|
||||
if not args.disable_pinned_memory:
|
||||
if is_nvidia() or is_amd():
|
||||
ram = get_total_memory(torch.device("cpu"))
|
||||
if WINDOWS:
|
||||
MAX_PINNED_MEMORY = ram * 0.40 # Windows limit is apparently 50%
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50%
|
||||
else:
|
||||
MAX_PINNED_MEMORY = ram * 0.90
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||
|
||||
def pinned_hostbuf_size(size):
|
||||
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
|
||||
|
||||
def discard_cuda_async_error():
|
||||
try:
|
||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
@ -1422,8 +1378,8 @@ def pin_memory(tensor):
|
||||
return False
|
||||
|
||||
size = tensor.nbytes
|
||||
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
|
||||
ensure_pin_registerable(size)
|
||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||
return False
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
if ptr == 0:
|
||||
@ -1460,8 +1416,7 @@ def unpin_memory(tensor):
|
||||
return False
|
||||
|
||||
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
||||
size = PINNED_MEMORY.pop(ptr)
|
||||
TOTAL_PINNED_MEMORY -= size
|
||||
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
|
||||
return True
|
||||
else:
|
||||
logging.warning("Unpin error.")
|
||||
|
||||
@ -35,7 +35,6 @@ import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
import comfy_aimdo.host_buffer
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
@ -118,8 +117,6 @@ def string_to_seed(data):
|
||||
return comfy.utils.string_to_seed(data)
|
||||
|
||||
class LowVramPatch:
|
||||
is_lowvram_patch = True
|
||||
|
||||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||
self.key = key
|
||||
self.patches = patches
|
||||
@ -127,21 +124,11 @@ class LowVramPatch:
|
||||
self.set_func = set_func
|
||||
self.prepared_patches = None
|
||||
|
||||
def memory_required(self):
|
||||
counter = [0]
|
||||
for patch in self.patches[self.key]:
|
||||
comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False)
|
||||
return counter[0]
|
||||
|
||||
def prepare(self, destination, stream, copy=True, commit=True):
|
||||
counter = [0]
|
||||
prepared_patches = [
|
||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4])
|
||||
def prepare(self, allocate_buffer, stream):
|
||||
self.prepared_patches = [
|
||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
|
||||
for patch in self.patches[self.key]
|
||||
]
|
||||
if commit:
|
||||
self.prepared_patches = prepared_patches
|
||||
return prepared_patches
|
||||
|
||||
def clear_prepared(self):
|
||||
self.prepared_patches = None
|
||||
@ -354,6 +341,9 @@ class ModelPatcher:
|
||||
self.size = comfy.model_management.module_size(self.model)
|
||||
return self.size
|
||||
|
||||
def model_mmap_residency(self, free=False):
|
||||
return comfy.model_management.module_mmap_residency(self.model, free=free)
|
||||
|
||||
def loaded_size(self):
|
||||
return self.model.model_loaded_weight_memory
|
||||
|
||||
@ -1128,12 +1118,8 @@ class ModelPatcher:
|
||||
# Pinned memory pressure tracking is only implemented for DynamicVram loading
|
||||
return 0
|
||||
|
||||
def loaded_ram_size(self):
|
||||
# Loaded RAM pressure tracking is only implemented for DynamicVram loading
|
||||
return 0
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
return 0
|
||||
pass
|
||||
|
||||
def detach(self, unpatch_all=True):
|
||||
self.eject_model()
|
||||
@ -1564,16 +1550,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
||||
if not hasattr(self.model, "dynamic_vbars"):
|
||||
self.model.dynamic_vbars = {}
|
||||
if not hasattr(self.model, "dynamic_pins"):
|
||||
self.model.dynamic_pins = {}
|
||||
if self.load_device not in self.model.dynamic_pins:
|
||||
self.model.dynamic_pins[self.load_device] = {
|
||||
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
|
||||
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
|
||||
"hostbufs_initialized": False,
|
||||
"failed": False,
|
||||
"active": False,
|
||||
}
|
||||
self.non_dynamic_delegate_model = None
|
||||
assert load_device is not None
|
||||
|
||||
@ -1613,16 +1589,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
#use all ModelPatcherDynamic this is ignored and its all done dynamically.
|
||||
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
|
||||
|
||||
def restore_loaded_backups(self):
|
||||
restored = self.model.model_loaded_weight_memory
|
||||
for key in list(self.backup.keys()):
|
||||
bk = self.backup.pop(key)
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
for key in list(self.backup_buffers.keys()):
|
||||
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
return restored
|
||||
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
|
||||
|
||||
@ -1639,20 +1605,12 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
num_patches = 0
|
||||
allocated_size = 0
|
||||
self.restore_loaded_backups()
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
|
||||
vbar = self._vbar_get(create=True)
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
if not pin_state["hostbufs_initialized"]:
|
||||
hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size())
|
||||
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0])
|
||||
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0])
|
||||
pin_state["hostbufs_initialized"] = True
|
||||
pin_state["failed"] = False
|
||||
pin_state["active"] = True
|
||||
if vbar is not None:
|
||||
vbar.prioritize()
|
||||
|
||||
@ -1678,9 +1636,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
if key in self.patches:
|
||||
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
|
||||
return (True, 0)
|
||||
lowvram_patch = LowVramPatch(key, self.patches)
|
||||
lowvram_patch._pin_state = pin_state
|
||||
setattr(m, param_key + "_lowvram_function", lowvram_patch)
|
||||
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
||||
num_patches += 1
|
||||
else:
|
||||
setattr(m, param_key + "_lowvram_function", None)
|
||||
@ -1697,9 +1653,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
def force_load_param(self, param_key, device_to):
|
||||
key = key_param_name_to_key(n, param_key)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if weight is None:
|
||||
return
|
||||
if key in self.backup:
|
||||
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||
self.patch_weight_to_device(key, device_to=device_to, force_cast=True)
|
||||
@ -1709,26 +1662,17 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.comfy_cast_weights = True
|
||||
m.pin_failed = False
|
||||
m.seed_key = n
|
||||
m._pin_state = pin_state
|
||||
set_dirty(m, dirty)
|
||||
|
||||
#Models that mix tiny and giant weights can causing lopsided stream buffer
|
||||
#rotations and stall. force the tinys over.
|
||||
if module_mem > 16 * 1024:
|
||||
force_load, v_weight_size = setup_param(self, m, n, "weight")
|
||||
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
|
||||
force_load = force_load or force_load_bias
|
||||
v_weight_size += v_weight_bias
|
||||
if force_load:
|
||||
logging.info(f"Module {n} has resizing Lora - force loading")
|
||||
else:
|
||||
force_load=True
|
||||
force_load, v_weight_size = setup_param(self, m, n, "weight")
|
||||
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
|
||||
force_load = force_load or force_load_bias
|
||||
v_weight_size += v_weight_bias
|
||||
|
||||
if force_load:
|
||||
if hasattr(m, "_v"):
|
||||
comfy_aimdo.model_vbar.vbar_unpin(m._v)
|
||||
delattr(m, "_v")
|
||||
logging.info(f"Module {n} has resizing Lora - force loading")
|
||||
force_load_param(self, "weight", device_to)
|
||||
force_load_param(self, "bias", device_to)
|
||||
else:
|
||||
@ -1786,62 +1730,33 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||
|
||||
if freed < memory_to_free:
|
||||
freed += self.restore_loaded_backups()
|
||||
for key in list(self.backup.keys()):
|
||||
bk = self.backup.pop(key)
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
for key in list(self.backup_buffers.keys()):
|
||||
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
|
||||
freed += self.model.model_loaded_weight_memory
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
return freed
|
||||
|
||||
def loaded_ram_size(self):
|
||||
return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
|
||||
self.model.dynamic_pins[self.load_device]["patches"][0].size)
|
||||
|
||||
def pinned_memory_size(self):
|
||||
return (self.model.dynamic_pins[self.load_device]["weights"][3][0] +
|
||||
self.model.dynamic_pins[self.load_device]["patches"][3][0])
|
||||
total = 0
|
||||
loading = self._load_list(for_dynamic=True)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
pin = comfy.pinned_memory.get_pin(m)
|
||||
if pin is not None:
|
||||
total += pin.numel() * pin.element_size()
|
||||
return total
|
||||
|
||||
def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]):
|
||||
freed = 0
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
for subset in subsets:
|
||||
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
|
||||
split = stack_split[0]
|
||||
while split >= 0:
|
||||
module, offset = stack[split]
|
||||
split -= 1
|
||||
stack_split[0] = split
|
||||
if not module._pin_registered:
|
||||
continue
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
if torch.cuda.cudart().cudaHostUnregister(module._pin.data_ptr()) != 0:
|
||||
comfy.model_management.discard_cuda_async_error()
|
||||
continue
|
||||
module._pin_registered = False
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size)
|
||||
pinned_size[0] = max(0, pinned_size[0] - size)
|
||||
freed += size
|
||||
ram_to_unload -= size
|
||||
if ram_to_unload <= 0:
|
||||
return freed
|
||||
return freed
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]):
|
||||
freed = 0
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
for subset in subsets:
|
||||
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
|
||||
while len(stack) > 0:
|
||||
module, offset = stack.pop()
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
del module._pin
|
||||
hostbuf.truncate(offset, do_unregister=module._pin_registered)
|
||||
stack_split[0] = min(stack_split[0], len(stack) - 1)
|
||||
if module._pin_registered:
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size)
|
||||
pinned_size[0] = max(0, pinned_size[0] - size)
|
||||
freed += size
|
||||
ram_to_unload -= size
|
||||
if ram_to_unload <= 0:
|
||||
return freed
|
||||
return freed
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
||||
for x in loading:
|
||||
*_, m, _ = x
|
||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||
if ram_to_unload <= 0:
|
||||
return
|
||||
|
||||
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||
#This isn't used by the core at all and can only be to load a model out of
|
||||
|
||||
88
comfy/ops.py
88
comfy/ops.py
@ -75,8 +75,6 @@ except:
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
@ -93,9 +91,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
offload_stream = None
|
||||
cast_buffer = None
|
||||
cast_buffer_offset = 0
|
||||
stream_pin_hostbuf = None
|
||||
stream_pin_offset = 0
|
||||
stream_pin_queue = []
|
||||
|
||||
def ensure_offload_stream(module, required_size, check_largest):
|
||||
nonlocal offload_stream
|
||||
@ -129,22 +124,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
cast_buffer_offset += buffer_size
|
||||
return buffer
|
||||
|
||||
def get_stream_pin_buffer_offset(buffer_size):
|
||||
nonlocal stream_pin_hostbuf
|
||||
nonlocal stream_pin_offset
|
||||
|
||||
if buffer_size == 0 or offload_stream is None:
|
||||
return None
|
||||
|
||||
if stream_pin_hostbuf is None:
|
||||
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
|
||||
if stream_pin_hostbuf is None:
|
||||
return None
|
||||
|
||||
offset = stream_pin_offset
|
||||
stream_pin_offset += buffer_size
|
||||
return offset
|
||||
|
||||
for s in comfy_modules:
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
@ -183,47 +162,23 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
if xfer_dest is None:
|
||||
xfer_dest = get_cast_buffer(dest_size)
|
||||
|
||||
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
|
||||
if xfer_source is not None:
|
||||
if getattr(xfer_source, "is_lowvram_patch", False):
|
||||
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
|
||||
else:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
|
||||
if signature is None and pin is None:
|
||||
comfy.pinned_memory.pin_memory(s)
|
||||
pin = comfy.pinned_memory.get_pin(s)
|
||||
else:
|
||||
pin = None
|
||||
|
||||
def handle_pin(m, pin, source, dest, subset="weights", size=None):
|
||||
if pin is not None:
|
||||
cast_maybe_lowvram_patch([pin], dest, offload_stream)
|
||||
return
|
||||
if signature is None:
|
||||
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||
if pin is not None:
|
||||
if isinstance(source, list):
|
||||
comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest)
|
||||
else:
|
||||
cast_maybe_lowvram_patch(source, pin, None)
|
||||
cast_maybe_lowvram_patch([ pin ], dest, offload_stream)
|
||||
return
|
||||
if pin is None:
|
||||
pin_offset = get_stream_pin_buffer_offset(size)
|
||||
if pin_offset is not None:
|
||||
stream_pin_queue.append((source, pin_offset, size, dest))
|
||||
return
|
||||
cast_maybe_lowvram_patch(source, dest, offload_stream)
|
||||
|
||||
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
|
||||
if pin is not None:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, pin)
|
||||
xfer_source = [ pin ]
|
||||
#send it over
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_source = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_source is not None:
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
ensure_offload_stream(s, cast_buffer_offset, False)
|
||||
lowvram_size = lowvram_source.memory_required()
|
||||
lowvram_dest = get_cast_buffer(lowvram_size)
|
||||
lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True)
|
||||
|
||||
pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches")
|
||||
handle_pin(lowvram_source, pin, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size)
|
||||
|
||||
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
|
||||
|
||||
prefetch["xfer_dest"] = xfer_dest
|
||||
prefetch["cast_dest"] = cast_dest
|
||||
@ -231,23 +186,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
prefetch["needs_cast"] = needs_cast
|
||||
s._prefetch = prefetch
|
||||
|
||||
if stream_pin_offset > 0:
|
||||
if stream_pin_hostbuf.size < stream_pin_offset:
|
||||
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
|
||||
for xfer_source, _, _, xfer_dest in stream_pin_queue:
|
||||
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
|
||||
return offload_stream
|
||||
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf)
|
||||
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
|
||||
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
|
||||
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
|
||||
if isinstance(xfer_source, list):
|
||||
comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest)
|
||||
else:
|
||||
cast_maybe_lowvram_patch(xfer_source, pin, None)
|
||||
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
stream_pin_hostbuf._comfy_event = offload_stream.record_event()
|
||||
|
||||
return offload_stream
|
||||
|
||||
|
||||
|
||||
@ -2,62 +2,42 @@ import comfy.model_management
|
||||
import comfy.memory_management
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.torch
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
def get_pin(module, subset="weights"):
|
||||
pin = getattr(module, "_pin", None)
|
||||
if pin is None or module._pin_registered or args.disable_pinned_memory:
|
||||
return pin
|
||||
def get_pin(module):
|
||||
return getattr(module, "_pin", None)
|
||||
|
||||
_, _, stack_split, pinned_size = module._pin_state[subset]
|
||||
size = pin.nbytes
|
||||
comfy.model_management.ensure_pin_registerable(size)
|
||||
|
||||
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
|
||||
comfy.model_management.discard_cuda_async_error()
|
||||
return pin
|
||||
|
||||
module._pin_registered = True
|
||||
stack_split[0] = max(stack_split[0], module._pin_stack_index)
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
||||
pinned_size[0] += size
|
||||
return pin
|
||||
|
||||
def pin_memory(module, subset="weights", size=None):
|
||||
pin_state = module._pin_state
|
||||
if args.disable_pinned_memory:
|
||||
def pin_memory(module):
|
||||
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
||||
return
|
||||
|
||||
pin = get_pin(module, subset)
|
||||
if pin is not None or pin_state["failed"]:
|
||||
return
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
|
||||
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
|
||||
if size is None:
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
offset = hostbuf.size
|
||||
registerable_size = size + max(0, hostbuf.size - pinned_size[0])
|
||||
|
||||
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
|
||||
if (not comfy.model_management.ensure_pin_budget(size) or
|
||||
not comfy.model_management.ensure_pin_registerable(registerable_size)):
|
||||
pin_state["failed"] = True
|
||||
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
||||
module.pin_failed = True
|
||||
return False
|
||||
|
||||
try:
|
||||
hostbuf.extend(size=size)
|
||||
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
|
||||
except RuntimeError:
|
||||
pin_state["failed"] = True
|
||||
module.pin_failed = True
|
||||
return False
|
||||
|
||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
|
||||
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
|
||||
stack.append((module, offset))
|
||||
module._pin_registered = True
|
||||
module._pin_stack_index = len(stack) - 1
|
||||
stack_split[0] = max(stack_split[0], module._pin_stack_index)
|
||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
|
||||
module._pin_hostbuf = hostbuf
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
||||
pinned_size[0] += size
|
||||
return True
|
||||
|
||||
def unpin_memory(module):
|
||||
if get_pin(module) is None:
|
||||
return 0
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY -= size
|
||||
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY = 0
|
||||
|
||||
del module._pin
|
||||
del module._pin_hostbuf
|
||||
return size
|
||||
|
||||
@ -113,6 +113,7 @@ def load_safetensors(ckpt):
|
||||
"_comfy_tensor_file_slice",
|
||||
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
||||
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
||||
setattr(storage, "_comfy_tensor_mmap_touched", False)
|
||||
sd[name] = tensor
|
||||
|
||||
return sd, header.get("__metadata__", {}),
|
||||
@ -1019,11 +1020,10 @@ def bislerp(samples, width, height):
|
||||
|
||||
def lanczos(samples, width, height):
|
||||
#the below API is strict and expects grayscale to be squeezed
|
||||
if samples.ndim == 4:
|
||||
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
|
||||
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
|
||||
images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||
images = [torch.from_numpy(t).movedim(-1, 0) if (t := np.array(image).astype(np.float32) / 255.0).ndim == 3 else torch.from_numpy(t) for image in images]
|
||||
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
||||
result = torch.stack(images)
|
||||
return result.to(samples.device, samples.dtype)
|
||||
|
||||
@ -1451,3 +1451,4 @@ def deepcopy_list_dict(obj, memo=None):
|
||||
|
||||
memo[obj_id] = res
|
||||
return res
|
||||
|
||||
|
||||
52
comfy/windows.py
Normal file
52
comfy/windows.py
Normal file
@ -0,0 +1,52 @@
|
||||
import ctypes
|
||||
import logging
|
||||
import psutil
|
||||
from ctypes import wintypes
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
psapi = ctypes.WinDLL("psapi")
|
||||
kernel32 = ctypes.WinDLL("kernel32")
|
||||
|
||||
class PERFORMANCE_INFORMATION(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("cb", wintypes.DWORD),
|
||||
("CommitTotal", ctypes.c_size_t),
|
||||
("CommitLimit", ctypes.c_size_t),
|
||||
("CommitPeak", ctypes.c_size_t),
|
||||
("PhysicalTotal", ctypes.c_size_t),
|
||||
("PhysicalAvailable", ctypes.c_size_t),
|
||||
("SystemCache", ctypes.c_size_t),
|
||||
("KernelTotal", ctypes.c_size_t),
|
||||
("KernelPaged", ctypes.c_size_t),
|
||||
("KernelNonpaged", ctypes.c_size_t),
|
||||
("PageSize", ctypes.c_size_t),
|
||||
("HandleCount", wintypes.DWORD),
|
||||
("ProcessCount", wintypes.DWORD),
|
||||
("ThreadCount", wintypes.DWORD),
|
||||
]
|
||||
|
||||
def get_free_ram():
|
||||
#Windows is way too conservative and chalks recently used uncommitted model RAM
|
||||
#as "in-use". So, calculate free RAM for the sake of general use as the greater of:
|
||||
#
|
||||
#1: What psutil says
|
||||
#2: Total Memory - (Committed Memory - VRAM in use)
|
||||
#
|
||||
#We have to subtract VRAM in use from the comitted memory as WDDM creates a naked
|
||||
#commit charge for all VRAM used just incase it wants to page it all out. This just
|
||||
#isn't realistic so "overcommit" on our calculations by just subtracting it off.
|
||||
|
||||
pi = PERFORMANCE_INFORMATION()
|
||||
pi.cb = ctypes.sizeof(pi)
|
||||
|
||||
if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb):
|
||||
logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal")
|
||||
return psutil.virtual_memory().available
|
||||
|
||||
committed = pi.CommitTotal * pi.PageSize
|
||||
total = pi.PhysicalTotal * pi.PageSize
|
||||
|
||||
return max(psutil.virtual_memory().available,
|
||||
total - (committed - comfy_aimdo.control.get_total_vram_usage()))
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@ -11,44 +9,76 @@ class Rodin3DGenerateRequest(BaseModel):
|
||||
material: str = Field(..., description="The material type.")
|
||||
quality_override: int = Field(..., description="The poly count of the mesh.")
|
||||
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
|
||||
TAPose: Optional[bool] = Field(None, description="")
|
||||
TAPose: bool | None = Field(None, description="")
|
||||
|
||||
|
||||
class Rodin3DGen25Request(BaseModel):
|
||||
|
||||
tier: str = Field(..., description="Gen-2.5 tier (e.g. Gen-2.5-High).")
|
||||
prompt: str | None = Field(None, description="Required for Text-to-3D; ignored otherwise.")
|
||||
seed: int | None = Field(None, description="0-65535.")
|
||||
material: str | None = Field(None, description="PBR | Shaded | All | None.")
|
||||
geometry_file_format: str | None = Field(None, description="glb | usdz | fbx | obj | stl.")
|
||||
texture_mode: str | None = Field(None, description="legacy | extreme-low | low | medium | high.")
|
||||
mesh_mode: str | None = Field(None, description="Raw (triangular) | Quad.")
|
||||
quality_override: int | None = Field(None, description="Mesh face count override.")
|
||||
geometry_instruct_mode: str | None = Field(None, description="faithful | creative.")
|
||||
bbox_condition: list[int] | None = Field(None, description="Bounding box [Width(Y), Height(Z), Length(X)] in cm.")
|
||||
height: int | None = Field(None, description="Approximate model height in cm.")
|
||||
TAPose: bool | None = Field(None, description="T/A pose for human-like models.")
|
||||
hd_texture: bool | None = Field(None, description="Enhanced texture quality.")
|
||||
texture_delight: bool | None = Field(None, description="Remove baked lighting from textures.")
|
||||
is_micro: bool | None = Field(None, description="Micro detail (Extreme-High only).")
|
||||
use_original_alpha: bool | None = Field(None, description="Preserve image transparency.")
|
||||
preview_render: bool | None = Field(None, description="Generate high-quality preview render.")
|
||||
addons: list[str] | None = Field(None, description='Optional addons, e.g. ["HighPack"].')
|
||||
|
||||
|
||||
class GenerateJobsData(BaseModel):
|
||||
uuids: List[str] = Field(..., description="str LIST")
|
||||
uuids: list[str] = Field(..., description="str LIST")
|
||||
subscription_key: str = Field(..., description="subscription key")
|
||||
|
||||
|
||||
class Rodin3DGenerateResponse(BaseModel):
|
||||
message: Optional[str] = Field(None, description="Return message.")
|
||||
prompt: Optional[str] = Field(None, description="Generated Prompt from image.")
|
||||
submit_time: Optional[str] = Field(None, description="Submit Time")
|
||||
uuid: Optional[str] = Field(None, description="Task str")
|
||||
jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs")
|
||||
message: str | None = Field(None, description="Return message.")
|
||||
prompt: str | None = Field(None, description="Generated Prompt from image.")
|
||||
submit_time: str | None = Field(None, description="Submit Time")
|
||||
uuid: str | None = Field(None, description="Task str")
|
||||
jobs: GenerateJobsData | None = Field(None, description="Details of jobs")
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
"""
|
||||
Status for jobs
|
||||
"""
|
||||
|
||||
Done = "Done"
|
||||
Failed = "Failed"
|
||||
Generating = "Generating"
|
||||
Waiting = "Waiting"
|
||||
|
||||
|
||||
class Rodin3DCheckStatusRequest(BaseModel):
|
||||
subscription_key: str = Field(..., description="subscription from generate endpoint")
|
||||
|
||||
|
||||
class JobItem(BaseModel):
|
||||
uuid: str = Field(..., description="uuid")
|
||||
status: JobStatus = Field(...,description="Status Currently")
|
||||
status: JobStatus = Field(..., description="Status Currently")
|
||||
|
||||
|
||||
class Rodin3DCheckStatusResponse(BaseModel):
|
||||
jobs: List[JobItem] = Field(..., description="Job status List")
|
||||
jobs: list[JobItem] = Field(..., description="Job status List")
|
||||
|
||||
|
||||
class Rodin3DDownloadRequest(BaseModel):
|
||||
task_uuid: str = Field(..., description="Task str")
|
||||
|
||||
|
||||
class RodinResourceItem(BaseModel):
|
||||
url: str = Field(..., description="Download Url")
|
||||
name: str = Field(..., description="File name with ext")
|
||||
|
||||
|
||||
class Rodin3DDownloadResponse(BaseModel):
|
||||
list: List[RodinResourceItem] = Field(..., description="Source List")
|
||||
items: list[RodinResourceItem] = Field(..., alias="list", description="Source List")
|
||||
|
||||
@ -5,32 +5,37 @@ Rodin API docs: https://developer.hyper3d.ai/
|
||||
|
||||
"""
|
||||
|
||||
from inspect import cleandoc
|
||||
import folder_paths as comfy_paths
|
||||
import os
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from inspect import cleandoc
|
||||
from io import BytesIO
|
||||
from typing_extensions import override
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from PIL import Image
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths as comfy_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Types
|
||||
from comfy_api_nodes.apis.rodin import (
|
||||
Rodin3DGenerateRequest,
|
||||
Rodin3DGenerateResponse,
|
||||
JobStatus,
|
||||
Rodin3DCheckStatusRequest,
|
||||
Rodin3DCheckStatusResponse,
|
||||
Rodin3DDownloadRequest,
|
||||
Rodin3DDownloadResponse,
|
||||
JobStatus,
|
||||
Rodin3DGen25Request,
|
||||
Rodin3DGenerateRequest,
|
||||
Rodin3DGenerateResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
sync_op,
|
||||
poll_op,
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
poll_op,
|
||||
sync_op,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
|
||||
|
||||
COMMON_PARAMETERS = [
|
||||
IO.Int.Input(
|
||||
@ -51,40 +56,30 @@ COMMON_PARAMETERS = [
|
||||
]
|
||||
|
||||
|
||||
def get_quality_mode(poly_count):
|
||||
polycount = poly_count.split("-")
|
||||
poly = polycount[1]
|
||||
count = polycount[0]
|
||||
if poly == "Triangle":
|
||||
mesh_mode = "Raw"
|
||||
elif poly == "Quad":
|
||||
mesh_mode = "Quad"
|
||||
else:
|
||||
mesh_mode = "Quad"
|
||||
|
||||
if count == "4K":
|
||||
quality_override = 4000
|
||||
elif count == "8K":
|
||||
quality_override = 8000
|
||||
elif count == "18K":
|
||||
quality_override = 18000
|
||||
elif count == "50K":
|
||||
quality_override = 50000
|
||||
elif count == "2K":
|
||||
quality_override = 2000
|
||||
elif count == "20K":
|
||||
quality_override = 20000
|
||||
elif count == "150K":
|
||||
quality_override = 150000
|
||||
elif count == "500K":
|
||||
quality_override = 500000
|
||||
else:
|
||||
quality_override = 18000
|
||||
|
||||
return mesh_mode, quality_override
|
||||
_QUALITY_MESH_OPTIONS: dict[str, tuple[str, int]] = {
|
||||
"4K-Quad": ("Quad", 4000),
|
||||
"8K-Quad": ("Quad", 8000),
|
||||
"18K-Quad": ("Quad", 18000),
|
||||
"50K-Quad": ("Quad", 50000),
|
||||
"200K-Quad": ("Quad", 200000),
|
||||
"2K-Triangle": ("Raw", 2000),
|
||||
"20K-Triangle": ("Raw", 20000),
|
||||
"150K-Triangle": ("Raw", 150000),
|
||||
"200K-Triangle": ("Raw", 200000),
|
||||
"500K-Triangle": ("Raw", 500000),
|
||||
"1M-Triangle": ("Raw", 1000000),
|
||||
}
|
||||
|
||||
|
||||
def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
|
||||
def get_quality_mode(poly_count: str) -> tuple[str, int]:
|
||||
"""Map a polygon-count preset like '18K-Quad' to (mesh_mode, quality_override).
|
||||
|
||||
Falls back to ('Quad', 18000) for unknown labels; legacy parity.
|
||||
"""
|
||||
return _QUALITY_MESH_OPTIONS.get(poly_count, ("Quad", 18000))
|
||||
|
||||
|
||||
def tensor_to_filelike(tensor, max_pixels: int = 2048 * 2048):
|
||||
"""
|
||||
Converts a PyTorch tensor to a file-like object.
|
||||
|
||||
@ -96,8 +91,8 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
|
||||
- io.BytesIO: A file-like object containing the image data.
|
||||
"""
|
||||
array = tensor.cpu().numpy()
|
||||
array = (array * 255).astype('uint8')
|
||||
image = Image.fromarray(array, 'RGB')
|
||||
array = (array * 255).astype("uint8")
|
||||
image = Image.fromarray(array, "RGB")
|
||||
|
||||
original_width, original_height = image.size
|
||||
original_pixels = original_width * original_height
|
||||
@ -112,7 +107,7 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
|
||||
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
img_byte_arr = BytesIO()
|
||||
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
|
||||
image.save(img_byte_arr, format="PNG") # PNG is used for lossless compression
|
||||
img_byte_arr.seek(0)
|
||||
return img_byte_arr
|
||||
|
||||
@ -145,11 +140,9 @@ async def create_generate_task(
|
||||
TAPose=ta_pose,
|
||||
),
|
||||
files=[
|
||||
(
|
||||
"images",
|
||||
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image)
|
||||
)
|
||||
for image in images if image is not None
|
||||
("images", open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image))
|
||||
for image in images
|
||||
if image is not None
|
||||
],
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
@ -177,6 +170,7 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
||||
return "DONE"
|
||||
return "Generating"
|
||||
|
||||
|
||||
def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None:
|
||||
if not response.jobs:
|
||||
return None
|
||||
@ -214,7 +208,7 @@ async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.Fi
|
||||
model_file_path = None
|
||||
file_3d = None
|
||||
|
||||
for i in url_list.list:
|
||||
for i in url_list.items:
|
||||
file_path = os.path.join(save_path, i.name)
|
||||
if i.name.lower().endswith(".glb"):
|
||||
model_file_path = os.path.join(result_folder_name, i.name)
|
||||
@ -489,7 +483,16 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
IO.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
|
||||
IO.Combo.Input(
|
||||
"Polygon_count",
|
||||
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
|
||||
options=[
|
||||
"4K-Quad",
|
||||
"8K-Quad",
|
||||
"18K-Quad",
|
||||
"50K-Quad",
|
||||
"2K-Triangle",
|
||||
"20K-Triangle",
|
||||
"150K-Triangle",
|
||||
"500K-Triangle",
|
||||
],
|
||||
default="500K-Triangle",
|
||||
optional=True,
|
||||
),
|
||||
@ -542,6 +545,566 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
def _rodin_multipart_parser(data: dict[str, Any]) -> aiohttp.FormData:
|
||||
"""Convert a Rodin request dict to an aiohttp form, fixing bool/list serialization.
|
||||
|
||||
Booleans --> "true"/"false". Lists --> one field per element.
|
||||
"""
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
for key, value in data.items():
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, bool):
|
||||
form.add_field(key, "true" if value else "false")
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
form.add_field(key, str(item))
|
||||
elif isinstance(value, (bytes, bytearray)):
|
||||
form.add_field(key, value)
|
||||
else:
|
||||
form.add_field(key, str(value))
|
||||
return form
|
||||
|
||||
|
||||
async def _create_gen25_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: Rodin3DGen25Request,
|
||||
images: list | None,
|
||||
) -> tuple[str, str]:
|
||||
"""Submit a Gen-2.5 generate job; returns (task_uuid, subscription_key)."""
|
||||
|
||||
if images is not None and len(images) > 5:
|
||||
raise ValueError("Rodin Gen-2.5 supports at most 5 input images.")
|
||||
|
||||
files = None
|
||||
if images:
|
||||
files = [
|
||||
(
|
||||
"images",
|
||||
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image),
|
||||
)
|
||||
for image in images
|
||||
if image is not None
|
||||
]
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"),
|
||||
response_model=Rodin3DGenerateResponse,
|
||||
data=request,
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
multipart_parser=_rodin_multipart_parser,
|
||||
)
|
||||
|
||||
if not response.uuid or not response.jobs or not response.jobs.subscription_key:
|
||||
raise RuntimeError(f"Rodin Gen-2.5 submit failed: message={response.message!r}")
|
||||
return response.uuid, response.jobs.subscription_key
|
||||
|
||||
|
||||
_PREVIEWABLE_3D_EXTS = {".glb", ".obj", ".fbx", ".stl", ".gltf"}
|
||||
|
||||
|
||||
async def _download_gen25_files(
|
||||
download_list: Rodin3DDownloadResponse,
|
||||
task_uuid: str,
|
||||
geometry_file_format: str,
|
||||
) -> Types.File3D | None:
|
||||
"""Download every file in the list; return the File3D matching the chosen format."""
|
||||
|
||||
folder_name = f"Rodin3D_Gen25_{task_uuid}"
|
||||
save_dir = os.path.join(comfy_paths.get_output_directory(), folder_name)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
target_ext = f".{geometry_file_format.lower().lstrip('.')}"
|
||||
file_3d: Types.File3D | None = None
|
||||
|
||||
for item in download_list.items:
|
||||
file_path = os.path.join(save_dir, item.name)
|
||||
ext = os.path.splitext(item.name.lower())[1]
|
||||
# Prefer the file matching the user's chosen format; fall back below.
|
||||
if file_3d is None and ext == target_ext and ext in _PREVIEWABLE_3D_EXTS:
|
||||
file_3d = await download_url_to_file_3d(item.url, target_ext.lstrip("."))
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_3d.get_bytes())
|
||||
continue
|
||||
await download_url_to_bytesio(item.url, file_path)
|
||||
|
||||
# If the chosen format wasn't found, surface any model file we did get.
|
||||
if file_3d is None:
|
||||
for item in download_list.items:
|
||||
ext = os.path.splitext(item.name.lower())[1]
|
||||
if ext in _PREVIEWABLE_3D_EXTS:
|
||||
file_3d = await download_url_to_file_3d(item.url, ext.lstrip("."))
|
||||
break
|
||||
return file_3d
|
||||
|
||||
|
||||
_MODE_REGULAR = "Regular"
|
||||
_MODE_FAST = "Fast"
|
||||
_MODE_EXTREME_HIGH = "Extreme-High"
|
||||
|
||||
_REGULAR_POLY_OPTIONS = [
|
||||
"Default",
|
||||
"4K-Quad",
|
||||
"8K-Quad",
|
||||
"18K-Quad",
|
||||
"50K-Quad",
|
||||
"2K-Triangle",
|
||||
"20K-Triangle",
|
||||
"150K-Triangle",
|
||||
"500K-Triangle",
|
||||
"1M-Triangle",
|
||||
]
|
||||
|
||||
_TEXTURE_MODE_OPTIONS = ["Default", "legacy", "extreme-low", "low", "medium", "high"]
|
||||
_GEOMETRY_FORMAT_OPTIONS = ["glb", "fbx", "obj", "stl"]
|
||||
_MATERIAL_OPTIONS = ["PBR", "Shaded", "All", "None"]
|
||||
|
||||
|
||||
def _build_mode_input(name: str = "mode") -> IO.DynamicCombo.Input:
|
||||
return IO.DynamicCombo.Input(
|
||||
name,
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
_MODE_REGULAR,
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"tier",
|
||||
options=["Gen-2.5-Low", "Gen-2.5-Medium", "Gen-2.5-High"],
|
||||
default="Gen-2.5-High",
|
||||
tooltip="Quality tier. Higher tiers produce higher-fidelity geometry.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"polygon_count",
|
||||
options=_REGULAR_POLY_OPTIONS,
|
||||
default="Default",
|
||||
tooltip="Preset face count. 'Default' uses the server's default for the selected tier.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"creative",
|
||||
default=False,
|
||||
tooltip="Creative mode (Medium/High only). Enhances generative robustness.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_MODE_FAST,
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"tier",
|
||||
options=[
|
||||
"Gen-2.5-Extreme-Low",
|
||||
"Gen-2.5-Low",
|
||||
"Gen-2.5-Medium",
|
||||
"Gen-2.5-High",
|
||||
],
|
||||
default="Gen-2.5-Low",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"mesh_faces",
|
||||
default=20000,
|
||||
min=1000,
|
||||
max=20000,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Mesh face count (1K-20K in Fast mode).",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_MODE_EXTREME_HIGH,
|
||||
[
|
||||
IO.Combo.Input("mesh_mode", options=["Raw", "Quad"], default="Raw"),
|
||||
IO.Int.Input(
|
||||
"mesh_faces",
|
||||
default=1000000,
|
||||
min=20000,
|
||||
max=2000000,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip=(
|
||||
"Mesh face count. Raw mode: 20K-2M. "
|
||||
"Quad mode: keep under 200K (upstream may reject higher values)."
|
||||
),
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"is_micro",
|
||||
default=False,
|
||||
tooltip="Enable micro detail (Extreme-High only).",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"creative",
|
||||
default=False,
|
||||
tooltip="Creative mode. Enhances generative robustness.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip=(
|
||||
"Generation mode. Regular = balanced. Fast = 1K-20K faces for rapid prototyping. "
|
||||
"Extreme-High = 20K-2M faces with optional micro details."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _build_common_inputs(*, include_image_only: bool) -> list:
|
||||
inputs: list = [
|
||||
IO.Combo.Input("material", options=_MATERIAL_OPTIONS, default="Shaded"),
|
||||
IO.Combo.Input("geometry_file_format", options=_GEOMETRY_FORMAT_OPTIONS, default="glb"),
|
||||
IO.Combo.Input(
|
||||
"texture_mode",
|
||||
options=_TEXTURE_MODE_OPTIONS,
|
||||
default="Default",
|
||||
optional=True,
|
||||
tooltip="Texture quality preset. 'Default' uses the server's default for the selected tier.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=65535,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"TAPose", default=False, optional=True, advanced=True, tooltip="T/A pose for human-like models."
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"hd_texture", default=False, optional=True, advanced=True, tooltip="High-quality texture enhancement."
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"texture_delight",
|
||||
default=False,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Remove baked lighting from textures.",
|
||||
),
|
||||
]
|
||||
if include_image_only:
|
||||
inputs.append(
|
||||
IO.Boolean.Input(
|
||||
"use_original_alpha",
|
||||
default=False,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Preserve image transparency.",
|
||||
)
|
||||
)
|
||||
inputs.extend(
|
||||
[
|
||||
IO.Boolean.Input(
|
||||
"addon_highpack",
|
||||
default=False,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="HighPack addon: 4K textures and ~16x faces in Quad mode.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"bbox_width",
|
||||
default=0,
|
||||
min=0,
|
||||
max=300,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Bounding-box width (Y axis). Set to 0 with the others to skip bbox.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"bbox_height",
|
||||
default=0,
|
||||
min=0,
|
||||
max=300,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Bounding-box height (Z axis).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"bbox_length",
|
||||
default=0,
|
||||
min=0,
|
||||
max=300,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Bounding-box length (X axis).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"height_cm",
|
||||
default=0,
|
||||
min=0,
|
||||
max=10000,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Approximate model height in centimeters (0 to skip).",
|
||||
),
|
||||
]
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
_PRICE_EXPR = """
|
||||
(
|
||||
$baseCredits := widgets.mode = "extreme-high" ? 1.0 : 0.5;
|
||||
$addonCredits := widgets.addon_highpack ? 1.0 : 0.0;
|
||||
$total := ($baseCredits * 1.5) + ($addonCredits * 0.8);
|
||||
{"type":"usd","usd": $total}
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
def _resolve_mode_params(mode_input: dict) -> dict:
|
||||
"""Translate the DynamicCombo `mode` payload into Gen-2.5 request fields.
|
||||
|
||||
Returns a dict with: tier, quality_override, mesh_mode, geometry_instruct_mode, is_micro.
|
||||
Missing keys mean "do not send" (so we don't override server defaults).
|
||||
"""
|
||||
selected = mode_input["mode"]
|
||||
out: dict = {}
|
||||
|
||||
if selected == _MODE_REGULAR:
|
||||
out["tier"] = mode_input["tier"]
|
||||
polygon = mode_input.get("polygon_count", "Default")
|
||||
if polygon != "Default":
|
||||
mesh_mode, faces = get_quality_mode(polygon)
|
||||
out["mesh_mode"] = mesh_mode
|
||||
out["quality_override"] = faces
|
||||
if mode_input.get("creative"):
|
||||
out["geometry_instruct_mode"] = "creative"
|
||||
|
||||
elif selected == _MODE_FAST:
|
||||
out["tier"] = mode_input["tier"]
|
||||
out["mesh_mode"] = "Raw"
|
||||
out["quality_override"] = int(mode_input["mesh_faces"])
|
||||
|
||||
elif selected == _MODE_EXTREME_HIGH:
|
||||
out["tier"] = "Gen-2.5-Extreme-High"
|
||||
out["mesh_mode"] = mode_input["mesh_mode"]
|
||||
out["quality_override"] = int(mode_input["mesh_faces"])
|
||||
if mode_input.get("is_micro"):
|
||||
out["is_micro"] = True
|
||||
if mode_input.get("creative"):
|
||||
out["geometry_instruct_mode"] = "creative"
|
||||
return out
|
||||
|
||||
|
||||
def _build_request(
|
||||
*,
|
||||
mode_input: dict,
|
||||
material: str,
|
||||
geometry_file_format: str,
|
||||
texture_mode: str,
|
||||
seed: int,
|
||||
TAPose: bool,
|
||||
hd_texture: bool,
|
||||
texture_delight: bool,
|
||||
addon_highpack: bool,
|
||||
bbox_width: int,
|
||||
bbox_height: int,
|
||||
bbox_length: int,
|
||||
height_cm: int,
|
||||
prompt: str | None = None,
|
||||
use_original_alpha: bool = False,
|
||||
) -> Rodin3DGen25Request:
|
||||
mode_params = _resolve_mode_params(mode_input)
|
||||
|
||||
bbox = None
|
||||
if bbox_width and bbox_height and bbox_length:
|
||||
bbox = [bbox_width, bbox_height, bbox_length]
|
||||
|
||||
return Rodin3DGen25Request(
|
||||
tier=mode_params["tier"],
|
||||
prompt=prompt or None,
|
||||
seed=seed,
|
||||
material=material,
|
||||
geometry_file_format=geometry_file_format,
|
||||
texture_mode=None if texture_mode == "Default" else texture_mode,
|
||||
mesh_mode=mode_params.get("mesh_mode"),
|
||||
quality_override=mode_params.get("quality_override"),
|
||||
geometry_instruct_mode=mode_params.get("geometry_instruct_mode"),
|
||||
bbox_condition=bbox,
|
||||
height=height_cm or None,
|
||||
TAPose=TAPose or None,
|
||||
hd_texture=hd_texture or None,
|
||||
texture_delight=texture_delight or None,
|
||||
is_micro=mode_params.get("is_micro"),
|
||||
use_original_alpha=use_original_alpha or None,
|
||||
addons=["HighPack"] if addon_highpack else None,
|
||||
)
|
||||
|
||||
|
||||
class Rodin3D_Gen25_Image(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen25_Image",
|
||||
display_name="Rodin 3D Gen-2.5 - Image to 3D",
|
||||
category="api node/3d/Rodin",
|
||||
description=(
|
||||
"Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. "
|
||||
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
|
||||
),
|
||||
inputs=[
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplatePrefix(IO.Image.Input("image"), prefix="image", min=1, max=5),
|
||||
tooltip="1-5 images. The first image is used for materials when multi-view.",
|
||||
),
|
||||
_build_mode_input(),
|
||||
*_build_common_inputs(include_image_only=True),
|
||||
],
|
||||
outputs=[IO.File3DAny.Output(display_name="model_file")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode", "addon_highpack"]),
|
||||
expr=_PRICE_EXPR,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
images: IO.Autogrow.Type,
|
||||
mode: dict,
|
||||
material: str,
|
||||
geometry_file_format: str,
|
||||
texture_mode: str,
|
||||
seed: int,
|
||||
TAPose: bool,
|
||||
hd_texture: bool,
|
||||
texture_delight: bool,
|
||||
use_original_alpha: bool,
|
||||
addon_highpack: bool,
|
||||
bbox_width: int,
|
||||
bbox_height: int,
|
||||
bbox_length: int,
|
||||
height_cm: int,
|
||||
) -> IO.NodeOutput:
|
||||
image_tensors = [img for img in images.values() if img is not None]
|
||||
if not image_tensors:
|
||||
raise ValueError("Rodin Gen-2.5 Image-to-3D requires at least one image.")
|
||||
|
||||
# Flatten multi-image tensors into individual frames; the API accepts each as a separate part.
|
||||
flat_images: list = []
|
||||
for tensor in image_tensors:
|
||||
if hasattr(tensor, "shape") and len(tensor.shape) == 4:
|
||||
for i in range(tensor.shape[0]):
|
||||
flat_images.append(tensor[i])
|
||||
else:
|
||||
flat_images.append(tensor)
|
||||
|
||||
if len(flat_images) > 5:
|
||||
raise ValueError(f"Rodin Gen-2.5 accepts at most 5 images; received {len(flat_images)}.")
|
||||
|
||||
request = _build_request(
|
||||
mode_input=mode,
|
||||
material=material,
|
||||
geometry_file_format=geometry_file_format,
|
||||
texture_mode=texture_mode,
|
||||
seed=seed,
|
||||
TAPose=TAPose,
|
||||
hd_texture=hd_texture,
|
||||
texture_delight=texture_delight,
|
||||
addon_highpack=addon_highpack,
|
||||
bbox_width=bbox_width,
|
||||
bbox_height=bbox_height,
|
||||
bbox_length=bbox_length,
|
||||
height_cm=height_cm,
|
||||
prompt=None,
|
||||
use_original_alpha=use_original_alpha,
|
||||
)
|
||||
|
||||
task_uuid, subscription_key = await _create_gen25_task(cls, request, flat_images)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
file_3d = await _download_gen25_files(download_list, task_uuid, geometry_file_format)
|
||||
return IO.NodeOutput(file_3d)
|
||||
|
||||
|
||||
class Rodin3D_Gen25_Text(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen25_Text",
|
||||
display_name="Rodin 3D Gen-2.5 - Text to 3D",
|
||||
category="api node/3d/Rodin",
|
||||
description=(
|
||||
"Generate a 3D model from a text prompt via Rodin Gen-2.5. "
|
||||
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
|
||||
),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the 3D model.",
|
||||
),
|
||||
_build_mode_input(),
|
||||
*_build_common_inputs(include_image_only=False),
|
||||
],
|
||||
outputs=[IO.File3DAny.Output(display_name="model_file")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode", "addon_highpack"]),
|
||||
expr=_PRICE_EXPR,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
mode: dict,
|
||||
material: str,
|
||||
geometry_file_format: str,
|
||||
texture_mode: str,
|
||||
seed: int,
|
||||
TAPose: bool,
|
||||
hd_texture: bool,
|
||||
texture_delight: bool,
|
||||
addon_highpack: bool,
|
||||
bbox_width: int,
|
||||
bbox_height: int,
|
||||
bbox_length: int,
|
||||
height_cm: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, field_name="prompt", min_length=1, max_length=2500)
|
||||
request = _build_request(
|
||||
mode_input=mode,
|
||||
material=material,
|
||||
geometry_file_format=geometry_file_format,
|
||||
texture_mode=texture_mode,
|
||||
seed=seed,
|
||||
TAPose=TAPose,
|
||||
hd_texture=hd_texture,
|
||||
texture_delight=texture_delight,
|
||||
addon_highpack=addon_highpack,
|
||||
bbox_width=bbox_width,
|
||||
bbox_height=bbox_height,
|
||||
bbox_length=bbox_length,
|
||||
height_cm=height_cm,
|
||||
prompt=prompt,
|
||||
)
|
||||
task_uuid, subscription_key = await _create_gen25_task(cls, request, images=None)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
file_3d = await _download_gen25_files(download_list, task_uuid, geometry_file_format)
|
||||
return IO.NodeOutput(file_3d)
|
||||
|
||||
|
||||
class Rodin3DExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -551,6 +1114,8 @@ class Rodin3DExtension(ComfyExtension):
|
||||
Rodin3D_Smooth,
|
||||
Rodin3D_Sketch,
|
||||
Rodin3D_Gen2,
|
||||
Rodin3D_Gen25_Image,
|
||||
Rodin3D_Gen25_Text,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1,111 +0,0 @@
|
||||
"""Pure-numpy port of MediaPipe's face_geometry (FACE_LANDMARK_PIPELINE mode)
|
||||
+ weighted Procrustes solver. Computes the 4x4 facial transformation matrix.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _solve_weighted_orthogonal_problem(src: np.ndarray, tgt: np.ndarray, weights: np.ndarray) -> np.ndarray:
|
||||
"""Weighted orthogonal Procrustes (similarity). Returns 4x4 M with
|
||||
`target ≈ M @ homogeneous(source)` in the weighted LS sense. fp64 for
|
||||
SVD stability. Port of procrustes_solver.cc."""
|
||||
sqrt_w = np.sqrt(weights.astype(np.float64))
|
||||
w_total = float((sqrt_w ** 2).sum())
|
||||
ws = src.astype(np.float64) * sqrt_w
|
||||
wt = tgt.astype(np.float64) * sqrt_w
|
||||
|
||||
c_w = (ws @ sqrt_w) / w_total
|
||||
centered = ws - np.outer(c_w, sqrt_w)
|
||||
U, _S, Vt = np.linalg.svd(wt @ centered.T, full_matrices=True)
|
||||
# Disallow reflection: flip the least-significant axis when det(U)·det(V)<0.
|
||||
post, pre = U.copy(), Vt.T.copy()
|
||||
if np.linalg.det(post) * np.linalg.det(pre) < 0:
|
||||
post[:, 2] *= -1.0
|
||||
R = post @ pre.T
|
||||
|
||||
denom = float((centered * ws).sum())
|
||||
if denom < 1e-12:
|
||||
raise ValueError("Procrustes denominator collapsed (degenerate source).")
|
||||
scale = float((R @ centered * wt).sum()) / denom
|
||||
translation = ((wt - scale * (R @ ws)) @ sqrt_w) / w_total
|
||||
|
||||
M = np.eye(4, dtype=np.float64)
|
||||
M[:3, :3] = scale * R
|
||||
M[:3, 3] = translation
|
||||
return M
|
||||
|
||||
|
||||
def _estimate_scale(canonical: np.ndarray, runtime: np.ndarray, weights: np.ndarray) -> float:
|
||||
"""scale = ‖first column of M[:3]‖ per geometry_pipeline.cc::EstimateScale."""
|
||||
return float(np.linalg.norm(_solve_weighted_orthogonal_problem(canonical, runtime, weights)[:3, 0]))
|
||||
|
||||
|
||||
def solve_facial_transformation_matrix(
|
||||
landmarks_normalized: np.ndarray,
|
||||
canonical_vertices: np.ndarray,
|
||||
procrustes_indices: np.ndarray,
|
||||
procrustes_weights: np.ndarray,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
# face_geometry_calculator_options.pbtxt defaults
|
||||
vertical_fov_degrees: float = 63.0,
|
||||
near: float = 1.0,
|
||||
) -> np.ndarray:
|
||||
"""4x4 facial transformation matrix via two-pass scale recovery
|
||||
`landmarks_normalized` is (N, 3) in MediaPipe normalized convention: x, y
|
||||
in [0,1] with TOP-LEFT origin, z in width-scaled units.
|
||||
"""
|
||||
|
||||
h_near = 2.0 * near * math.tan(0.5 * math.radians(vertical_fov_degrees))
|
||||
w_near = image_width * h_near / image_height
|
||||
|
||||
sub = procrustes_indices.astype(np.int64)
|
||||
screen = landmarks_normalized[sub].T.astype(np.float64).copy()
|
||||
canon = canonical_vertices[sub].T.astype(np.float64).copy()
|
||||
weights = procrustes_weights.astype(np.float64)
|
||||
|
||||
# ProjectXY (TOP_LEFT y-flip, then scale all 3 axes; z uses x-scale).
|
||||
screen[1] = 1.0 - screen[1]
|
||||
screen[0] = screen[0] * w_near - 0.5 * w_near
|
||||
screen[1] = screen[1] * h_near - 0.5 * h_near
|
||||
screen[2] = screen[2] * w_near
|
||||
depth_offset = float(screen[2].mean())
|
||||
|
||||
def _unproject(s: np.ndarray, scale: float) -> np.ndarray:
|
||||
s = s.copy()
|
||||
s[2] = (s[2] - depth_offset + near) / scale
|
||||
s[0] *= s[2] / near
|
||||
s[1] *= s[2] / near
|
||||
s[2] *= -1.0
|
||||
return s
|
||||
|
||||
first = screen.copy()
|
||||
first[2] *= -1.0
|
||||
s1 = _estimate_scale(canon, first, weights) # 1st pass: Procrustes on projected XY
|
||||
s2 = _estimate_scale(canon, _unproject(screen, s1), weights) # 2nd pass: rescale z by s1, un-project XY
|
||||
return _solve_weighted_orthogonal_problem(canon, _unproject(screen, s1 * s2), weights).astype(np.float32)
|
||||
|
||||
|
||||
def transformation_matrix_from_detection(face_dict: dict, image_width: int, image_height: int, canonical_data: dict) -> np.ndarray:
|
||||
"""Adapt a FaceLandmarker face dict to MP's normalized convention and solve.
|
||||
FaceMesh emits (x, y, z) in 192-canonical units; MP's geometry expects
|
||||
z_norm = z_canonical * scale_x / image_width"""
|
||||
|
||||
lmks_xy, lmks_3d = face_dict["landmarks_xy"], face_dict["landmarks_3d"]
|
||||
aug = np.concatenate([lmks_3d[:, :2].astype(np.float64), np.ones((lmks_xy.shape[0], 1))], axis=1)
|
||||
M, *_ = np.linalg.lstsq(aug, lmks_xy.astype(np.float64), rcond=None)
|
||||
scale_x = float(np.linalg.norm(M[0]))
|
||||
z_scale = scale_x / image_width if scale_x > 1e-6 else 1.0 / image_width
|
||||
|
||||
normalized = np.empty((lmks_xy.shape[0], 3), dtype=np.float32)
|
||||
normalized[:, 0] = lmks_xy[:, 0] / image_width
|
||||
normalized[:, 1] = lmks_xy[:, 1] / image_height
|
||||
normalized[:, 2] = lmks_3d[:, 2] * z_scale
|
||||
return solve_facial_transformation_matrix(
|
||||
normalized, canonical_data["canonical_vertices"],
|
||||
canonical_data["procrustes_indices"], canonical_data["procrustes_weights"],
|
||||
image_width=image_width, image_height=image_height,
|
||||
)
|
||||
@ -1,682 +0,0 @@
|
||||
"""Pure-PyTorch port of MediaPipe's face_landmarker_v2_with_blendshapes.task:
|
||||
BlazeFace detector → FaceMesh v2 → ARKit-52 blendshapes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from scipy.special import expit
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
# Values below must stay verbatim with the published face_landmarker_v2 graph
|
||||
|
||||
# face_blendshapes_graph.cc::kLandmarksSubsetIdxs
|
||||
_BS_INPUT_INDICES: Tuple[int, ...] = (
|
||||
0, 1, 4, 5, 6, 7, 8, 10, 13, 14, 17, 21, 33, 37, 39, 40, 46, 52, 53, 54,
|
||||
55, 58, 61, 63, 65, 66, 67, 70, 78, 80, 81, 82, 84, 87, 88, 91, 93, 95,
|
||||
103, 105, 107, 109, 127, 132, 133, 136, 144, 145, 146, 148, 149, 150, 152,
|
||||
153, 154, 155, 157, 158, 159, 160, 161, 162, 163, 168, 172, 173, 176, 178,
|
||||
181, 185, 191, 195, 197, 234, 246, 249, 251, 263, 267, 269, 270, 276, 282,
|
||||
283, 284, 285, 288, 291, 293, 295, 296, 297, 300, 308, 310, 311, 312, 314,
|
||||
317, 318, 321, 323, 324, 332, 334, 336, 338, 356, 361, 362, 365, 373, 374,
|
||||
375, 377, 378, 379, 380, 381, 382, 384, 385, 386, 387, 388, 389, 390, 397,
|
||||
398, 400, 402, 405, 409, 415, 454, 466, 468, 469, 470, 471, 472, 473, 474,
|
||||
475, 476, 477,
|
||||
)
|
||||
|
||||
# face_blendshapes_graph.cc::kCategoryNames
|
||||
BLENDSHAPE_NAMES: Tuple[str, ...] = (
|
||||
"_neutral", "browDownLeft", "browDownRight", "browInnerUp", "browOuterUpLeft",
|
||||
"browOuterUpRight", "cheekPuff", "cheekSquintLeft", "cheekSquintRight",
|
||||
"eyeBlinkLeft", "eyeBlinkRight", "eyeLookDownLeft", "eyeLookDownRight",
|
||||
"eyeLookInLeft", "eyeLookInRight", "eyeLookOutLeft", "eyeLookOutRight",
|
||||
"eyeLookUpLeft", "eyeLookUpRight", "eyeSquintLeft", "eyeSquintRight",
|
||||
"eyeWideLeft", "eyeWideRight", "jawForward", "jawLeft", "jawOpen",
|
||||
"jawRight", "mouthClose", "mouthDimpleLeft", "mouthDimpleRight",
|
||||
"mouthFrownLeft", "mouthFrownRight", "mouthFunnel", "mouthLeft",
|
||||
"mouthLowerDownLeft", "mouthLowerDownRight", "mouthPressLeft",
|
||||
"mouthPressRight", "mouthPucker", "mouthRight", "mouthRollLower",
|
||||
"mouthRollUpper", "mouthShrugLower", "mouthShrugUpper", "mouthSmileLeft",
|
||||
"mouthSmileRight", "mouthStretchLeft", "mouthStretchRight",
|
||||
"mouthUpperUpLeft", "mouthUpperUpRight", "noseSneerLeft", "noseSneerRight",
|
||||
)
|
||||
|
||||
# face_detection.pbtxt — short-range BlazeFace.
|
||||
_BF_NUM_LAYERS = 4
|
||||
_BF_INPUT_SIZE = 128
|
||||
_BF_STRIDES = (8, 16, 16, 16)
|
||||
_BF_ANCHOR_OFFSET_X = 0.5
|
||||
_BF_ANCHOR_OFFSET_Y = 0.5
|
||||
_BF_ASPECT_RATIOS = (1.0,)
|
||||
_BF_INTERP_SCALE_AR = 1.0
|
||||
_BF_BOX_SCALE = 128.0
|
||||
_BF_KP_OFFSET = 4
|
||||
_BF_SCORE_CLIP = 100.0
|
||||
_BF_MIN_SCORE = 0.5
|
||||
|
||||
# face_detection_full_range.pbtxt — 48x48 grid at stride 4, 1 anchor/cell.
|
||||
_BF_FR_INPUT_SIZE = 192
|
||||
_BF_FR_GRID = 48
|
||||
_BF_FR_NUM_ANCHORS = _BF_FR_GRID * _BF_FR_GRID
|
||||
_BF_FR_BOX_SCALE = 192.0
|
||||
_BF_FR_SCORE_CLIP = 100.0
|
||||
|
||||
_FM_INPUT_SIZE = 192
|
||||
|
||||
# Face ROI: 1.5xbbox rect warped anisotropically into 192x192.
|
||||
_FACE_LEFT_EYE_KP = 0
|
||||
_FACE_RIGHT_EYE_KP = 1
|
||||
_FACE_ROI_SCALE_X = 1.5
|
||||
_FACE_ROI_SCALE_Y = 1.5
|
||||
_FACE_ROI_TARGET_ANGLE = 0.0
|
||||
|
||||
|
||||
def _tf_same_pad(x: Tensor, kernel: int, stride: int) -> Tensor:
|
||||
"""TF SAME pad (asymmetric on stride-2; PyTorch's symmetric pad undershoots by 1 px)."""
|
||||
H, W = x.shape[-2], x.shape[-1]
|
||||
pad_h = max(((H + stride - 1) // stride - 1) * stride + kernel - H, 0)
|
||||
pad_w = max(((W + stride - 1) // stride - 1) * stride + kernel - W, 0)
|
||||
if pad_h == 0 and pad_w == 0:
|
||||
return x
|
||||
return F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
|
||||
|
||||
|
||||
# BlazeFace short-range: stem 5x5/s2 → 16 BlazeBlocks → parallel heads at
|
||||
# 16²x88 (2 anchors/cell) and 8²x96 (6/cell) = 896 anchors. (in, out, stride):
|
||||
_BLAZEFACE_BLOCKS = [
|
||||
(24, 24, 1), (24, 28, 1), (28, 32, 2), (32, 36, 1),
|
||||
(36, 42, 1), (42, 48, 2), (48, 56, 1), (56, 64, 1),
|
||||
(64, 72, 1), (72, 80, 1), (80, 88, 1), (88, 96, 2),
|
||||
(96, 96, 1), (96, 96, 1), (96, 96, 1), (96, 96, 1),
|
||||
]
|
||||
|
||||
|
||||
class BlazeFaceBlock(nn.Module):
|
||||
"""DW 3x3 + PW + residual. Residual max-pools on stride>1, channel-pads on out_ch>in_ch."""
|
||||
|
||||
def __init__(self, in_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
self.in_ch, self.out_ch, self.stride = in_ch, out_ch, stride
|
||||
self.depthwise = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, device=device, dtype=dtype)
|
||||
self.pointwise = ops.Conv2d(in_ch, out_ch, 1, padding=0, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
residual = F.max_pool2d(x, 2, 2) if self.stride > 1 else x
|
||||
if self.out_ch > self.in_ch:
|
||||
residual = F.pad(residual, (0, 0, 0, 0, 0, self.out_ch - self.in_ch))
|
||||
x = _tf_same_pad(x, 3, self.stride) if self.stride > 1 else F.pad(x, (1, 1, 1, 1))
|
||||
return F.relu(self.pointwise(self.depthwise(x)) + residual)
|
||||
|
||||
|
||||
class BlazeFace(nn.Module):
|
||||
"""Short-range BlazeFace: (B, 3, 128, 128) in [-1, 1] → 896 anchors x 17."""
|
||||
|
||||
def __init__(self, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
kw = dict(device=device, dtype=dtype)
|
||||
self.stem = ops.Conv2d(3, 24, 5, stride=2, padding=0, bias=True, **kw)
|
||||
self.blocks = nn.ModuleList(BlazeFaceBlock(i, o, s, device=device, dtype=dtype, operations=operations)
|
||||
for (i, o, s) in _BLAZEFACE_BLOCKS)
|
||||
# 16²x2 + 8²x6 = 512 + 384 = 896 anchors.
|
||||
self.cls_16 = ops.Conv2d(88, 2, 1, padding=0, bias=True, **kw)
|
||||
self.cls_8 = ops.Conv2d(96, 6, 1, padding=0, bias=True, **kw)
|
||||
self.reg_16 = ops.Conv2d(88, 32, 1, padding=0, bias=True, **kw)
|
||||
self.reg_8 = ops.Conv2d(96, 96, 1, padding=0, bias=True, **kw)
|
||||
|
||||
def forward(self, image_chw_normalized: Tensor) -> tuple[Tensor, Tensor]:
|
||||
x = F.relu(self.stem(_tf_same_pad(image_chw_normalized, 5, 2)))
|
||||
# 16x16 tap is block-10 output (before the 88→96 stride-2 in block 11).
|
||||
for i in range(11):
|
||||
x = self.blocks[i](x)
|
||||
feat_16 = x
|
||||
for i in range(11, 16):
|
||||
x = self.blocks[i](x)
|
||||
feat_8 = x
|
||||
|
||||
def flat(t, a, k): # NHWC flatten → (B, H*W*A, K)
|
||||
B, _, H, W = t.shape
|
||||
return t.permute(0, 2, 3, 1).reshape(B, H * W * a, k)
|
||||
|
||||
cls = torch.cat([flat(self.cls_16(feat_16), 2, 1), flat(self.cls_8(feat_8), 6, 1)], dim=1)
|
||||
reg = torch.cat([flat(self.reg_16(feat_16), 2, 16), flat(self.reg_8(feat_8), 6, 16)], dim=1)
|
||||
return reg, cls
|
||||
|
||||
|
||||
# BlazeFace full-range (face_detection_full_range_sparse.tflite): MobileNetV2-ish
|
||||
# backbone + top-down FPN, 192² input → 2304 anchors at the 48x48 grid.
|
||||
class FRBlock(nn.Module):
|
||||
"""Double inverted residual: DW → PW(mid) → DW → PW(out) [+ residual].
|
||||
|
||||
Per source tflite: dw* have no fused activation, pw1 is always ReLU, pw2
|
||||
is ReLU only when no residual (else ReLU fuses into the ADD).
|
||||
"""
|
||||
|
||||
def __init__(self, in_ch: int, mid_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
kw = dict(device=device, dtype=dtype)
|
||||
self.has_residual = (in_ch == out_ch and stride == 1)
|
||||
self.dw1 = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, **kw)
|
||||
self.pw1 = ops.Conv2d(in_ch, mid_ch, 1, padding=0, bias=True, **kw)
|
||||
self.dw2 = ops.Conv2d(mid_ch, mid_ch, 3, stride=1, padding=0, groups=mid_ch, bias=True, **kw)
|
||||
self.pw2 = ops.Conv2d(mid_ch, out_ch, 1, padding=0, bias=True, **kw)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
residual = x if self.has_residual else None
|
||||
x = F.relu(self.pw1(self.dw1(F.pad(x, (1, 1, 1, 1)))))
|
||||
x = self.pw2(self.dw2(F.pad(x, (1, 1, 1, 1))))
|
||||
return F.relu(x + residual) if residual is not None else F.relu(x)
|
||||
|
||||
|
||||
# (in_ch, mid_ch, out_ch, stride). Stages downsample 96²x32 → 48²x64 → 24²x128
|
||||
# → 12²x192 → 6²x384. Lateral taps at indices 4, 7, 10 (see _FR_LATERAL_*).
|
||||
_FR_BACKBONE_BLOCKS = [
|
||||
(32, 8, 32, 1), (32, 8, 32, 1), # 96²x32
|
||||
(32, 16, 64, 2), (64, 16, 64, 1), (64, 16, 64, 1), # 48²x64 — tap[0]
|
||||
(64, 32, 128, 2), (128, 32, 128, 1), (128, 32, 128, 1), # 24²x128 — tap[1]
|
||||
(128, 48, 192, 2), (192, 48, 192, 1), (192, 48, 192, 1), # 12²x192 — tap[2]
|
||||
(192, 96, 384, 2), (384, 96, 384, 1), (384, 96, 384, 1), (384, 96, 384, 1), # 6²x384
|
||||
]
|
||||
_FR_LATERAL_TAP_INDICES = (4, 7, 10)
|
||||
_FR_LATERAL_CHANNELS = ((64, 48), (128, 64), (192, 96)) # (in, out) per side-conv
|
||||
|
||||
# Decoder blocks per FPN level (after upsample-and-merge with the lateral).
|
||||
_FR_DECODER_BLOCKS = [
|
||||
[(96, 48, 96, 1), (96, 48, 96, 1)], # 12²x96
|
||||
[(64, 32, 64, 1), (64, 32, 64, 1)], # 24²x64
|
||||
[(48, 24, 48, 1)], # 48²x48 — feeds the heads
|
||||
]
|
||||
|
||||
|
||||
def _dcr_depth_to_space(t: Tensor, r: int, c_out: int) -> Tensor:
|
||||
"""TF DEPTH_TO_SPACE in DCR layout (input channels = (i, j, c_out)).
|
||||
pixel_shuffle uses CRD which permutes output channels for c_out > 1."""
|
||||
B_, _, H_, W_ = t.shape
|
||||
t = t.reshape(B_, r, r, c_out, H_, W_)
|
||||
t = t.permute(0, 3, 4, 1, 5, 2).contiguous()
|
||||
return t.reshape(B_, c_out, H_ * r, W_ * r)
|
||||
|
||||
|
||||
class BlazeFaceFullRange(nn.Module):
|
||||
"""Full-range face detector: (B, 3, 192, 192) in [-1, 1] → 2304 anchors x 17 values."""
|
||||
|
||||
def __init__(self, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
kw = dict(device=device, dtype=dtype)
|
||||
mk_block = lambda i, m, o, s: FRBlock(i, m, o, s, device=device, dtype=dtype, operations=operations)
|
||||
self.stem = ops.Conv2d(3, 32, 3, stride=2, padding=0, bias=True, **kw)
|
||||
self.backbone = nn.ModuleList(mk_block(i, m, o, s) for (i, m, o, s) in _FR_BACKBONE_BLOCKS)
|
||||
self.lateral_convs = nn.ModuleList(ops.Conv2d(i, o, 1, padding=0, bias=True, **kw) for (i, o) in _FR_LATERAL_CHANNELS)
|
||||
self.top_conv = ops.Conv2d(384, 96, 1, padding=0, bias=True, **kw)
|
||||
self.decoder_levels = nn.ModuleList(
|
||||
nn.ModuleList(mk_block(i, m, o, s) for (i, m, o, s) in lvl) for lvl in _FR_DECODER_BLOCKS
|
||||
)
|
||||
# 96→64 before 12→24, 64→48 before 24→48.
|
||||
self.decoder_reduce_convs = nn.ModuleList([
|
||||
ops.Conv2d(96, 64, 1, padding=0, bias=True, **kw),
|
||||
ops.Conv2d(64, 48, 1, padding=0, bias=True, **kw),
|
||||
])
|
||||
# Heads mix 2x2-cell info via DW-stride-2 + depth_to_space block_size=2.
|
||||
self.cls_conv = ops.Conv2d(48, 4, 1, padding=0, bias=True, **kw)
|
||||
self.cls_dw = ops.Conv2d(4, 4, 3, stride=2, padding=0, groups=4, bias=True, **kw)
|
||||
self.reg_conv = ops.Conv2d(48, 64, 1, padding=0, bias=True, **kw)
|
||||
self.reg_dw = ops.Conv2d(64, 64, 3, stride=2, padding=0, groups=64, bias=True, **kw)
|
||||
|
||||
def forward(self, image_chw_normalized: Tensor) -> tuple[Tensor, Tensor]:
|
||||
# Symmetric pad-1 throughout (full-range tflite uses explicit TF PAD, not SAME).
|
||||
x = F.relu(self.stem(F.pad(image_chw_normalized, (1, 1, 1, 1))))
|
||||
tap_set = set(_FR_LATERAL_TAP_INDICES)
|
||||
laterals: list[Tensor] = []
|
||||
for i, blk in enumerate(self.backbone):
|
||||
x = blk(x)
|
||||
if i in tap_set:
|
||||
laterals.append(x)
|
||||
|
||||
# top_conv / lateral_convs / decoder_reduce_convs all have fused ReLU in the tflite.
|
||||
p = F.relu(self.top_conv(x))
|
||||
laterals_rev = list(reversed(laterals))
|
||||
lateral_convs_rev = list(reversed(self.lateral_convs))
|
||||
for level in range(len(self.decoder_levels)):
|
||||
lateral = laterals_rev[level]
|
||||
p = F.interpolate(p, size=lateral.shape[-2:], mode="bilinear", align_corners=False)
|
||||
p = p + F.relu(lateral_convs_rev[level](lateral))
|
||||
for blk in self.decoder_levels[level]:
|
||||
p = blk(p)
|
||||
if level < len(self.decoder_reduce_convs):
|
||||
p = F.relu(self.decoder_reduce_convs[level](p))
|
||||
|
||||
c = self.cls_dw(F.pad(self.cls_conv(p), (1, 1, 1, 1)))
|
||||
c = _dcr_depth_to_space(c, r=2, c_out=1)
|
||||
r = self.reg_dw(F.pad(self.reg_conv(p), (1, 1, 1, 1)))
|
||||
r = _dcr_depth_to_space(r, r=2, c_out=16)
|
||||
B = c.shape[0]
|
||||
cls_out = c.permute(0, 2, 3, 1).reshape(B, _BF_FR_NUM_ANCHORS, 1)
|
||||
reg_out = r.permute(0, 2, 3, 1).reshape(B, _BF_FR_NUM_ANCHORS, 16)
|
||||
return reg_out, cls_out
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _blazeface_full_range_anchors() -> np.ndarray:
|
||||
"""2304 anchors over 48x48; anchor_w=anchor_h=1 (fixed_anchor_size)."""
|
||||
feat = _BF_FR_GRID
|
||||
yy, xx = np.meshgrid(np.arange(feat, dtype=np.float32), np.arange(feat, dtype=np.float32), indexing="ij")
|
||||
cx, cy, ones = (xx + 0.5) / feat, (yy + 0.5) / feat, np.ones_like(xx)
|
||||
return np.stack([cx, cy, ones, ones], axis=-1).reshape(_BF_FR_NUM_ANCHORS, 4)
|
||||
|
||||
|
||||
def _decode_blazeface_full_range(regressors: np.ndarray, classificators: np.ndarray,
|
||||
score_thresh: float = _BF_MIN_SCORE) -> np.ndarray:
|
||||
"""Same decode as short-range with 2304-anchor grid and box_scale=192."""
|
||||
scores = expit(np.clip(classificators[:, 0], -_BF_FR_SCORE_CLIP, _BF_FR_SCORE_CLIP))
|
||||
keep = scores >= score_thresh
|
||||
if not keep.any():
|
||||
return np.empty((0, 17), dtype=np.float32)
|
||||
r = regressors[keep] / _BF_FR_BOX_SCALE
|
||||
a = _blazeface_full_range_anchors()[keep]
|
||||
cxs, cys, aws, ahs = a[:, 0:1], a[:, 1:2], a[:, 2:3], a[:, 3:4]
|
||||
xc, yc = r[:, 0:1] * aws + cxs, r[:, 1:2] * ahs + cys
|
||||
w, h = r[:, 2:3] * aws, r[:, 3:4] * ahs
|
||||
out = np.empty((r.shape[0], 17), dtype=np.float32)
|
||||
out[:, 0:1], out[:, 1:2], out[:, 2:3], out[:, 3:4] = xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2
|
||||
out[:, 4:16:2] = r[:, _BF_KP_OFFSET::2] * aws + cxs
|
||||
out[:, 5:16:2] = r[:, _BF_KP_OFFSET + 1::2] * ahs + cys
|
||||
out[:, 16] = scores[keep]
|
||||
return out
|
||||
|
||||
|
||||
# FaceMesh (face_landmarks_detector.tflite): PReLU variant of BlazeBlock,
|
||||
# 17 blocks, heads for 478x3 landmarks + presence.
|
||||
_FACEMESH_BLOCKS = [ # (in_ch, out_ch, stride)
|
||||
(16, 16, 1), (16, 16, 1), (16, 32, 2), (32, 32, 1), (32, 32, 1), (32, 64, 2),
|
||||
(64, 64, 1), (64, 64, 1), (64, 128, 2), (128, 128, 1), (128, 128, 1), (128, 128, 2),
|
||||
(128, 128, 1), (128, 128, 1), (128, 128, 2), (128, 128, 1), (128, 128, 1),
|
||||
]
|
||||
|
||||
|
||||
class FaceMeshBlock(nn.Module):
|
||||
"""PReLU BlazeBlock: PReLU between DW and PW, and after the residual add."""
|
||||
|
||||
def __init__(self, in_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
kw = dict(device=device, dtype=dtype)
|
||||
self.in_ch, self.out_ch, self.stride = in_ch, out_ch, stride
|
||||
self.depthwise = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, **kw)
|
||||
self.prelu_dwise = nn.PReLU(num_parameters=in_ch, **kw)
|
||||
self.pointwise = ops.Conv2d(in_ch, out_ch, 1, padding=0, bias=True, **kw)
|
||||
self.prelu_out = nn.PReLU(num_parameters=out_ch, **kw)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
residual = F.max_pool2d(x, 2, 2) if self.stride > 1 else x
|
||||
if self.out_ch > self.in_ch:
|
||||
residual = F.pad(residual, (0, 0, 0, 0, 0, self.out_ch - self.in_ch))
|
||||
x = _tf_same_pad(x, 3, self.stride) if self.stride > 1 else F.pad(x, (1, 1, 1, 1))
|
||||
return self.prelu_out(self.pointwise(self.prelu_dwise(self.depthwise(x))) + residual)
|
||||
|
||||
|
||||
class FaceMesh(nn.Module):
|
||||
NUM_LANDMARKS = 478
|
||||
|
||||
def __init__(self, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
kw = dict(device=device, dtype=dtype)
|
||||
self.stem = ops.Conv2d(3, 16, 3, stride=2, padding=0, bias=True, **kw)
|
||||
self.prelu_stem = nn.PReLU(num_parameters=16, **kw)
|
||||
self.blocks = nn.ModuleList(FaceMeshBlock(i, o, s, device=device, dtype=dtype, operations=operations)
|
||||
for (i, o, s) in _FACEMESH_BLOCKS)
|
||||
self.head_reduce = ops.Conv2d(128, 8, 1, padding=0, bias=True, **kw)
|
||||
self.prelu_head_reduce = nn.PReLU(num_parameters=8, **kw)
|
||||
self.head_block = FaceMeshBlock(8, 8, 1, device=device, dtype=dtype, operations=operations)
|
||||
self.head_presence = ops.Conv2d(8, 1, 3, padding=0, bias=True, **kw)
|
||||
self.head_landmarks = ops.Conv2d(8, self.NUM_LANDMARKS * 3, 3, padding=0, bias=True, **kw)
|
||||
|
||||
def forward(self, face_chw_normalized: Tensor) -> tuple[Tensor, Tensor]:
|
||||
"""(B, 3, 192, 192) in [0, 1] → ((B, 478, 3) landmarks in 192-canonical, (B,) presence)."""
|
||||
x = self.prelu_stem(self.stem(_tf_same_pad(face_chw_normalized, 3, 2)))
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
x = self.prelu_head_reduce(self.head_reduce(x))
|
||||
x = self.head_block(x)
|
||||
B = x.shape[0]
|
||||
presence = self.head_presence(x).reshape(B)
|
||||
lmks = self.head_landmarks(x).reshape(B, self.NUM_LANDMARKS, 3)
|
||||
return lmks, presence
|
||||
|
||||
|
||||
# FaceBlendshapes (MLP-Mixer "GhumMarkerPoserMlpMixerGeneral"):
|
||||
# 146x2 → token-reduce 146→96 → embed 2→64 → +cls token → 4x mixer → cls→52.
|
||||
_BS_NUM_INPUT_LANDMARKS = 146
|
||||
_BS_NUM_TOKENS_REDUCED = 96
|
||||
_BS_NUM_TOKENS = 97 # +1 cls
|
||||
_BS_TOKEN_DIM = 64
|
||||
_BS_TOKEN_MIX_HIDDEN = 384
|
||||
_BS_CHANNEL_MIX_HIDDEN = 256
|
||||
_BS_NUM_BLENDSHAPES = 52
|
||||
_BS_LN_EPS = 1e-6
|
||||
|
||||
|
||||
class MlpMixerBlock(nn.Module):
|
||||
"""MLP-Mixer block: token-mixing MLP (over tokens) → channel-mixing MLP (over dim).
|
||||
Both pre-LN, both residual. LN has no beta (bias=False) to match MP."""
|
||||
|
||||
def __init__(self, num_tokens: int, token_dim: int, token_hidden: int, channel_hidden: int,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
kw = dict(device=device, dtype=dtype)
|
||||
# bias=False → no LN beta (matches MP).
|
||||
self.ln1 = ops.LayerNorm(token_dim, eps=_BS_LN_EPS, bias=False, **kw)
|
||||
self.ln2 = ops.LayerNorm(token_dim, eps=_BS_LN_EPS, bias=False, **kw)
|
||||
self.token_mlp1 = ops.Linear(num_tokens, token_hidden, bias=True, **kw)
|
||||
self.token_mlp2 = ops.Linear(token_hidden, num_tokens, bias=True, **kw)
|
||||
self.channel_mlp1 = ops.Linear(token_dim, channel_hidden, bias=True, **kw)
|
||||
self.channel_mlp2 = ops.Linear(channel_hidden, token_dim, bias=True, **kw)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
y = self.ln1(x).transpose(1, 2)
|
||||
x = x + self.token_mlp2(F.relu(self.token_mlp1(y))).transpose(1, 2)
|
||||
return x + self.channel_mlp2(F.relu(self.channel_mlp1(self.ln2(x))))
|
||||
|
||||
|
||||
class FaceBlendshapes(nn.Module):
|
||||
def __init__(self, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
kw = dict(device=device, dtype=dtype)
|
||||
self.token_reduce = ops.Linear(_BS_NUM_INPUT_LANDMARKS, _BS_NUM_TOKENS_REDUCED, bias=True, **kw)
|
||||
self.token_embed = ops.Linear(2, _BS_TOKEN_DIM, bias=True, **kw)
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, _BS_TOKEN_DIM, **kw))
|
||||
self.blocks = nn.ModuleList(
|
||||
MlpMixerBlock(_BS_NUM_TOKENS, _BS_TOKEN_DIM, _BS_TOKEN_MIX_HIDDEN, _BS_CHANNEL_MIX_HIDDEN,
|
||||
device=device, dtype=dtype, operations=operations) for _ in range(4)
|
||||
)
|
||||
self.head = ops.Linear(_BS_TOKEN_DIM, _BS_NUM_BLENDSHAPES, bias=True, **kw)
|
||||
|
||||
@staticmethod
|
||||
def _input_normalize(landmarks_2d: Tensor) -> Tensor:
|
||||
# Centroid-subtract → L2 scale → x0.5. The 0.5 is baked into training.
|
||||
centroid = landmarks_2d.mean(dim=1, keepdim=True)
|
||||
x = landmarks_2d - centroid
|
||||
mag = torch.sqrt((x * x).sum(dim=-1, keepdim=True))
|
||||
scale = mag.mean(dim=1, keepdim=True)
|
||||
return (x / scale.clamp(min=1e-12)) * 0.5
|
||||
|
||||
def forward(self, landmarks_2d: Tensor) -> Tensor:
|
||||
"""(B, 146, 2) → (B, 52) in [0, 1]. Input units don't matter (centroid + L2 normalize)."""
|
||||
x = self._input_normalize(landmarks_2d)
|
||||
x = self.token_reduce(x.transpose(1, 2)).transpose(1, 2)
|
||||
x = self.token_embed(x)
|
||||
cls = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat([cls, x], dim=1)
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
return torch.sigmoid(self.head(x[:, 0]))
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _blazeface_anchors() -> np.ndarray:
|
||||
"""896 anchors per SsdAnchorsCalculator (fixed_anchor_size → anchor_w=anchor_h=1)."""
|
||||
per_ar = len(_BF_ASPECT_RATIOS) + (1 if _BF_INTERP_SCALE_AR > 0 else 0)
|
||||
layer_anchors: List[np.ndarray] = []
|
||||
layer = 0
|
||||
while layer < _BF_NUM_LAYERS:
|
||||
stride = _BF_STRIDES[layer]
|
||||
last = layer
|
||||
while last < _BF_NUM_LAYERS and _BF_STRIDES[last] == stride:
|
||||
last += 1
|
||||
per_cell = per_ar * (last - layer)
|
||||
feat = (_BF_INPUT_SIZE + stride - 1) // stride
|
||||
yy, xx = np.meshgrid(np.arange(feat, dtype=np.float32), np.arange(feat, dtype=np.float32), indexing="ij")
|
||||
cx, cy, ones = (xx + _BF_ANCHOR_OFFSET_X) / feat, (yy + _BF_ANCHOR_OFFSET_Y) / feat, np.ones_like(xx)
|
||||
cell = np.stack([cx, cy, ones, ones], axis=-1).reshape(-1, 4)
|
||||
layer_anchors.append(np.repeat(cell, per_cell, axis=0))
|
||||
layer = last
|
||||
out = np.concatenate(layer_anchors, axis=0)
|
||||
assert out.shape == (896, 4), out.shape
|
||||
return out
|
||||
|
||||
|
||||
def _decode_blazeface(regressors: np.ndarray, classificators: np.ndarray,
|
||||
score_thresh: float = _BF_MIN_SCORE) -> np.ndarray:
|
||||
"""Decode (regs (896,16), cls (896,1)) → (N, 17) = [xyxy, kp0x..kp5y, score] in [0, 1]."""
|
||||
scores = expit(np.clip(classificators[:, 0], -_BF_SCORE_CLIP, _BF_SCORE_CLIP))
|
||||
keep = scores >= score_thresh
|
||||
if not keep.any():
|
||||
return np.empty((0, 17), dtype=np.float32)
|
||||
r = regressors[keep] / _BF_BOX_SCALE
|
||||
a = _blazeface_anchors()[keep] # (N, 4) cx, cy, 1, 1
|
||||
cxs, cys, aws, ahs = a[:, 0:1], a[:, 1:2], a[:, 2:3], a[:, 3:4]
|
||||
xc, yc = r[:, 0:1] * aws + cxs, r[:, 1:2] * ahs + cys
|
||||
w, h = r[:, 2:3] * aws, r[:, 3:4] * ahs
|
||||
out = np.empty((r.shape[0], 17), dtype=np.float32)
|
||||
out[:, 0:1], out[:, 1:2], out[:, 2:3], out[:, 3:4] = xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2
|
||||
out[:, 4:16:2] = r[:, _BF_KP_OFFSET::2] * aws + cxs
|
||||
out[:, 5:16:2] = r[:, _BF_KP_OFFSET + 1::2] * ahs + cys
|
||||
out[:, 16] = scores[keep]
|
||||
return out
|
||||
|
||||
|
||||
def _weighted_nms(detections: np.ndarray, iou_thresh: float = 0.5) -> np.ndarray:
|
||||
"""MP weighted NMS — kept boxes are score-weighted averages of overlapping detections."""
|
||||
if detections.shape[0] == 0:
|
||||
return detections
|
||||
dets = detections[np.argsort(-detections[:, 16])]
|
||||
N = dets.shape[0]
|
||||
areas = np.clip(dets[:, 2] - dets[:, 0], 0, None) * np.clip(dets[:, 3] - dets[:, 1], 0, None)
|
||||
kept: List[np.ndarray] = []
|
||||
used = np.zeros(N, dtype=bool)
|
||||
for i in range(N):
|
||||
if used[i]:
|
||||
continue
|
||||
ax1, ay1, ax2, ay2 = dets[i, 0:4]
|
||||
merge_idx = [i]
|
||||
for j in range(i + 1, N):
|
||||
if used[j]:
|
||||
continue
|
||||
bx1, by1, bx2, by2 = dets[j, 0:4]
|
||||
iw = max(0.0, min(ax2, bx2) - max(ax1, bx1))
|
||||
ih = max(0.0, min(ay2, by2) - max(ay1, by1))
|
||||
inter = iw * ih
|
||||
union = areas[i] + areas[j] - inter
|
||||
if union > 0 and inter / union > iou_thresh: # strict > matches MP
|
||||
merge_idx.append(j)
|
||||
used[j] = True
|
||||
used[i] = True
|
||||
cluster = dets[merge_idx]
|
||||
ws = cluster[:, 16:17]
|
||||
ws_sum = ws.sum()
|
||||
merged = np.copy(cluster[0])
|
||||
if ws_sum > 0:
|
||||
merged[:16] = (cluster[:, :16] * ws).sum(axis=0) / ws_sum
|
||||
kept.append(merged)
|
||||
return np.stack(kept, axis=0) if kept else np.empty((0, 17), dtype=np.float32)
|
||||
|
||||
|
||||
def _detection_to_face_rect(detection: np.ndarray, image_w: int, image_h: int) -> Tuple[float, float, float, float, float]:
|
||||
"""Detection (normalized) → rotated 1.5xbbox ROI in image pixels (anisotropic)."""
|
||||
xmin, ymin, xmax, ymax = detection[0:4]
|
||||
lx = detection[4 + _FACE_LEFT_EYE_KP * 2 + 0] * image_w
|
||||
ly = detection[4 + _FACE_LEFT_EYE_KP * 2 + 1] * image_h
|
||||
rx = detection[4 + _FACE_RIGHT_EYE_KP * 2 + 0] * image_w
|
||||
ry = detection[4 + _FACE_RIGHT_EYE_KP * 2 + 1] * image_h
|
||||
# Image-y-down convention: angle = target - atan2(-dy, dx).
|
||||
angle = _FACE_ROI_TARGET_ANGLE - math.atan2(ly - ry, rx - lx)
|
||||
return (float((xmin + xmax) * 0.5 * image_w),
|
||||
float((ymin + ymax) * 0.5 * image_h),
|
||||
float((xmax - xmin) * image_w * _FACE_ROI_SCALE_X),
|
||||
float((ymax - ymin) * image_h * _FACE_ROI_SCALE_Y),
|
||||
float(angle))
|
||||
|
||||
|
||||
def _sample_warp(image_chw: Tensor, src_x: Tensor, src_y: Tensor, padding_mode: str) -> Tensor:
|
||||
"""Bilinear-sample image_chw at corner-aligned (src_x, src_y)."""
|
||||
H, W = int(image_chw.shape[-2]), int(image_chw.shape[-1])
|
||||
grid = torch.stack([(2.0 * src_x + 1.0) / W - 1.0,
|
||||
(2.0 * src_y + 1.0) / H - 1.0], dim=-1).unsqueeze(0)
|
||||
return F.grid_sample(image_chw.unsqueeze(0), grid, mode="bilinear",
|
||||
align_corners=False, padding_mode=padding_mode).squeeze(0)
|
||||
|
||||
|
||||
def _warp_face_crop(image_chw: Tensor, cx: float, cy: float, width: float, height: float,
|
||||
angle: float, output_size: int = _FM_INPUT_SIZE) -> Tensor:
|
||||
"""Rotated rect → output_size² with BORDER_REPLICATE. image_chw must be in [0, 1]."""
|
||||
s_x, s_y = width / output_size, height / output_size
|
||||
cos_a, sin_a = math.cos(angle), math.sin(angle)
|
||||
arange = torch.arange(output_size, dtype=image_chw.dtype, device=image_chw.device) - output_size * 0.5
|
||||
v_grid, u_grid = torch.meshgrid(arange, arange, indexing="ij")
|
||||
src_x = cx + u_grid * s_x * cos_a - v_grid * s_y * sin_a
|
||||
src_y = cy + u_grid * s_x * sin_a + v_grid * s_y * cos_a
|
||||
return _sample_warp(image_chw, src_x, src_y, "border")
|
||||
|
||||
|
||||
def _blazeface_input_warp(image_chw_raw: Tensor, target: int = _BF_INPUT_SIZE) -> Tuple[Tensor, float, float, float]:
|
||||
"""Centered max(W,H) square → target² with BORDER_ZERO + [-1, 1] norm.
|
||||
|
||||
Sub-pixel grid_sample matters; integer-pad-then-resize drifts the bbox ~5%.
|
||||
Returns (warped, sub_rect_cx, sub_rect_cy, sub_rect_size) — the triplet maps
|
||||
tensor-normalized [0,1] detections back to image pixels.
|
||||
"""
|
||||
H, W = int(image_chw_raw.shape[1]), int(image_chw_raw.shape[2])
|
||||
sub_rect_size = float(max(W, H))
|
||||
sub_rect_cx, sub_rect_cy = W * 0.5, H * 0.5
|
||||
s = sub_rect_size / target
|
||||
arange = torch.arange(target, dtype=image_chw_raw.dtype, device=image_chw_raw.device) - target * 0.5
|
||||
v_grid, u_grid = torch.meshgrid(arange, arange, indexing="ij")
|
||||
out = _sample_warp(image_chw_raw, sub_rect_cx + u_grid * s, sub_rect_cy + v_grid * s, "zeros")
|
||||
return (out / 127.5) - 1.0, sub_rect_cx, sub_rect_cy, sub_rect_size
|
||||
|
||||
|
||||
class FaceLandmarker(nn.Module):
|
||||
"""BlazeFace → FaceMesh v2 → blendshapes. `detector_variant` selects 'short'
|
||||
(128², ≤2m) or 'full' (192² FPN, ≤5m). State dict uses inner-module prefixes
|
||||
`detector.*` / `mesh.*` / `blendshapes.*`; the outer FaceLandmarkerModel
|
||||
wrapper rewrites `detector_{variant}.*` keys to `detector.*` before loading.
|
||||
"""
|
||||
|
||||
def __init__(self, device=None, dtype=None, operations=None, detector_variant: str = "short"):
|
||||
super().__init__()
|
||||
det_cls = {"short": BlazeFace, "full": BlazeFaceFullRange}.get(detector_variant)
|
||||
|
||||
self.detector_variant = detector_variant
|
||||
self.detector = det_cls(device=device, dtype=dtype, operations=operations)
|
||||
self.mesh = FaceMesh(device=device, dtype=dtype, operations=operations)
|
||||
self.blendshapes = FaceBlendshapes(device=device, dtype=dtype, operations=operations)
|
||||
self.register_buffer("_bs_idx", torch.tensor(_BS_INPUT_INDICES, dtype=torch.long), persistent=False)
|
||||
|
||||
def run_detector_batch(self, images_rgb_uint8: List[np.ndarray],
|
||||
score_thresh: float = _BF_MIN_SCORE,
|
||||
iou_thresh: float = 0.5):
|
||||
"""Batched detector pass. Returns (img_raws, sub_rects, sizes, per_frame_decoded)
|
||||
where per_frame_decoded[b] is (N, 17) in tensor-normalized [0,1] coords."""
|
||||
if not images_rgb_uint8:
|
||||
return [], [], [], []
|
||||
device, dtype = self.detector.stem.weight.device, self.detector.stem.weight.dtype
|
||||
det_input_size, decode_fn = ((_BF_FR_INPUT_SIZE, _decode_blazeface_full_range)
|
||||
if self.detector_variant == "full"
|
||||
else (_BF_INPUT_SIZE, _decode_blazeface))
|
||||
|
||||
# Same-size frames: stack once and transfer once. Variable size falls back
|
||||
# to per-image (only triggers for SAM3DBody's head crops).
|
||||
sizes = [tuple(img.shape[:2]) for img in images_rgb_uint8]
|
||||
if len(set(sizes)) == 1:
|
||||
batch_chw = torch.from_numpy(np.stack(images_rgb_uint8, axis=0)).to(device, dtype).movedim(-1, -3).contiguous()
|
||||
img_raws = [batch_chw[bi] for bi in range(batch_chw.shape[0])]
|
||||
else:
|
||||
img_raws = [torch.from_numpy(img).to(device, dtype).movedim(-1, -3).contiguous() for img in images_rgb_uint8]
|
||||
|
||||
warps = [_blazeface_input_warp(img_raw, det_input_size) for img_raw in img_raws]
|
||||
det_crops = [w[0] for w in warps]
|
||||
sub_rects = [(w[1], w[2], w[3]) for w in warps]
|
||||
|
||||
regs_b, cls_b = self.detector(torch.stack(det_crops, dim=0))
|
||||
regs_np, cls_np = regs_b.float().cpu().numpy(), cls_b.float().cpu().numpy()
|
||||
per_frame = []
|
||||
for b in range(len(images_rgb_uint8)):
|
||||
decoded = decode_fn(regs_np[b], cls_np[b], score_thresh=score_thresh)
|
||||
per_frame.append(_weighted_nms(decoded, iou_thresh=iou_thresh) if decoded.shape[0] > 0 else decoded)
|
||||
return img_raws, sub_rects, sizes, per_frame
|
||||
|
||||
def detect_batch(self, images_rgb_uint8: List[np.ndarray], num_faces: int = 1,
|
||||
score_thresh: float = _BF_MIN_SCORE) -> List[List[dict]]:
|
||||
"""Full pipeline batched across `images_rgb_uint8`. Returns one face-dict
|
||||
list per image (empty if nothing detected). Face dict:
|
||||
bbox_xyxy (4,) image pixels, blendshapes {52} ∈ [0,1],
|
||||
landmarks_xy (478, 2) image pixels, landmarks_3d (478, 3) in
|
||||
192-canonical (pre-transformation) units, presence float (raw logit).
|
||||
"""
|
||||
img_raws, sub_rects, sizes, per_frame_dets = self.run_detector_batch(
|
||||
images_rgb_uint8, score_thresh=score_thresh,
|
||||
)
|
||||
# tensor-normalized → image-normalized [0,1] for _detection_to_face_rect.
|
||||
for b, decoded in enumerate(per_frame_dets):
|
||||
if decoded.shape[0] == 0:
|
||||
continue
|
||||
cx, cy, size = sub_rects[b]
|
||||
H, W = sizes[b]
|
||||
sx0, sy0 = cx - size * 0.5, cy - size * 0.5
|
||||
decoded[:, 0:16:2] = (sx0 + size * decoded[:, 0:16:2]) / W
|
||||
decoded[:, 1:16:2] = (sy0 + size * decoded[:, 1:16:2]) / H
|
||||
if num_faces > 0:
|
||||
per_frame_dets[b] = decoded[: int(num_faces)]
|
||||
|
||||
# Collect every detected face across all frames into one mesh input.
|
||||
face_params: List[Tuple[int, float, float, float, float, float, float]] = []
|
||||
mesh_crops: List[Tensor] = []
|
||||
for b, dets in enumerate(per_frame_dets):
|
||||
if dets.shape[0] == 0:
|
||||
continue
|
||||
H, W = sizes[b]
|
||||
img_for_mesh = img_raws[b] / 255.0
|
||||
for det in dets:
|
||||
cx, cy, w, h, angle = _detection_to_face_rect(det, W, H)
|
||||
mesh_crops.append(_warp_face_crop(img_for_mesh, cx, cy, w, h, angle, _FM_INPUT_SIZE))
|
||||
face_params.append((b, float(det[16]), cx, cy, w, h, angle))
|
||||
|
||||
results: List[List[dict]] = [[] for _ in range(len(images_rgb_uint8))]
|
||||
if not mesh_crops:
|
||||
return results
|
||||
|
||||
lmks_canon_b, presence_b = self.mesh(torch.stack(mesh_crops, dim=0))
|
||||
bs_out_b = self.blendshapes(lmks_canon_b[:, self._bs_idx, :2])
|
||||
|
||||
# Batched canonical→image affine
|
||||
params_t = torch.tensor(
|
||||
[(cx, cy, w, h, math.cos(a), math.sin(a)) for (_b, _s, cx, cy, w, h, a) in face_params],
|
||||
device=lmks_canon_b.device, dtype=lmks_canon_b.dtype,
|
||||
)
|
||||
cxs, cys, ws, hs, cos_a, sin_a = params_t.unbind(dim=1)
|
||||
inv = 1.0 / _FM_INPUT_SIZE
|
||||
u = lmks_canon_b[..., 0] - _FM_INPUT_SIZE * 0.5
|
||||
v = lmks_canon_b[..., 1] - _FM_INPUT_SIZE * 0.5
|
||||
lmks_xy_t = torch.stack([
|
||||
cxs[:, None] + u * (ws * inv * cos_a)[:, None] - v * (hs * inv * sin_a)[:, None],
|
||||
cys[:, None] + u * (ws * inv * sin_a)[:, None] + v * (hs * inv * cos_a)[:, None],
|
||||
], dim=-1)
|
||||
|
||||
lmks_xy_np = lmks_xy_t.float().cpu().numpy()
|
||||
lmks_canon_np = lmks_canon_b.float().cpu().numpy()
|
||||
presence_np = presence_b.float().cpu().numpy()
|
||||
bs_np = bs_out_b.float().cpu().numpy()
|
||||
|
||||
for i, (b, score, *_) in enumerate(face_params):
|
||||
lmks_xy = lmks_xy_np[i]
|
||||
mn, mx = lmks_xy.min(0), lmks_xy.max(0)
|
||||
results[b].append({
|
||||
"bbox_xyxy": np.array([mn[0], mn[1], mx[0], mx[1]], dtype=np.float32),
|
||||
"blendshapes": dict(zip(BLENDSHAPE_NAMES, bs_np[i].tolist())),
|
||||
"landmarks_xy": lmks_xy,
|
||||
"landmarks_3d": lmks_canon_np[i],
|
||||
"presence": float(presence_np[i]),
|
||||
"score": score,
|
||||
})
|
||||
return results
|
||||
@ -543,7 +543,7 @@ class AudioConcat(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="AudioConcat",
|
||||
search_aliases=["join audio", "combine audio", "append audio"],
|
||||
display_name="Concatenate Audio",
|
||||
display_name="Audio Concat",
|
||||
description="Concatenates the audio1 to audio2 in the specified direction.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -597,7 +597,7 @@ class AudioMerge(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="AudioMerge",
|
||||
search_aliases=["mix audio", "overlay audio", "layer audio"],
|
||||
display_name="Merge Audio",
|
||||
display_name="Audio Merge",
|
||||
description="Combine two audio tracks by overlaying their waveforms.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -667,9 +667,8 @@ class AudioAdjustVolume(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="AudioAdjustVolume",
|
||||
search_aliases=["audio gain", "loudness", "audio level"],
|
||||
display_name="Adjust Audio Volume",
|
||||
display_name="Audio Adjust Volume",
|
||||
category="audio",
|
||||
description="Adjust the volume of the audio by a specified amount in decibels (dB).",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.Int.Input(
|
||||
|
||||
@ -47,10 +47,8 @@ class LoadImageDataSetFromFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadImageDataSetFromFolder",
|
||||
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
|
||||
display_name="Load Image (from Folder)",
|
||||
category="image",
|
||||
description="Load a dataset of images from a specified folder and return a list of images. Supported formats: PNG, JPG, JPEG, WEBP.",
|
||||
display_name="Load Image Dataset from Folder",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
@ -86,16 +84,14 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadImageTextDataSetFromFolder",
|
||||
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
|
||||
display_name="Load Image-Text (from Folder)",
|
||||
category="image",
|
||||
description="Load a dataset of pairs of images and text captions from a specified folder and return them as a list. Supported formats: PNG, JPG, JPEG, WEBP.",
|
||||
display_name="Load Image and Text Dataset from Folder",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"folder",
|
||||
options=folder_paths.get_input_subfolders(),
|
||||
tooltip="The folder to load images and text captions from.",
|
||||
tooltip="The folder to load images from.",
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
@ -210,10 +206,8 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveImageDataSetToFolder",
|
||||
search_aliases=["save folder", "save to folder", "save dataset", "save images", "export dataset"],
|
||||
display_name="Save Image (to Folder) (DEPRECATED)",
|
||||
category="image",
|
||||
description="Save a dataset of images to a specified folder. Supported formats: PNG.",
|
||||
display_name="Save Image Dataset to Folder",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
is_input_list=True, # Receive images as list
|
||||
@ -232,7 +226,6 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[],
|
||||
is_deprecated=True, # This node is redundant and superseded by existing Save Image nodes where the target folder can be specified in the filename_prefix
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -253,20 +246,14 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveImageTextDataSetToFolder",
|
||||
search_aliases=["save folder", "save to folder", "save dataset", "save images", "save text", "export dataset"],
|
||||
display_name="Save Image-Text (to Folder)",
|
||||
category="image",
|
||||
description="Save a dataset of pairs of images and text captions to a specified folder. Images are saved as PNG files and captions are saved as TXT files with the same filename_prefix.",
|
||||
display_name="Save Image and Text Dataset to Folder",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
is_input_list=True, # Receive both images and texts as lists
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="List of images to save."),
|
||||
io.String.Input("texts",
|
||||
optional=True,
|
||||
force_input=True,
|
||||
tooltip="List of text captions to save."
|
||||
),
|
||||
io.String.Input("texts", tooltip="List of text captions to save."),
|
||||
io.String.Input(
|
||||
"folder_name",
|
||||
default="dataset",
|
||||
@ -283,7 +270,7 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, folder_name, filename_prefix, texts=None):
|
||||
def execute(cls, images, texts, folder_name, filename_prefix):
|
||||
# Extract scalar values
|
||||
folder_name = folder_name[0]
|
||||
filename_prefix = filename_prefix[0]
|
||||
@ -292,12 +279,11 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
saved_files = save_images_to_folder(images, output_dir, filename_prefix)
|
||||
|
||||
# Save captions
|
||||
if texts:
|
||||
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
|
||||
caption_filename = filename.replace(".png", ".txt")
|
||||
caption_path = os.path.join(output_dir, caption_filename)
|
||||
with open(caption_path, "w", encoding="utf-8") as f:
|
||||
f.write(caption)
|
||||
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
|
||||
caption_filename = filename.replace(".png", ".txt")
|
||||
caption_path = os.path.join(output_dir, caption_filename)
|
||||
with open(caption_path, "w", encoding="utf-8") as f:
|
||||
f.write(caption)
|
||||
|
||||
logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.")
|
||||
return io.NodeOutput()
|
||||
@ -328,13 +314,11 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
|
||||
Child classes should set:
|
||||
node_id: Unique node identifier (required)
|
||||
search_aliases: List of search aliases (optional)
|
||||
display_name: Display name (optional, defaults to node_id)
|
||||
description: Node description (optional)
|
||||
extra_inputs: List of additional io.Input objects beyond "images" (optional)
|
||||
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
||||
is_output_list: True (list output) or False (single output) (optional, default True)
|
||||
is_deprecated: True if the node is deprecated (optional, default False)
|
||||
|
||||
Child classes must implement ONE of:
|
||||
_process(cls, image, **kwargs) -> tensor (for single-item processing)
|
||||
@ -342,13 +326,12 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
"""
|
||||
|
||||
node_id = None
|
||||
search_aliases = []
|
||||
display_name = None
|
||||
description = None
|
||||
extra_inputs = []
|
||||
is_group_process = None # None = auto-detect, True/False = explicit
|
||||
is_output_list = None # None = auto-detect based on processing mode
|
||||
is_deprecated = False
|
||||
|
||||
@classmethod
|
||||
def _detect_processing_mode(cls):
|
||||
"""Detect whether this node uses group or individual processing.
|
||||
@ -419,10 +402,8 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
|
||||
return io.Schema(
|
||||
node_id=cls.node_id,
|
||||
search_aliases=cls.search_aliases,
|
||||
display_name=cls.display_name or cls.node_id,
|
||||
category=cls.category,
|
||||
description=cls.description,
|
||||
category="dataset/image",
|
||||
is_experimental=True,
|
||||
is_input_list=is_group, # True for group, False for individual
|
||||
inputs=inputs,
|
||||
@ -491,13 +472,11 @@ class TextProcessingNode(io.ComfyNode):
|
||||
|
||||
Child classes should set:
|
||||
node_id: Unique node identifier (required)
|
||||
search_aliases: List of search aliases (optional)
|
||||
display_name: Display name (optional, defaults to node_id)
|
||||
description: Node description (optional)
|
||||
extra_inputs: List of additional io.Input objects beyond "texts" (optional)
|
||||
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
||||
is_output_list: True (list output) or False (single output) (optional, default True)
|
||||
is_deprecated: True if the node is deprecated (optional, default False)
|
||||
|
||||
Child classes must implement ONE of:
|
||||
_process(cls, text, **kwargs) -> str (for single-item processing)
|
||||
@ -505,13 +484,12 @@ class TextProcessingNode(io.ComfyNode):
|
||||
"""
|
||||
|
||||
node_id = None
|
||||
search_aliases = []
|
||||
display_name = None
|
||||
description = None
|
||||
extra_inputs = []
|
||||
is_group_process = None # None = auto-detect, True/False = explicit
|
||||
is_output_list = None # None = auto-detect based on processing mode
|
||||
is_deprecated = False
|
||||
|
||||
@classmethod
|
||||
def _detect_processing_mode(cls):
|
||||
"""Detect whether this node uses group or individual processing.
|
||||
@ -649,17 +627,15 @@ class TextProcessingNode(io.ComfyNode):
|
||||
|
||||
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesByShorterEdge"
|
||||
display_name = "Resize Images by Shorter Edge (DEPRECATED)"
|
||||
category = "image/transform"
|
||||
description = "Resize images so that the shorter edge matches the specified dimension while preserving aspect ratio."
|
||||
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale shorter dimension
|
||||
display_name = "Resize Images by Shorter Edge"
|
||||
description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio."
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"shorter_edge",
|
||||
default=512,
|
||||
min=1,
|
||||
max=8192,
|
||||
tooltip="Target dimension for the shorter edge.",
|
||||
tooltip="Target length for the shorter edge.",
|
||||
),
|
||||
]
|
||||
|
||||
@ -679,17 +655,15 @@ class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
|
||||
class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesByLongerEdge"
|
||||
display_name = "Resize Images by Longer Edge (DEPRECATED)"
|
||||
category = "image/transform"
|
||||
description = "Resize images so that the longer edge matches the specified dimension while preserving aspect ratio."
|
||||
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale longer dimension
|
||||
display_name = "Resize Images by Longer Edge"
|
||||
description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio."
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"longer_edge",
|
||||
default=1024,
|
||||
min=1,
|
||||
max=8192,
|
||||
tooltip="Target dimension for the longer edge.",
|
||||
tooltip="Target length for the longer edge.",
|
||||
),
|
||||
]
|
||||
|
||||
@ -712,10 +686,8 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
||||
|
||||
class CenterCropImagesNode(ImageProcessingNode):
|
||||
node_id = "CenterCropImages"
|
||||
search_aliases=["crop", "cut", "trim"]
|
||||
display_name="Crop Image (Center)"
|
||||
category="image/transform"
|
||||
description = "Center crop an image to the specified dimensions."
|
||||
display_name = "Center Crop Images"
|
||||
description = "Center crop all images to the specified dimensions."
|
||||
extra_inputs = [
|
||||
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
|
||||
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
|
||||
@ -734,11 +706,10 @@ class CenterCropImagesNode(ImageProcessingNode):
|
||||
|
||||
class RandomCropImagesNode(ImageProcessingNode):
|
||||
node_id = "RandomCropImages"
|
||||
search_aliases=["crop", "cut", "trim"]
|
||||
display_name = "Crop Image (Random)"
|
||||
category="image/transform"
|
||||
description = "Randomly crop an image to the specified dimensions."
|
||||
|
||||
display_name = "Random Crop Images"
|
||||
description = (
|
||||
"Randomly crop all images to the specified dimensions (for data augmentation)."
|
||||
)
|
||||
extra_inputs = [
|
||||
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
|
||||
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
|
||||
@ -763,9 +734,7 @@ class RandomCropImagesNode(ImageProcessingNode):
|
||||
|
||||
class NormalizeImagesNode(ImageProcessingNode):
|
||||
node_id = "NormalizeImages"
|
||||
search_aliases=["normalize", "normalize colors"]
|
||||
display_name = "Normalize Image Colors"
|
||||
category = "image/color"
|
||||
display_name = "Normalize Images"
|
||||
description = "Normalize images using mean and standard deviation."
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
@ -793,10 +762,8 @@ class NormalizeImagesNode(ImageProcessingNode):
|
||||
|
||||
class AdjustBrightnessNode(ImageProcessingNode):
|
||||
node_id = "AdjustBrightness"
|
||||
search_aliases=["brightness"]
|
||||
display_name = "Adjust Brightness"
|
||||
category="image/adjustments"
|
||||
description = "Adjust the brightness of an image."
|
||||
description = "Adjust brightness of all images."
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
"factor",
|
||||
@ -814,10 +781,8 @@ class AdjustBrightnessNode(ImageProcessingNode):
|
||||
|
||||
class AdjustContrastNode(ImageProcessingNode):
|
||||
node_id = "AdjustContrast"
|
||||
search_aliases=["contrast"]
|
||||
display_name = "Adjust Contrast"
|
||||
category="image/adjustments"
|
||||
description = "Adjust the contrast of an image."
|
||||
description = "Adjust contrast of all images."
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
"factor",
|
||||
@ -835,10 +800,8 @@ class AdjustContrastNode(ImageProcessingNode):
|
||||
|
||||
class ShuffleDatasetNode(ImageProcessingNode):
|
||||
node_id = "ShuffleDataset"
|
||||
search_aliases=["shuffle", "randomize", "mix"]
|
||||
display_name = "Shuffle Images List"
|
||||
category = "image/batch"
|
||||
description = "Randomly shuffle the order of images in a list."
|
||||
display_name = "Shuffle Image Dataset"
|
||||
description = "Randomly shuffle the order of images in the dataset."
|
||||
is_group_process = True # Requires full list to shuffle
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
@ -860,15 +823,13 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ShuffleImageTextDataset",
|
||||
search_aliases=["shuffle", "randomize", "mix"],
|
||||
display_name = "Shuffle Pairs of Image-Text",
|
||||
category = "image/batch",
|
||||
description = "Randomly shuffle the order of pairs of image-text in a list.",
|
||||
display_name="Shuffle Image-Text Dataset",
|
||||
category="dataset/image",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="List of images to shuffle."),
|
||||
io.String.Input("texts", tooltip="List of texts to shuffle.", force_input=True),
|
||||
io.String.Input("texts", tooltip="List of texts to shuffle."),
|
||||
io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
@ -904,11 +865,8 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
|
||||
|
||||
class TextToLowercaseNode(TextProcessingNode):
|
||||
node_id = "TextToLowercase"
|
||||
search_aliases=["lowercase"]
|
||||
display_name = "Convert Text to Lowercase (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Convert text to lowercase."
|
||||
is_deprecated = True # This node is superseded by the Convert Text Case node
|
||||
display_name = "Text to Lowercase"
|
||||
description = "Convert all texts to lowercase."
|
||||
|
||||
@classmethod
|
||||
def _process(cls, text):
|
||||
@ -917,11 +875,8 @@ class TextToLowercaseNode(TextProcessingNode):
|
||||
|
||||
class TextToUppercaseNode(TextProcessingNode):
|
||||
node_id = "TextToUppercase"
|
||||
search_aliases=["uppercase"]
|
||||
display_name = "Convert Text to Uppercase (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Convert text to uppercase."
|
||||
is_deprecated = True # This node is superseded by the Convert Text Case node
|
||||
display_name = "Text to Uppercase"
|
||||
description = "Convert all texts to uppercase."
|
||||
|
||||
@classmethod
|
||||
def _process(cls, text):
|
||||
@ -930,10 +885,8 @@ class TextToUppercaseNode(TextProcessingNode):
|
||||
|
||||
class TruncateTextNode(TextProcessingNode):
|
||||
node_id = "TruncateText"
|
||||
search_aliases=["truncate", "cut", "shorten"]
|
||||
display_name = "Truncate Text"
|
||||
category = "text"
|
||||
description = "Truncate text to a maximum length."
|
||||
description = "Truncate all texts to a maximum length."
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"max_length", default=77, min=1, max=10000, tooltip="Maximum text length."
|
||||
@ -947,10 +900,8 @@ class TruncateTextNode(TextProcessingNode):
|
||||
|
||||
class AddTextPrefixNode(TextProcessingNode):
|
||||
node_id = "AddTextPrefix"
|
||||
display_name = "Add Text Prefix (DEPRECATED)"
|
||||
category = "text"
|
||||
display_name = "Add Text Prefix"
|
||||
description = "Add a prefix to all texts."
|
||||
is_deprecated = True # This node is superseded by the Concatenate Text node
|
||||
extra_inputs = [
|
||||
io.String.Input("prefix", default="", tooltip="Prefix to add."),
|
||||
]
|
||||
@ -962,10 +913,8 @@ class AddTextPrefixNode(TextProcessingNode):
|
||||
|
||||
class AddTextSuffixNode(TextProcessingNode):
|
||||
node_id = "AddTextSuffix"
|
||||
display_name = "Add Text Suffix (DEPRECATED)"
|
||||
category = "text"
|
||||
display_name = "Add Text Suffix"
|
||||
description = "Add a suffix to all texts."
|
||||
is_deprecated = True # This node is superseded by the Concatenate Text node
|
||||
extra_inputs = [
|
||||
io.String.Input("suffix", default="", tooltip="Suffix to add."),
|
||||
]
|
||||
@ -977,10 +926,8 @@ class AddTextSuffixNode(TextProcessingNode):
|
||||
|
||||
class ReplaceTextNode(TextProcessingNode):
|
||||
node_id = "ReplaceText"
|
||||
display_name = "Replace Text (DEPRECATED)"
|
||||
category = "text"
|
||||
display_name = "Replace Text"
|
||||
description = "Replace text in all texts."
|
||||
is_deprecated = True # This node is superseded by the other Replace Text node
|
||||
extra_inputs = [
|
||||
io.String.Input("find", default="", tooltip="Text to find."),
|
||||
io.String.Input("replace", default="", tooltip="Text to replace with."),
|
||||
@ -993,10 +940,8 @@ class ReplaceTextNode(TextProcessingNode):
|
||||
|
||||
class StripWhitespaceNode(TextProcessingNode):
|
||||
node_id = "StripWhitespace"
|
||||
display_name = "Strip Whitespace (DEPRECATED)"
|
||||
category = "text"
|
||||
display_name = "Strip Whitespace"
|
||||
description = "Strip leading and trailing whitespace from all texts."
|
||||
is_deprecated = True # This node is superseded by the Trim Text node
|
||||
|
||||
@classmethod
|
||||
def _process(cls, text):
|
||||
@ -1007,13 +952,11 @@ class StripWhitespaceNode(TextProcessingNode):
|
||||
|
||||
|
||||
class ImageDeduplicationNode(ImageProcessingNode):
|
||||
"""Remove duplicate or very similar images from a list using perceptual hashing."""
|
||||
"""Remove duplicate or very similar images from the dataset using perceptual hashing."""
|
||||
|
||||
node_id = "ImageDeduplication"
|
||||
search_aliases=["deduplicate", "remove duplicates", "similarity filter"]
|
||||
display_name = "Deduplicate Images"
|
||||
category = "image/batch"
|
||||
description = "Remove duplicate or very similar images from a list."
|
||||
display_name = "Image Deduplication"
|
||||
description = "Remove duplicate or very similar images from the dataset."
|
||||
is_group_process = True # Requires full list to compare images
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
@ -1083,9 +1026,7 @@ class ImageGridNode(ImageProcessingNode):
|
||||
"""Combine multiple images into a single grid/collage."""
|
||||
|
||||
node_id = "ImageGrid"
|
||||
search_aliases=["grid", "collage", "combine"]
|
||||
display_name = "Make Image Grid"
|
||||
category="image/batch"
|
||||
display_name = "Image Grid"
|
||||
description = "Arrange multiple images into a grid layout."
|
||||
is_group_process = True # Requires full list to create grid
|
||||
is_output_list = False # Outputs single grid image
|
||||
@ -1161,12 +1102,9 @@ class MergeImageListsNode(ImageProcessingNode):
|
||||
"""Merge multiple image lists into a single list."""
|
||||
|
||||
node_id = "MergeImageLists"
|
||||
search_aliases=["list", "merge list", "make list"]
|
||||
display_name = "Merge Image Lists (DEPRECATED)"
|
||||
category = "image/batch"
|
||||
display_name = "Merge Image Lists"
|
||||
description = "Concatenate multiple image lists into one."
|
||||
is_group_process = True # Receives images as list
|
||||
is_deprecated = True # This node is superseded by the Create List node
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, images):
|
||||
@ -1181,11 +1119,9 @@ class MergeTextListsNode(TextProcessingNode):
|
||||
"""Merge multiple text lists into a single list."""
|
||||
|
||||
node_id = "MergeTextLists"
|
||||
display_name = "Merge Text Lists (DEPRECATED)"
|
||||
category = "text"
|
||||
display_name = "Merge Text Lists"
|
||||
description = "Concatenate multiple text lists into one."
|
||||
is_group_process = True # Receives texts as list
|
||||
is_deprecated = True # This node is superseded by the Create List node
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, texts):
|
||||
@ -1206,10 +1142,8 @@ class ResolutionBucket(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ResolutionBucket",
|
||||
search_aliases=["bucket by resolution", "group by resolution", "batch by resolution"],
|
||||
display_name="Resolution Bucket",
|
||||
category="training",
|
||||
description="Group latents and conditionings into buckets",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
@ -1302,8 +1236,7 @@ class MakeTrainingDataset(io.ComfyNode):
|
||||
node_id="MakeTrainingDataset",
|
||||
search_aliases=["encode dataset"],
|
||||
display_name="Make Training Dataset",
|
||||
category="training",
|
||||
description="Encode images with VAE and texts with CLIP to create a training dataset of latents and conditionings.",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_input_list=True, # images and texts as lists
|
||||
inputs=[
|
||||
@ -1318,7 +1251,6 @@ class MakeTrainingDataset(io.ComfyNode):
|
||||
"texts",
|
||||
optional=True,
|
||||
tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).",
|
||||
force_input=True
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
@ -1388,10 +1320,9 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveTrainingDataset",
|
||||
search_aliases=["export dataset", "save dataset"],
|
||||
search_aliases=["export training data"],
|
||||
display_name="Save Training Dataset",
|
||||
category="training",
|
||||
description="Save encoded training dataset (latents + conditioning) to disk for efficient loading during training.",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
is_input_list=True, # Receive lists
|
||||
@ -1493,8 +1424,7 @@ class LoadTrainingDataset(io.ComfyNode):
|
||||
node_id="LoadTrainingDataset",
|
||||
search_aliases=["import dataset", "training data"],
|
||||
display_name="Load Training Dataset",
|
||||
category="training",
|
||||
description="Load encoded training dataset (latents + conditioning) from disk for use in training.",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
|
||||
@ -419,17 +419,15 @@ class VoxelToMeshBasic(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VoxelToMeshBasic",
|
||||
display_name="Voxel to Mesh (Basic) (DEPRECATED)",
|
||||
display_name="Voxel to Mesh (Basic)",
|
||||
category="3d",
|
||||
description="Converts a voxel grid to a mesh.",
|
||||
is_deprecated=True, # This node is superseded by the Voxel To Mesh node
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output(),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -455,10 +453,9 @@ class VoxelToMesh(IO.ComfyNode):
|
||||
node_id="VoxelToMesh",
|
||||
display_name="Voxel to Mesh",
|
||||
category="3d",
|
||||
description="Converts a voxel grid to a mesh.",
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
IO.Combo.Input("algorithm", options=["surface net", "basic"]),
|
||||
IO.Combo.Input("algorithm", options=["surface net", "basic"], advanced=True),
|
||||
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
|
||||
@ -55,10 +55,9 @@ class ImageCropV2(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCropV2",
|
||||
search_aliases=["crop", "cut", "trim"],
|
||||
search_aliases=["trim"],
|
||||
display_name="Crop Image",
|
||||
category="image/transform",
|
||||
description = "Crop an image to the specified dimensions.",
|
||||
essentials_category="Image Tools",
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
|
||||
@ -11,8 +11,8 @@ class LTXVAudioVAELoader(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAELoader",
|
||||
display_name="Load LTXV Audio VAE",
|
||||
category="loaders",
|
||||
display_name="LTXV Audio VAE Loader",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"ckpt_name",
|
||||
@ -40,7 +40,7 @@ class LTXVAudioVAEEncode(VAEEncodeAudio):
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEEncode",
|
||||
display_name="LTXV Audio VAE Encode",
|
||||
category="latent/audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Audio.Input("audio", tooltip="The audio to be encoded."),
|
||||
io.Vae.Input(
|
||||
@ -63,7 +63,7 @@ class LTXVAudioVAEDecode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEDecode",
|
||||
display_name="LTXV Audio VAE Decode",
|
||||
category="latent/audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Latent.Input("samples", tooltip="The latent to be decoded."),
|
||||
io.Vae.Input(
|
||||
|
||||
@ -1,509 +0,0 @@
|
||||
"""ComfyUI nodes for the pure-PyTorch MediaPipe Face Landmarker port.
|
||||
|
||||
Custom IO types:
|
||||
FACE_LANDMARKER — FaceLandmarkerModel wrapper (ModelPatcher inside)
|
||||
FACE_LANDMARKS — {"frames": List[List[face_dict]], "image_size": (H, W),
|
||||
"connection_sets": dict[str, frozenset[(int, int)]]}
|
||||
face_dict: bbox_xyxy, blendshapes, landmarks_xy,
|
||||
landmarks_3d, presence, score, transformation_matrix
|
||||
|
||||
MediaPipeFaceLandmarker also emits the core BOUNDING_BOX type — pair with DrawBBoxes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageColor, ImageDraw
|
||||
from tqdm.auto import tqdm
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
from comfy_extras.mediapipe.face_landmarker import FaceLandmarker
|
||||
from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection
|
||||
|
||||
|
||||
FaceDetectionType = io.Custom("FACE_DETECTION_MODEL")
|
||||
FaceLandmarksType = io.Custom("FACE_LANDMARKS")
|
||||
|
||||
_CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights")
|
||||
_CONTOUR_PARTS = ("face_oval", "left_eye", "right_eye", "left_eyebrow", "right_eyebrow", "lips")
|
||||
|
||||
|
||||
class FaceLandmarkerModel:
|
||||
"""Loaded FaceLandmarker variants + ModelPatcher per variant.
|
||||
|
||||
Safetensors layout: `detector_short.*` / `detector_full.*` plus shared
|
||||
`mesh.*`, `blendshapes.*`, `canonical_*`, and `topology.*`.
|
||||
PReLU forces plain-nn / fp32 (manual_cast strands buffers across devices).
|
||||
"""
|
||||
|
||||
def __init__(self, state_dict: dict):
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = torch.float32
|
||||
|
||||
# FACEMESH_* connection sets, embedded as int32 (N, 2) under topology.*.
|
||||
base: dict[str, frozenset] = {}
|
||||
for k in [k for k in state_dict if k.startswith("topology.")]:
|
||||
base[k[len("topology."):]] = frozenset(map(tuple, state_dict.pop(k).tolist()))
|
||||
base["contours"] = frozenset().union(*(base[p] for p in _CONTOUR_PARTS))
|
||||
base["all"] = base["contours"] | base["irises"] | base["nose"]
|
||||
|
||||
self.connection_sets: dict[str, frozenset] = base
|
||||
self.canonical_data: dict[str, np.ndarray] = {k: state_dict.pop(k).numpy() for k in _CANONICAL_KEYS}
|
||||
|
||||
shared = {k: v for k, v in state_dict.items() if k.startswith(("mesh.", "blendshapes."))}
|
||||
|
||||
self.models: dict[str, FaceLandmarker] = {}
|
||||
self.patchers: dict[str, comfy.model_patcher.ModelPatcher] = {}
|
||||
for variant in ("short", "full"):
|
||||
prefix = f"detector_{variant}."
|
||||
sub = dict(shared)
|
||||
sub.update({f"detector.{k[len(prefix):]}": v for k, v in state_dict.items() if k.startswith(prefix)})
|
||||
fl = FaceLandmarker(device=offload_device, dtype=self.dtype, operations=None, detector_variant=variant).eval()
|
||||
fl.load_state_dict(sub, strict=False)
|
||||
|
||||
self.models[variant] = fl
|
||||
self.patchers[variant] = comfy.model_patcher.CoreModelPatcher(
|
||||
fl, load_device=self.load_device, offload_device=offload_device,
|
||||
size=comfy.model_management.module_size(fl),
|
||||
)
|
||||
|
||||
def detect_batch(self, images, num_faces: int, score_thresh: float, variant: str):
|
||||
comfy.model_management.load_model_gpu(self.patchers[variant])
|
||||
return self.models[variant].detect_batch(images, num_faces=num_faces, score_thresh=score_thresh)
|
||||
|
||||
|
||||
def _image_to_uint8(image: torch.Tensor) -> np.ndarray:
|
||||
return image[..., :3].mul(255.0).add_(0.5).clamp_(0, 255).to(torch.uint8).cpu().numpy()
|
||||
|
||||
|
||||
def _parse_color(color: str) -> tuple[int, int, int]:
|
||||
try:
|
||||
return ImageColor.getrgb(color)[:3]
|
||||
except ValueError:
|
||||
return (0, 255, 0)
|
||||
|
||||
|
||||
def _copy_face(face: dict) -> dict:
|
||||
"""Shallow copy of a face_dict with array-fields cloned so callers can mutate."""
|
||||
return {
|
||||
"bbox_xyxy": face["bbox_xyxy"].copy(),
|
||||
"blendshapes": dict(face["blendshapes"]),
|
||||
"landmarks_xy": face["landmarks_xy"].copy(),
|
||||
"landmarks_3d": face["landmarks_3d"].copy(),
|
||||
"presence": face["presence"],
|
||||
"score": face["score"],
|
||||
}
|
||||
|
||||
|
||||
def _lerp_face(a: dict, b: dict, t: float) -> dict:
|
||||
return {
|
||||
"bbox_xyxy": (1 - t) * a["bbox_xyxy"] + t * b["bbox_xyxy"],
|
||||
"blendshapes": {k: (1 - t) * a["blendshapes"][k] + t * b["blendshapes"][k] for k in a["blendshapes"]},
|
||||
"landmarks_xy": (1 - t) * a["landmarks_xy"] + t * b["landmarks_xy"],
|
||||
"landmarks_3d": (1 - t) * a["landmarks_3d"] + t * b["landmarks_3d"],
|
||||
"presence": (1 - t) * a["presence"] + t * b["presence"],
|
||||
"score": (1 - t) * a["score"] + t * b["score"],
|
||||
}
|
||||
|
||||
|
||||
def _match_faces(a: list[dict], b: list[dict]) -> list[tuple[int, int]]:
|
||||
"""Greedy nearest-neighbour pairing of faces between two frames by bbox
|
||||
centre distance. Unmatched (when counts differ) are dropped."""
|
||||
if not a or not b:
|
||||
return []
|
||||
centers_a = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]),
|
||||
0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in a])
|
||||
centers_b = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]),
|
||||
0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in b])
|
||||
dists = np.linalg.norm(centers_a[:, None] - centers_b[None], axis=-1)
|
||||
pairs: list[tuple[int, int]] = []
|
||||
used_a: set[int] = set()
|
||||
used_b: set[int] = set()
|
||||
candidates = sorted((dists[ia, ib], ia, ib) for ia in range(len(a)) for ib in range(len(b)))
|
||||
for _, ia, ib in candidates:
|
||||
if ia in used_a or ib in used_b:
|
||||
continue
|
||||
pairs.append((ia, ib))
|
||||
used_a.add(ia)
|
||||
used_b.add(ib)
|
||||
return pairs
|
||||
|
||||
|
||||
def _fill_missing_frames(frames: list[list[dict]], mode: str) -> None:
|
||||
"""In-place fill empty frame slots from neighbouring detections. Multi-face
|
||||
aware: pairs faces across bracketing frames by greedy bbox-centre NN.
|
||||
When counts differ, unmatched faces are dropped from the synthesised frame."""
|
||||
if mode == "empty":
|
||||
return
|
||||
valid = [i for i, fr in enumerate(frames) if fr]
|
||||
if not valid:
|
||||
return # nothing to fill from
|
||||
if mode == "previous":
|
||||
last: list[dict] = []
|
||||
for i, fr in enumerate(frames):
|
||||
if fr:
|
||||
last = fr
|
||||
elif last:
|
||||
frames[i] = [_copy_face(f) for f in last]
|
||||
return
|
||||
# interpolate: lerp between bracketing valid frames; clamp at ends.
|
||||
for i in range(len(frames)):
|
||||
if frames[i]:
|
||||
continue
|
||||
prev_i = max((v for v in valid if v < i), default=None)
|
||||
next_i = min((v for v in valid if v > i), default=None)
|
||||
if prev_i is None:
|
||||
frames[i] = [_copy_face(f) for f in frames[next_i]]
|
||||
elif next_i is None:
|
||||
frames[i] = [_copy_face(f) for f in frames[prev_i]]
|
||||
else:
|
||||
t = (i - prev_i) / (next_i - prev_i)
|
||||
pairs = _match_faces(frames[prev_i], frames[next_i])
|
||||
frames[i] = [_lerp_face(frames[prev_i][a], frames[next_i][b], t) for a, b in pairs]
|
||||
|
||||
|
||||
def _ordered_rings(edges: frozenset[tuple[int, int]]) -> list[list[int]]:
|
||||
"""Walk an unordered edge set into one or more closed-loop vertex rings
|
||||
(handles multi-loop sets like FACEMESH_LIPS: outer + inner)."""
|
||||
adj: dict[int, set[int]] = {}
|
||||
for a, b in edges:
|
||||
adj.setdefault(a, set()).add(b)
|
||||
adj.setdefault(b, set()).add(a)
|
||||
visited: set[int] = set()
|
||||
rings: list[list[int]] = []
|
||||
for start in adj:
|
||||
if start in visited:
|
||||
continue
|
||||
ring = [start]
|
||||
visited.add(start)
|
||||
prev, cur = -1, start
|
||||
while True:
|
||||
nxt = next((v for v in adj[cur] if v != prev), None)
|
||||
if nxt is None or nxt == start:
|
||||
break
|
||||
ring.append(nxt)
|
||||
visited.add(nxt)
|
||||
prev, cur = cur, nxt
|
||||
rings.append(ring)
|
||||
return rings
|
||||
|
||||
|
||||
class LoadMediaPipeFaceLandmarker(io.ComfyNode):
|
||||
"""Load MediaPipe Face Landmarker v2 weights. Contains both detector variants
|
||||
(short / full), shared mesh, blendshapes, and canonical geometry."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadMediaPipeFaceLandmarker",
|
||||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
|
||||
display_name="Load Face Detection Model (MediaPipe)",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("detection"),
|
||||
tooltip="Face detection model from models/detection/."),
|
||||
],
|
||||
outputs=[FaceDetectionType.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("detection", model_name), safe_load=True)
|
||||
wrapper = FaceLandmarkerModel(sd)
|
||||
return io.NodeOutput(wrapper)
|
||||
|
||||
|
||||
# Per-frame fallback modes for detection failures in a batch.
|
||||
_FALLBACK_MODES = ("empty", "previous", "interpolate")
|
||||
|
||||
|
||||
class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
"""BlazeFace → FaceMesh v2 → ARKit-52 blendshapes, batched across the
|
||||
input. Also emits a BOUNDING_BOX list (landmark-extent bbox per face) —
|
||||
pair with DrawBBoxes for detector-only viz or MediaPipeFaceMeshVisualize
|
||||
for the mesh overlay."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MediaPipeFaceLandmarker",
|
||||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
|
||||
display_name="Detect Face Landmarks (MediaPipe)",
|
||||
category="image/detection",
|
||||
description="Detects facial landmarks using MediaPipe model.",
|
||||
inputs=[
|
||||
FaceDetectionType.Input("face_detection_model"),
|
||||
io.Image.Input("image"),
|
||||
io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short",
|
||||
tooltip="Face detector range. 'short' is tuned for close-up faces "
|
||||
"(within ~2 m of the camera); 'full' covers farther / smaller "
|
||||
"faces (up to ~5 m) but is slower. 'both' runs both detectors and "
|
||||
"keeps whichever found more faces per frame (~2× detection cost)."),
|
||||
io.Int.Input("num_faces", default=1, min=0, max=16, step=1,
|
||||
tooltip="Maximum faces to return per frame. 0 = no cap (return all detected)."),
|
||||
io.Float.Input("min_confidence", default=0.5, min=0.0, max=1.0, step=0.01, advanced=True,
|
||||
tooltip="BlazeFace score threshold. Lower to catch small/occluded faces."),
|
||||
io.Combo.Input("missing_frame_fallback", options=list(_FALLBACK_MODES), default="empty", advanced=True,
|
||||
tooltip="Per-frame behaviour when detection fails in a batch. "
|
||||
"'empty' leaves the frame faceless. 'previous' copies the most recent successful "
|
||||
"detection. 'interpolate' lerps landmarks/bbox/blendshapes between bracketing "
|
||||
"successful frames. Multi-face: pairs faces across frames by greedy bbox-centre NN."),
|
||||
],
|
||||
outputs=[
|
||||
FaceLandmarksType.Output(display_name="face_landmarks"),
|
||||
io.BoundingBox.Output("bboxes"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, face_detection_model, image, detector_variant, num_faces, min_confidence,
|
||||
missing_frame_fallback) -> io.NodeOutput:
|
||||
canonical = face_detection_model.canonical_data
|
||||
img_np = _image_to_uint8(image)
|
||||
B, H, W = img_np.shape[:3]
|
||||
chunk = 16
|
||||
is_both = detector_variant == "both"
|
||||
total_work = 2 * B if is_both else B
|
||||
pbar = comfy.utils.ProgressBar(total_work)
|
||||
|
||||
def _run(variant: str) -> list[list[dict]]:
|
||||
res: list[list[dict]] = []
|
||||
with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq:
|
||||
for i in range(0, B, chunk):
|
||||
end = min(i + chunk, B)
|
||||
res.extend(face_detection_model.detect_batch(
|
||||
[img_np[bi] for bi in range(i, end)],
|
||||
num_faces=int(num_faces),
|
||||
score_thresh=float(min_confidence),
|
||||
variant=variant,
|
||||
))
|
||||
pbar.update_absolute(min(pbar.current + (end - i), total_work))
|
||||
tq.update(end - i)
|
||||
return res
|
||||
|
||||
if is_both:
|
||||
short_res = _run("short")
|
||||
full_res = _run("full")
|
||||
# Per-frame keep whichever found more faces (tie → short).
|
||||
frames: list[list[dict]] = [
|
||||
short_res[bi] if len(short_res[bi]) >= len(full_res[bi]) else full_res[bi]
|
||||
for bi in range(B)
|
||||
]
|
||||
else:
|
||||
frames = _run(detector_variant)
|
||||
_fill_missing_frames(frames, missing_frame_fallback)
|
||||
bboxes = []
|
||||
for per_frame in frames:
|
||||
per_bb = []
|
||||
for f in per_frame:
|
||||
f["transformation_matrix"] = transformation_matrix_from_detection(f, W, H, canonical)
|
||||
x1, y1, x2, y2 = (float(v) for v in f["bbox_xyxy"])
|
||||
per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])})
|
||||
bboxes.append(per_bb)
|
||||
return io.NodeOutput({"frames": frames, "image_size": (H, W),
|
||||
"connection_sets": face_detection_model.connection_sets}, bboxes)
|
||||
|
||||
|
||||
# Topology keys unioned by the 'all' connections preset (contour parts + irises + nose).
|
||||
_ALL_CONNECTION_PARTS: tuple[str, ...] = (*_CONTOUR_PARTS, "irises", "nose")
|
||||
_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = (
|
||||
("face_oval", True),
|
||||
("lips", True),
|
||||
("left_eye", True),
|
||||
("right_eye", True),
|
||||
("left_eyebrow", True),
|
||||
("right_eyebrow", True),
|
||||
("irises", True),
|
||||
("nose", True),
|
||||
("tesselation", False),
|
||||
)
|
||||
|
||||
|
||||
class MediaPipeFaceMeshVisualize(io.ComfyNode):
|
||||
"""Draw a FACEMESH_* subset over an image. Topology travels with the
|
||||
FACE_LANDMARKS payload (set at detection time)."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MediaPipeFaceMeshVisualize",
|
||||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection", "visualize"],
|
||||
display_name="Visualize Face Landmarks (MediaPipe)",
|
||||
category="image/detection",
|
||||
description="Draws face landmarks mesh on the input image.",
|
||||
inputs=[
|
||||
FaceLandmarksType.Input("face_landmarks"),
|
||||
io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."),
|
||||
io.DynamicCombo.Input(
|
||||
"connections",
|
||||
tooltip="'all' = oval+eyes+brows+lips+irises+nose. 'fill' = solid face_oval polygon (silhouette mask). 'custom' = toggle each feature individually (including 'tesselation', the full 2547-edge wireframe).",
|
||||
options=[
|
||||
io.DynamicCombo.Option("all", []),
|
||||
io.DynamicCombo.Option("fill", []),
|
||||
io.DynamicCombo.Option("custom", [
|
||||
io.Boolean.Input(feat, default=default,
|
||||
tooltip=f"Draw the '{feat}' connection set.")
|
||||
for feat, default in _CUSTOM_FEATURES
|
||||
]),
|
||||
],
|
||||
),
|
||||
io.Color.Input("color", default="#00ff00"),
|
||||
io.Int.Input("thickness", default=1, min=0, max=8, step=1,
|
||||
tooltip="Edge line thickness in pixels. 0 disables edge drawing."),
|
||||
io.Int.Input("point_size", default=2, min=0, max=16, step=1,
|
||||
tooltip="Landmark dot radius in pixels. 0 disables point drawing."),
|
||||
],
|
||||
outputs=[io.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, face_landmarks, connections, color, thickness, point_size, image=None) -> io.NodeOutput:
|
||||
sets = face_landmarks["connection_sets"]
|
||||
sel = connections["connections"]
|
||||
fill_rings: list[list[int]] | None = None
|
||||
if sel == "fill":
|
||||
fill_rings = _ordered_rings(sets["face_oval"])
|
||||
edges = frozenset()
|
||||
elif sel == "custom":
|
||||
parts = [feat for feat, _ in _CUSTOM_FEATURES if connections.get(feat, False)]
|
||||
edges = frozenset().union(*(sets[p] for p in parts))
|
||||
else: # "all"
|
||||
edges = frozenset().union(*(sets[p] for p in _ALL_CONNECTION_PARTS))
|
||||
rgb, thick, psize = _parse_color(color), int(thickness), int(point_size)
|
||||
frames = face_landmarks["frames"]
|
||||
if image is None:
|
||||
H, W = face_landmarks["image_size"]
|
||||
img_np = np.zeros((len(frames), H, W, 3), dtype=np.uint8)
|
||||
else:
|
||||
img_np = _image_to_uint8(image)
|
||||
B = img_np.shape[0]
|
||||
n_frames = len(frames)
|
||||
pbar = comfy.utils.ProgressBar(B)
|
||||
out = np.empty_like(img_np)
|
||||
for bi in range(B):
|
||||
faces = frames[bi] if bi < n_frames else []
|
||||
out[bi] = _draw_mesh(img_np[bi], faces, edges, rgb, thick, psize, fill_rings)
|
||||
pbar.update_absolute(bi + 1)
|
||||
return io.NodeOutput(torch.from_numpy(out).to(
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
dtype=comfy.model_management.intermediate_dtype(),
|
||||
).div_(255.0))
|
||||
|
||||
|
||||
def _draw_mesh(image_rgb: np.ndarray, faces: list, edges,
|
||||
rgb: tuple[int, int, int], thickness: int,
|
||||
point_size: int, fill_rings: list[list[int]] | None = None) -> np.ndarray:
|
||||
draw_edges = thickness > 0 and edges
|
||||
if not faces or (fill_rings is None and not draw_edges and point_size <= 0):
|
||||
return image_rgb.copy()
|
||||
pil = Image.fromarray(image_rgb)
|
||||
draw = ImageDraw.Draw(pil)
|
||||
r = point_size * 0.5
|
||||
if fill_rings is not None:
|
||||
for f in faces:
|
||||
lmks = f["landmarks_xy"]
|
||||
for ring in fill_rings:
|
||||
draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=rgb)
|
||||
return np.asarray(pil)
|
||||
for f in faces:
|
||||
lmks = f["landmarks_xy"]
|
||||
n = lmks.shape[0]
|
||||
if draw_edges:
|
||||
for a, b in edges:
|
||||
if a < n and b < n:
|
||||
draw.line([(float(lmks[a, 0]), float(lmks[a, 1])),
|
||||
(float(lmks[b, 0]), float(lmks[b, 1]))], fill=rgb, width=thickness)
|
||||
if point_size == 1:
|
||||
draw.point(lmks.flatten().tolist(), fill=rgb)
|
||||
elif point_size > 1:
|
||||
for x, y in lmks:
|
||||
draw.ellipse((float(x) - r, float(y) - r, float(x) + r, float(y) + r), fill=rgb)
|
||||
return np.asarray(pil)
|
||||
|
||||
|
||||
# Mask region presets — closed-loop topologies only.
|
||||
_MASK_REGIONS: tuple[str, ...] = ("face_oval", "lips", "left_eye", "right_eye", "irises")
|
||||
_MASK_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = (
|
||||
("face_oval", True),
|
||||
("lips", False),
|
||||
("left_eye", False),
|
||||
("right_eye", False),
|
||||
("irises", False),
|
||||
)
|
||||
|
||||
|
||||
class MediaPipeFaceMask(io.ComfyNode):
|
||||
"""Binary mask from face landmarks, filled polygon per face. One mask per
|
||||
frame in the batch; faces in the same frame composite (union)."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MediaPipeFaceMask",
|
||||
search_aliases=["face", "facial", "mediapipe", "face mask", "blazeface", "face detection", "visualize"],
|
||||
display_name="Draw Face Mask (MediaPipe)",
|
||||
category="image/detection",
|
||||
description="Draws a mask from face landmarks.",
|
||||
inputs=[
|
||||
FaceLandmarksType.Input("face_landmarks"),
|
||||
io.DynamicCombo.Input(
|
||||
"regions",
|
||||
tooltip="'all' = union of face_oval+lips+eyes+irises (which collapses to face_oval since it encloses the rest). 'custom' = toggle each region individually for combos like lips+eyes.",
|
||||
options=[
|
||||
io.DynamicCombo.Option("all", []),
|
||||
io.DynamicCombo.Option("custom", [
|
||||
io.Boolean.Input(reg, default=default,
|
||||
tooltip=f"Include the '{reg}' region in the mask.")
|
||||
for reg, default in _MASK_CUSTOM_FEATURES
|
||||
]),
|
||||
],
|
||||
),
|
||||
],
|
||||
outputs=[io.Mask.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, face_landmarks, regions) -> io.NodeOutput:
|
||||
sets = face_landmarks["connection_sets"]
|
||||
sel = regions["regions"]
|
||||
if sel == "custom":
|
||||
picked = [reg for reg, _ in _MASK_CUSTOM_FEATURES if regions.get(reg, False)]
|
||||
else:
|
||||
picked = list(_MASK_REGIONS)
|
||||
rings = [r for reg in picked for r in _ordered_rings(sets[reg])]
|
||||
frames = face_landmarks["frames"]
|
||||
H, W = face_landmarks["image_size"]
|
||||
masks = np.zeros((len(frames), H, W), dtype=np.uint8)
|
||||
pbar = comfy.utils.ProgressBar(len(frames))
|
||||
for bi, per_frame in enumerate(frames):
|
||||
if per_frame:
|
||||
pil = Image.new("L", (W, H), 0)
|
||||
draw = ImageDraw.Draw(pil)
|
||||
for f in per_frame:
|
||||
lmks = f["landmarks_xy"]
|
||||
for ring in rings:
|
||||
draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=255)
|
||||
masks[bi] = np.asarray(pil)
|
||||
pbar.update_absolute(bi + 1)
|
||||
return io.NodeOutput(torch.from_numpy(masks).to(
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
dtype=comfy.model_management.intermediate_dtype(),
|
||||
).div_(255.0))
|
||||
|
||||
|
||||
class MediaPipeFaceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [LoadMediaPipeFaceLandmarker, MediaPipeFaceLandmarker, MediaPipeFaceMeshVisualize, MediaPipeFaceMask]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MediaPipeFaceExtension:
|
||||
return MediaPipeFaceExtension()
|
||||
@ -103,10 +103,8 @@ class MoGePanoramaInference(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGePanoramaInference",
|
||||
search_aliases=["moge", "panorama", "depth", "geometry", "depth estimation", "geometry estimation"],
|
||||
display_name="Run MoGe Panorama Inference",
|
||||
display_name="MoGe Panorama Inference",
|
||||
category="image/geometry_estimation",
|
||||
description="Run MoGe on an equirectangular panorama by splitting it into 12 perspective views, running inference on each, and merging the results into a single depth map.",
|
||||
inputs=[
|
||||
MoGeModelType.Input("moge_model"),
|
||||
io.Image.Input("image", tooltip="Equirectangular panorama (any aspect)."),
|
||||
@ -224,9 +222,7 @@ class MoGeInference(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeInference",
|
||||
search_aliases=["moge", "depth", "geometry", "depth estimation", "geometry estimation"],
|
||||
display_name="Run MoGe Inference",
|
||||
description="Run MoGe on a single image to estimate depth and geometry.",
|
||||
display_name="MoGe Inference",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeModelType.Input("moge_model"),
|
||||
@ -281,9 +277,7 @@ class MoGeRender(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeRender",
|
||||
search_aliases=["moge", "render", "geometry", "depth", "normal"],
|
||||
display_name="Render MoGe Geometry",
|
||||
description="Render a depth map or normal map from geometry data",
|
||||
display_name="MoGe Render",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeGeometry.Input("moge_geometry"),
|
||||
@ -348,9 +342,7 @@ class MoGePointMapToMesh(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGePointMapToMesh",
|
||||
search_aliases=["moge", "mesh", "geometry", "point map"],
|
||||
display_name="Convert MoGe Point Map to Mesh",
|
||||
description="Convert a MoGe point map into a 3D mesh.",
|
||||
display_name="MoGe Point Map to Mesh",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeGeometry.Input("moge_geometry"),
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.22.0"
|
||||
__version__ = "0.22.2"
|
||||
|
||||
12
execution.py
12
execution.py
@ -2,7 +2,6 @@ import copy
|
||||
import heapq
|
||||
import inspect
|
||||
import logging
|
||||
import psutil
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@ -728,7 +727,6 @@ class PromptExecutor:
|
||||
|
||||
self._notify_prompt_lifecycle("start", prompt_id)
|
||||
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
|
||||
ram_inactive_headroom = int(self.cache_args["ram_inactive"] * (1024 ** 3))
|
||||
ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None
|
||||
comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom)
|
||||
|
||||
@ -782,14 +780,8 @@ class PromptExecutor:
|
||||
execution_list.complete_node_execution()
|
||||
|
||||
if self.cache_type == CacheType.RAM_PRESSURE:
|
||||
ram_release_callback(ram_inactive_headroom)
|
||||
ram_shortfall = ram_headroom - psutil.virtual_memory().available
|
||||
freed = comfy.model_management.free_pins(ram_shortfall + 512 * (1024 ** 2))
|
||||
if freed < ram_shortfall:
|
||||
if freed > 64 * (1024 ** 2):
|
||||
# AIMDO MEM_DECOMMIT can outrun psutil.available catching up.
|
||||
time.sleep(0.05)
|
||||
ram_release_callback(ram_headroom, free_active=True)
|
||||
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
|
||||
ram_release_callback(ram_headroom, free_active=True)
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
# Send cached UI for intermediate output nodes that weren't executed
|
||||
|
||||
@ -60,8 +60,6 @@ folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geom
|
||||
|
||||
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["detection"] = ([os.path.join(models_dir, "detection")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
input_directory = os.path.join(base_path, "input")
|
||||
|
||||
20
main.py
20
main.py
@ -283,25 +283,19 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
cache_ram = 0
|
||||
cache_ram_inactive = 0
|
||||
if not args.cache_classic and not args.cache_none and args.cache_lru <= 0:
|
||||
cache_ram = args.cache_ram
|
||||
if cache_ram < 0:
|
||||
cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0))
|
||||
cache_ram_inactive = min(96.0, max(12.0, comfy.model_management.total_ram * 0.75 / 1024.0))
|
||||
if len(args.cache_ram) > 0:
|
||||
cache_ram = args.cache_ram[0]
|
||||
if len(args.cache_ram) > 1:
|
||||
cache_ram_inactive = args.cache_ram[1]
|
||||
|
||||
cache_type = execution.CacheType.RAM_PRESSURE
|
||||
if args.cache_classic:
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
elif args.cache_lru > 0:
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
if args.cache_lru > 0:
|
||||
cache_type = execution.CacheType.LRU
|
||||
elif cache_ram > 0:
|
||||
cache_type = execution.CacheType.RAM_PRESSURE
|
||||
elif args.cache_none:
|
||||
cache_type = execution.CacheType.NONE
|
||||
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram, "ram_inactive" : cache_ram_inactive } )
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } )
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
|
||||
1
nodes.py
1
nodes.py
@ -2444,7 +2444,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_hidream_o1.py",
|
||||
"nodes_save_3d.py",
|
||||
"nodes_moge.py",
|
||||
"nodes_mediapipe.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
673
openapi.yaml
673
openapi.yaml
@ -1556,6 +1556,12 @@ paths:
|
||||
type: string
|
||||
enum: [asc, desc]
|
||||
description: Sort direction
|
||||
- name: job_ids
|
||||
in: query
|
||||
schema:
|
||||
type: string
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Comma-separated UUIDs to filter assets by associated job."
|
||||
- name: include_public
|
||||
in: query
|
||||
schema:
|
||||
@ -2508,25 +2514,37 @@ paths:
|
||||
|
||||
/api/assets/import:
|
||||
post:
|
||||
operationId: importPublishedAssets
|
||||
operationId: importAssets
|
||||
tags: [assets]
|
||||
summary: "[cloud-only] Import published assets into the caller's library"
|
||||
description: |
|
||||
[cloud-only] Imports the specified published assets into the caller's asset library. New DB records reference the same storage objects; no file copying occurs. Assets the caller already owns (by hash) are deduplicated. The `id` field on each returned `AssetInfo` is the caller's newly-created private asset ID, not the published asset ID supplied in the request.
|
||||
summary: Import assets from external URLs
|
||||
description: "[cloud-only] Imports one or more assets from external URLs into the cloud asset store."
|
||||
x-runtime: [cloud]
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ImportPublishedAssetsRequest"
|
||||
type: object
|
||||
required:
|
||||
- imports
|
||||
properties:
|
||||
imports:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/AssetImportRequest"
|
||||
description: Assets to import
|
||||
responses:
|
||||
"200":
|
||||
description: Successfully imported assets
|
||||
description: Import initiated
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ImportPublishedAssetsResponse"
|
||||
type: object
|
||||
properties:
|
||||
assets:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/Asset"
|
||||
"400":
|
||||
description: Bad request
|
||||
content:
|
||||
@ -3772,295 +3790,6 @@ paths:
|
||||
schema:
|
||||
$ref: "#/components/schemas/JwksResponse"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OAuth 2.1 / RFC 7591 Dynamic Client Registration (cloud)
|
||||
# ---------------------------------------------------------------------------
|
||||
/.well-known/oauth-authorization-server:
|
||||
get:
|
||||
operationId: getOAuthAuthorizationServer
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] OAuth 2.1 authorization-server metadata (RFC 8414)"
|
||||
description: "[cloud-only] Public metadata document for OAuth 2.1 clients. Cached 5 minutes."
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
responses:
|
||||
"200":
|
||||
description: Authorization-server metadata
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthAuthorizationServerMetadata"
|
||||
"404":
|
||||
description: OAuth disabled
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
|
||||
/.well-known/oauth-protected-resource:
|
||||
get:
|
||||
operationId: getOAuthProtectedResource
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] OAuth 2.1 protected-resource metadata (RFC 9728)"
|
||||
description: "[cloud-only] Public metadata describing the currently advertised protected resource. Cached 5 minutes."
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
responses:
|
||||
"200":
|
||||
description: Protected-resource metadata
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthProtectedResourceMetadata"
|
||||
"404":
|
||||
description: OAuth disabled or no active resource configured
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
|
||||
/oauth/authorize:
|
||||
get:
|
||||
operationId: getOAuthAuthorize
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] Begin or resume an OAuth 2.1 authorization request"
|
||||
description: |
|
||||
[cloud-only] Two modes:
|
||||
- **Initial entry** (OAuth params present): validates client/redirect/resource/scopes, persists a server-side authorization-request row, and either redirects (no session / unverified email) to the configured frontend login URL carrying only the opaque `oauth_request_id`, or returns the JSON consent challenge for the frontend to render.
|
||||
- **Resume** (`oauth_request_id` present): loads the server-side row, fails closed if expired/consumed/unknown, returns the JSON consent challenge. Browser-replayed OAuth params are intentionally ignored.
|
||||
|
||||
The frontend renders the consent UI from the JSON payload and POSTs the user's decision back to this endpoint.
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
parameters:
|
||||
- { name: response_type, in: query, required: false, schema: { type: string } }
|
||||
- { name: client_id, in: query, required: false, schema: { type: string } }
|
||||
- { name: redirect_uri, in: query, required: false, schema: { type: string } }
|
||||
- { name: scope, in: query, required: false, schema: { type: string } }
|
||||
- name: state
|
||||
in: query
|
||||
required: false
|
||||
schema: { type: string }
|
||||
description: |
|
||||
RFC 6749 §10.12 marks `state` as RECOMMENDED. Cloud hardening makes it REQUIRED on the initial-entry path (omitted only on the resume path where `oauth_request_id` is supplied instead). This parameter is `required: false` at the spec level only because the operation is dual-mode (initial entry vs. resume); the runtime rejects empty `state` on the initial-entry path with a stable `invalid_request` 400.
|
||||
- { name: code_challenge, in: query, required: false, schema: { type: string } }
|
||||
- { name: code_challenge_method, in: query, required: false, schema: { type: string } }
|
||||
- { name: resource, in: query, required: false, schema: { type: string } }
|
||||
- { name: oauth_request_id, in: query, required: false, schema: { type: string } }
|
||||
responses:
|
||||
"200":
|
||||
description: Consent challenge payload (session present, email verified). Frontend renders the consent UI from this payload and POSTs back to /oauth/authorize.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthConsentChallenge"
|
||||
"302":
|
||||
description: Redirect to login (no session / unverified email) or to registered redirect_uri (pre-validated client error)
|
||||
headers:
|
||||
Location:
|
||||
schema:
|
||||
type: string
|
||||
"400":
|
||||
description: Invalid authorize request (pre-redirect failure — unknown client, redirect mismatch, malformed params)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
"404":
|
||||
description: OAuth disabled
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
post:
|
||||
operationId: postOAuthAuthorize
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] Submit OAuth consent decision"
|
||||
description: |
|
||||
[cloud-only] JSON-only consent submission. The handler verifies the per-row CSRF token, atomically marks the authorization request consumed (single-use covers both allow and deny paths), then returns the redirect URL the browser must navigate to. The URL contains either `code` + original `state` for allow, or the RFC 6749 §5.2 error and `state` for deny.
|
||||
|
||||
Workspace membership is re-checked at submission time. Consent is persisted keyed by `(user_id, client_id, resource_id, workspace_id)`; broadening the previously approved scope set requires a fresh consent flow.
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required: [oauth_request_id, csrf_token, decision, workspace_id]
|
||||
properties:
|
||||
oauth_request_id: { type: string, format: uuid }
|
||||
csrf_token: { type: string }
|
||||
decision: { type: string, enum: [allow, deny] }
|
||||
workspace_id: { type: string }
|
||||
responses:
|
||||
"200":
|
||||
description: Redirect URL for the frontend to navigate to (allow → with code+state; deny → with error+state)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthAuthorizeRedirectResponse"
|
||||
"400":
|
||||
description: Bad request (CSRF mismatch, expired/consumed request, inaccessible workspace)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
"403":
|
||||
description: Scope broadening on consent re-grant — fresh consent flow required
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
"404":
|
||||
description: OAuth disabled
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
|
||||
/oauth/token:
|
||||
post:
|
||||
operationId: postOAuthToken
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] Exchange authorization code or refresh token for a resource-bound access token"
|
||||
description: |
|
||||
[cloud-only] OAuth 2.1 token endpoint (RFC 6749 §3.2). Public clients only — `client_secret` is rejected.
|
||||
|
||||
Two grant types are supported:
|
||||
- `authorization_code` — exchanges the code minted by `/oauth/authorize` (with PKCE verifier) for an access token + first refresh token. Single-use; reuse fails closed.
|
||||
- `refresh_token` — rotates the refresh token. Old token immediately invalid; presenting an already-rotated token revokes the entire token family and emits a security metric.
|
||||
|
||||
Both grant types re-validate canonical user state, current workspace membership, and the resource's active flag at every mint. A code or refresh token bound to a deactivated resource fails closed.
|
||||
|
||||
Errors follow RFC 6749 §5.2. Logs never contain raw codes, refresh tokens, or minted tokens.
|
||||
|
||||
Per RFC 6749 §5.1, every 200 and 400 response carries `Cache-Control: no-store` and `Pragma: no-cache` so intermediaries cannot cache token-bearing or state-change-reason responses.
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/x-www-form-urlencoded:
|
||||
schema:
|
||||
type: object
|
||||
required: [grant_type, client_id]
|
||||
properties:
|
||||
grant_type: { type: string, enum: [authorization_code, refresh_token] }
|
||||
client_id: { type: string }
|
||||
code: { type: string }
|
||||
redirect_uri: { type: string }
|
||||
code_verifier: { type: string }
|
||||
refresh_token: { type: string }
|
||||
scope: { type: string }
|
||||
client_secret: { type: string }
|
||||
responses:
|
||||
"200":
|
||||
description: New token pair
|
||||
headers:
|
||||
Cache-Control:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-store" per RFC 6749 §5.1'
|
||||
Pragma:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-cache" per RFC 6749 §5.1'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthTokenResponse"
|
||||
"400":
|
||||
description: RFC 6749 §5.2 error
|
||||
headers:
|
||||
Cache-Control:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-store" per RFC 6749 §5.1'
|
||||
Pragma:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-cache" per RFC 6749 §5.1'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthTokenError"
|
||||
"404":
|
||||
description: OAuth disabled
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
|
||||
/oauth/register:
|
||||
post:
|
||||
operationId: postOAuthRegister
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] Dynamic Client Registration (RFC 7591)"
|
||||
description: |
|
||||
[cloud-only] Public, unauthenticated, insert-only RFC 7591 §3.1 client registration. Used by MCP-spec-compliant clients to self-register a public OAuth client without operator involvement.
|
||||
|
||||
Policy:
|
||||
|
||||
- Public clients only — `token_endpoint_auth_method` is forced to `none`. Confidential-client registration is out of scope this phase.
|
||||
- Server-owned `resource_grants`. Caller-supplied `scope` or `resource_grants` is rejected as `invalid_client_metadata` (would be a privilege-escalation surface). Dynamic clients receive the same scopes the active resource publishes.
|
||||
- Application-type-aware redirect URI policy. `application_type=native` accepts loopback (`127.0.0.1`, `::1`, `localhost`) and reverse-DNS-shaped custom schemes; `application_type=web` accepts HTTPS to hosts in an operator-controlled allowlist only. `application_type` is REQUIRED on the request — missing or empty rejects with `invalid_client_metadata`.
|
||||
- Anti-impersonation: reserved client names are rejected from third parties via NFKC-folded compare.
|
||||
- Generated `client_id` carries a stable prefix to distinguish dynamic from seeded clients in audit logs.
|
||||
- Cache-Control: `no-store` on every 201 and 400 response (the response carries fresh credentials and rejection reasons).
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthRegisterRequest"
|
||||
responses:
|
||||
"201":
|
||||
description: Registered. Body echoes the metadata RFC 7591 §3.2.1 requires.
|
||||
headers:
|
||||
Cache-Control:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-store"'
|
||||
Pragma:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-cache"'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthRegisterResponse"
|
||||
"400":
|
||||
description: RFC 7591 §3.2.2 invalid client metadata
|
||||
headers:
|
||||
Cache-Control:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-store"'
|
||||
Pragma:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-cache"'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthRegisterError"
|
||||
"404":
|
||||
description: OAuth disabled
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
"503":
|
||||
description: No active resource is configured — DCR cannot mint a usable client until an active resource row is seeded.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Billing (cloud)
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -7361,35 +7090,24 @@ components:
|
||||
type: string
|
||||
description: Target path on the runtime filesystem
|
||||
|
||||
ImportPublishedAssetsRequest:
|
||||
AssetImportRequest:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Request body for importing published assets into the caller's library."
|
||||
description: "[cloud-only] A single asset to import from an external URL."
|
||||
required:
|
||||
- published_asset_ids
|
||||
- url
|
||||
properties:
|
||||
published_asset_ids:
|
||||
url:
|
||||
type: string
|
||||
format: uri
|
||||
description: URL of the asset to import
|
||||
name:
|
||||
type: string
|
||||
description: Display name for the imported asset
|
||||
tags:
|
||||
type: array
|
||||
description: IDs of published assets (inputs and models) to import.
|
||||
items:
|
||||
type: string
|
||||
share_id:
|
||||
type: string
|
||||
nullable: true
|
||||
description: |
|
||||
Optional. Share ID of the published workflow these assets belong to. When provided (non-null, non-empty): all `published_asset_ids` must belong to this share's workflow version; returns 400 if the share is not found or any asset does not belong to it. When omitted, null, or empty string: no share-scoped validation is performed and the assets are validated only against global rules (preserved for clients that have not yet adopted `share_id`).
|
||||
|
||||
ImportPublishedAssetsResponse:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Response after importing published assets. Each returned `AssetInfo.id` is the caller's newly-created private asset ID, not the published asset ID supplied in the request."
|
||||
required:
|
||||
- assets
|
||||
properties:
|
||||
assets:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/AssetInfo"
|
||||
|
||||
RemoteAssetMetadata:
|
||||
type: object
|
||||
@ -7706,325 +7424,6 @@ components:
|
||||
description: RSA exponent (base64url)
|
||||
additionalProperties: true
|
||||
|
||||
OAuthAuthorizationServerMetadata:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] OAuth 2.1 authorization-server metadata (RFC 8414)."
|
||||
required:
|
||||
- issuer
|
||||
- authorization_endpoint
|
||||
- token_endpoint
|
||||
- jwks_uri
|
||||
- response_types_supported
|
||||
- grant_types_supported
|
||||
- code_challenge_methods_supported
|
||||
- token_endpoint_auth_methods_supported
|
||||
properties:
|
||||
issuer:
|
||||
type: string
|
||||
format: uri
|
||||
authorization_endpoint:
|
||||
type: string
|
||||
format: uri
|
||||
token_endpoint:
|
||||
type: string
|
||||
format: uri
|
||||
jwks_uri:
|
||||
type: string
|
||||
format: uri
|
||||
registration_endpoint:
|
||||
type: string
|
||||
format: uri
|
||||
description: "[cloud-only] RFC 7591 §3.1 Dynamic Client Registration endpoint. Advertised so MCP-spec-compliant clients can auto-discover and self-register without operator involvement. Present only when DCR is enabled."
|
||||
response_types_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
grant_types_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
code_challenge_methods_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
token_endpoint_auth_methods_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
scopes_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
||||
OAuthProtectedResourceMetadata:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] OAuth 2.1 protected-resource metadata (RFC 9728)."
|
||||
required:
|
||||
- resource
|
||||
- authorization_servers
|
||||
- scopes_supported
|
||||
properties:
|
||||
resource:
|
||||
type: string
|
||||
format: uri
|
||||
authorization_servers:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
format: uri
|
||||
scopes_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
bearer_methods_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
||||
OAuthConsentChallenge:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Server-side state describing the OAuth consent decision the user is being asked to make. Returned by GET /oauth/authorize when a valid session exists; the frontend renders the consent UI from this payload and POSTs the decision back. Browser never sees the original OAuth params on resume."
|
||||
required:
|
||||
- oauth_request_id
|
||||
- csrf_token
|
||||
- client_display_name
|
||||
- resource_display_name
|
||||
- scopes
|
||||
- workspaces
|
||||
properties:
|
||||
oauth_request_id:
|
||||
type: string
|
||||
format: uuid
|
||||
description: Opaque server-side identifier for the authorization-request row. Carried back unchanged in the consent submission.
|
||||
csrf_token:
|
||||
type: string
|
||||
description: Per-row CSRF token bound to this authorization request (not to the session). Must be echoed back on POST.
|
||||
client_display_name:
|
||||
type: string
|
||||
description: Human-readable name of the OAuth client requesting authorization.
|
||||
resource_display_name:
|
||||
type: string
|
||||
description: Human-readable name of the protected resource.
|
||||
scopes:
|
||||
type: array
|
||||
description: Scopes the client is requesting for this resource. The frontend should present these for the user to approve.
|
||||
items:
|
||||
type: string
|
||||
workspaces:
|
||||
type: array
|
||||
description: Workspaces the user can select from. Membership is re-checked on POST.
|
||||
items:
|
||||
$ref: "#/components/schemas/OAuthConsentChallengeWorkspace"
|
||||
|
||||
OAuthConsentChallengeWorkspace:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] One workspace option presented in the OAuth consent challenge."
|
||||
required: [id, name, type, role]
|
||||
properties:
|
||||
id: { type: string }
|
||||
name: { type: string }
|
||||
type: { type: string, enum: [personal, team] }
|
||||
role: { type: string, enum: [owner, member] }
|
||||
|
||||
OAuthAuthorizeRedirectResponse:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Redirect target produced after a JSON consent submission. The frontend must navigate the browser to this URL so custom-scheme client callbacks work without relying on fetch-visible 302 headers."
|
||||
required:
|
||||
- redirect_url
|
||||
properties:
|
||||
redirect_url:
|
||||
type: string
|
||||
format: uri
|
||||
description: OAuth client redirect URI with either code+state for allow, or error+state for deny.
|
||||
|
||||
OAuthTokenResponse:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] RFC 6749 §5.1 successful token response."
|
||||
required: [access_token, token_type, expires_in, refresh_token, scope]
|
||||
properties:
|
||||
access_token:
|
||||
type: string
|
||||
description: Resource-bound access token (audience matches the protected resource).
|
||||
token_type:
|
||||
type: string
|
||||
enum: [Bearer]
|
||||
expires_in:
|
||||
type: integer
|
||||
description: Access token lifetime in seconds.
|
||||
refresh_token:
|
||||
type: string
|
||||
description: Opaque refresh token. Rotates on every successful refresh; presenting an already-rotated token revokes the entire family.
|
||||
scope:
|
||||
type: string
|
||||
description: Space-delimited scopes granted with this token.
|
||||
|
||||
OAuthTokenError:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] RFC 6749 §5.2 error response."
|
||||
required: [error]
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
description: 'RFC 6749 §5.2 error code: invalid_request, invalid_client, invalid_grant, unauthorized_client, unsupported_grant_type, invalid_scope.'
|
||||
error_description:
|
||||
type: string
|
||||
description: Human-readable, no leak of internal storage state.
|
||||
|
||||
OAuthRegisterRequest:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
additionalProperties: false
|
||||
description: "[cloud-only] RFC 7591 §2 client metadata document. Only the fields the server honors are listed; presence of `scope` or `resource_grants` in the request is rejected (`invalid_client_metadata`) because those are server-owned for dynamic clients."
|
||||
required:
|
||||
- redirect_uris
|
||||
- application_type
|
||||
properties:
|
||||
redirect_uris:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
minItems: 1
|
||||
maxItems: 5
|
||||
description: 1–5 redirect URIs. Validated against `application_type` policy.
|
||||
client_name:
|
||||
type: string
|
||||
maxLength: 100
|
||||
description: Human-readable name shown in the consent UI. Reserved-name list rejects impersonation of major clients.
|
||||
application_type:
|
||||
type: string
|
||||
enum: [native, web]
|
||||
description: |
|
||||
RFC 7591 §2 application_type. **REQUIRED** — clients MUST declare intent; the server does not default this field. `native` for desktop / CLI / MCP-spec-strict clients (loopback redirects); `web` for hosted clients (HTTPS only, host must be allowlisted). A missing or explicitly empty `application_type` rejects with `invalid_client_metadata`.
|
||||
token_endpoint_auth_method:
|
||||
type: string
|
||||
enum: [none]
|
||||
description: 'Public clients only this phase — must be `none` if present. The server forces `none` regardless.'
|
||||
grant_types:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
enum: [authorization_code, refresh_token]
|
||||
description: Optional. Defaults to `["authorization_code","refresh_token"]`.
|
||||
response_types:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
enum: [code]
|
||||
description: Optional. Defaults to `["code"]`.
|
||||
scope:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Dynamic clients do not pick scopes — the server assigns scopes from the active resource's published list. Sending `scope` in the registration body is treated as a privilege-escalation attempt and returns `invalid_client_metadata`."
|
||||
resource_grants:
|
||||
type: object
|
||||
nullable: true
|
||||
additionalProperties:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: "**REJECTED IF PRESENT.** Same reason as `scope`. The set of resources and scopes a dynamic client may request is server-policy, not request-driven."
|
||||
client_uri:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
logo_uri:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
tos_uri:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
policy_uri:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
software_id:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
software_version:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
contacts:
|
||||
type: array
|
||||
nullable: true
|
||||
items:
|
||||
type: string
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
jwks:
|
||||
type: object
|
||||
nullable: true
|
||||
additionalProperties: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
jwks_uri:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
|
||||
OAuthRegisterResponse:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] RFC 7591 §3.2.1 successful registration response."
|
||||
required:
|
||||
- client_id
|
||||
- client_id_issued_at
|
||||
- redirect_uris
|
||||
- grant_types
|
||||
- response_types
|
||||
- token_endpoint_auth_method
|
||||
- application_type
|
||||
properties:
|
||||
client_id:
|
||||
type: string
|
||||
description: Server-generated client_id.
|
||||
client_id_issued_at:
|
||||
type: integer
|
||||
format: int64
|
||||
description: Unix timestamp (seconds) when the client was registered.
|
||||
client_name:
|
||||
type: string
|
||||
redirect_uris:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
grant_types:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
response_types:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
token_endpoint_auth_method:
|
||||
type: string
|
||||
enum: [none]
|
||||
application_type:
|
||||
type: string
|
||||
enum: [native, web]
|
||||
|
||||
OAuthRegisterError:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] RFC 7591 §3.2.2 error response."
|
||||
required:
|
||||
- error
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
enum: [invalid_redirect_uri, invalid_client_metadata]
|
||||
error_description:
|
||||
type: string
|
||||
nullable: true
|
||||
|
||||
BillingBalance:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.22.0"
|
||||
version = "0.22.2"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.8
|
||||
comfy-aimdo==0.4.3
|
||||
comfy-aimdo==0.3.0
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
blake3
|
||||
|
||||
@ -14,6 +14,7 @@ from tests.execution.test_execution import ComfyClient, run_warmup
|
||||
class TestAsyncNodes:
|
||||
@fixture(scope="class", autouse=True, params=[
|
||||
(False, 0),
|
||||
(True, 0),
|
||||
(True, 100),
|
||||
])
|
||||
def _server(self, args_pytest, request):
|
||||
@ -28,8 +29,6 @@ class TestAsyncNodes:
|
||||
use_lru, lru_size = request.param
|
||||
if use_lru:
|
||||
pargs += ['--cache-lru', str(lru_size)]
|
||||
else:
|
||||
pargs += ['--cache-classic']
|
||||
# Running server with args: pargs
|
||||
p = subprocess.Popen(pargs)
|
||||
yield
|
||||
|
||||
@ -183,7 +183,8 @@ class TestExecution:
|
||||
# Initialize server and client
|
||||
#
|
||||
@fixture(scope="class", autouse=True, params=[
|
||||
{ "extra_args" : ["--cache-classic"], "should_cache_results" : True },
|
||||
{ "extra_args" : [], "should_cache_results" : True },
|
||||
{ "extra_args" : ["--cache-lru", 0], "should_cache_results" : True },
|
||||
{ "extra_args" : ["--cache-lru", 100], "should_cache_results" : True },
|
||||
{ "extra_args" : ["--cache-none"], "should_cache_results" : False },
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user