diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 76faed3ad..9d88c8517 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -110,13 +110,11 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") -CACHE_RAM_AUTO_GB = -1.0 - cache_group = parser.add_mutually_exclusive_group() +cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metavar="GB", help="Use RAM pressure caching with the specified headroom thresholds. This is the default caching mode. The first value sets the active-cache threshold; the optional second value sets the inactive-cache/pin threshold. Defaults when no values are provided: active 25%% of system RAM (min 4GB, max 32GB), inactive 75%% of system RAM (min 12GB, max 96GB).") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") -cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") @@ -245,6 +243,9 @@ if comfy.options.args_parsing: else: args = parser.parse_args([]) +if args.cache_ram is not None and len(args.cache_ram) > 2: + parser.error("--cache-ram accepts at most two values: active GB and inactive GB") + if args.windows_standalone_build: args.auto_launch = True diff --git a/comfy/lora.py b/comfy/lora.py index f11e26ec9..c0e8b865c 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -484,16 +484,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori return weight -def prefetch_prepared_value(value, allocate_buffer, stream): +def prefetch_prepared_value(value, counter, destination, stream, copy): if isinstance(value, torch.Tensor): - dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value)) - comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) + size = comfy.memory_management.vram_aligned_size(value) + offset = counter[0] + counter[0] += size + if destination is None: + return value + + dest = destination[offset:offset + size] + if copy: + comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) return comfy.memory_management.interpret_gathered_like([value], dest)[0] elif isinstance(value, weight_adapter.WeightAdapterBase): - return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream)) + return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy)) elif isinstance(value, tuple): - return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value) + return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value) elif isinstance(value, list): - return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value] + return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value] return value diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 48e3c11da..c43f0c4a2 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -15,7 +15,7 @@ class TensorFileSlice(NamedTuple): size: int -def read_tensor_file_slice_into(tensor, destination): +def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None): if isinstance(tensor, QuantizedTensor): if not isinstance(destination, QuantizedTensor): @@ -23,12 +23,17 @@ def read_tensor_file_slice_into(tensor, destination): if tensor._layout_cls != destination._layout_cls: return False - if not read_tensor_file_slice_into(tensor._qdata, destination._qdata): + if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream, + destination2=(destination2._qdata if destination2 is not None else None)): return False dst_orig_dtype = destination._params.orig_dtype destination._params.copy_from(tensor._params, non_blocking=False) destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype) + if destination2 is not None: + dst_orig_dtype = destination2._params.orig_dtype + destination2._params.copy_from(destination._params, non_blocking=True) + destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype) return True info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None) @@ -48,6 +53,17 @@ def read_tensor_file_slice_into(tensor, destination): if info.size == 0: return True + hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None) + if hostbuf is not None: + stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0 + device_ptr = destination2.data_ptr() if destination2 is not None else 0 + hostbuf.read_file_slice(file_obj, info.offset, info.size, + offset=destination.data_ptr() - hostbuf.get_raw_address(), + stream=stream_ptr, + device_ptr=device_ptr, + device=None if destination2 is None else destination2.device.index) + return True + buf_type = ctypes.c_ubyte * info.size view = memoryview(buf_type.from_address(destination.data_ptr())) @@ -151,7 +167,7 @@ def set_ram_cache_release_state(callback, headroom): extra_ram_release_callback = callback RAM_CACHE_HEADROOM = max(0, int(headroom)) -def extra_ram_release(target): +def extra_ram_release(target, free_active=False): if extra_ram_release_callback is None: return 0 - return extra_ram_release_callback(target) + return extra_ram_release_callback(target, free_active=free_active) diff --git a/comfy/model_management.py b/comfy/model_management.py index 21738a4c7..3894dfa9c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,7 @@ from contextlib import nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops +import comfy_aimdo.host_buffer import comfy_aimdo.vram_buffer class VRAMState(Enum): @@ -495,6 +496,14 @@ except: current_loaded_models = [] +DIRTY_MMAPS = set() + +PIN_PRESSURE_HYSTERESIS = 256 * 1024 * 1024 + +#Freeing registerables on pressure does imply a GPU sync, so go big on +#the hysteresis so each expensive sync gives us back a good chunk. +REGISTERABLE_PIN_HYSTERESIS = 2048 * 1024 * 1024 + def module_size(module): module_mem = 0 sd = module.state_dict() @@ -503,27 +512,46 @@ def module_size(module): module_mem += t.nbytes return module_mem -def module_mmap_residency(module, free=False): - mmap_touched_mem = 0 - module_mem = 0 - bounced_mmaps = set() - sd = module.state_dict() - for k in sd: - t = sd[k] - module_mem += t.nbytes - storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage() - if not getattr(storage, "_comfy_tensor_mmap_touched", False): - continue - mmap_touched_mem += t.nbytes - if not free: - continue - storage._comfy_tensor_mmap_touched = False - mmap_obj = storage._comfy_tensor_mmap_refs[0] - if mmap_obj in bounced_mmaps: - continue - mmap_obj.bounce() - bounced_mmaps.add(mmap_obj) - return mmap_touched_mem, module_mem +def mark_mmap_dirty(storage): + mmap_refs = getattr(storage, "_comfy_tensor_mmap_refs", None) + if mmap_refs is not None: + DIRTY_MMAPS.add(mmap_refs[0]) + +def free_pins(size, evict_active=False): + freed_total = 0 + for loaded_model in reversed(current_loaded_models): + if size <= 0: + return freed_total + model = loaded_model.model + if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]): + freed = model.partially_unload_ram(size) + freed_total += freed + size -= freed + return freed_total + +def ensure_pin_budget(size, evict_active=False): + shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available + if shortfall <= 0: + return True + + to_free = shortfall + PIN_PRESSURE_HYSTERESIS + return free_pins(to_free, evict_active=evict_active) >= shortfall + +def ensure_pin_registerable(size, evict_active=False): + shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY + if MAX_PINNED_MEMORY <= 0: + return False + if shortfall <= 0: + return True + + shortfall += REGISTERABLE_PIN_HYSTERESIS + for loaded_model in reversed(current_loaded_models): + model = loaded_model.model + if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]): + shortfall -= model.unregister_inactive_pins(shortfall) + if shortfall <= 0: + return True + return shortfall <= REGISTERABLE_PIN_HYSTERESIS class LoadedModel: def __init__(self, model): @@ -553,9 +581,6 @@ class LoadedModel: def model_memory(self): return self.model.model_size() - def model_mmap_residency(self, free=False): - return self.model.model_mmap_residency(free=free) - def model_loaded_memory(self): return self.model.loaded_size() @@ -635,15 +660,9 @@ WINDOWS = any(platform.win32_ver()) EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 if WINDOWS: - import comfy.windows EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 - def get_free_ram(): - return comfy.windows.get_free_ram() -else: - def get_free_ram(): - return psutil.virtual_memory().available if args.reserve_vram is not None: EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 @@ -657,7 +676,6 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0): cleanup_models_gc() - comfy.memory_management.extra_ram_release(max(pins_required, ram_required)) unloaded_model = [] can_unload = [] unloaded_models = [] @@ -673,11 +691,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins for x in can_unload_sorted: i = x[-1] memory_to_free = 1e32 - pins_to_free = 1e32 - if not DISABLE_SMART_MEMORY or device is None: + if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None): memory_to_free = 0 if device is None else memory_required - get_free_memory(device) - pins_to_free = pins_required - get_free_ram() - if current_loaded_models[i].model.is_dynamic() and for_dynamic: + if for_dynamic: #don't actually unload dynamic models for the sake of other dynamic models #as that works on-demand. memory_required -= current_loaded_models[i].model.loaded_size() @@ -685,18 +701,6 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) - if pins_to_free > 0: - logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}") - current_loaded_models[i].model.partially_unload_ram(pins_to_free) - - for x in can_unload_sorted: - i = x[-1] - ram_to_free = ram_required - psutil.virtual_memory().available - if ram_to_free <= 0 and i not in unloaded_model: - continue - resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True) - if resident_memory > 0: - logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) @@ -762,29 +766,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu model_to_unload.model.detach(unpatch_all=False) model_to_unload.model_finalizer.detach() - total_memory_required = {} - total_pins_required = {} - total_ram_required = {} for loaded_model in models_to_load: device = loaded_model.device total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device) - resident_memory, model_memory = loaded_model.model_mmap_residency() - pinned_memory = loaded_model.model.pinned_memory_size() - #FIXME: This can over-free the pins as it budgets to pin the entire model. We should - #make this JIT to keep as much pinned as possible. - pins_required = model_memory - pinned_memory - ram_required = model_memory - resident_memory - total_pins_required[device] = total_pins_required.get(device, 0) + pins_required - total_ram_required[device] = total_ram_required.get(device, 0) + ram_required for device in total_memory_required: if device != torch.device("cpu"): free_memory(total_memory_required[device] * 1.1 + extra_mem, device, - for_dynamic=free_for_dynamic, - pins_required=total_pins_required[device], - ram_required=total_ram_required[device]) + for_dynamic=free_for_dynamic) for device in total_memory_required: if device != torch.device("cpu"): @@ -1180,6 +1171,7 @@ STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) STREAM_AIMDO_CAST_BUFFERS = {} LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) +STREAM_PIN_BUFFERS = {} DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 @@ -1220,21 +1212,66 @@ def get_aimdo_cast_buffer(offload_stream, device): if cast_buffer is None: cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index) STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer - return cast_buffer + +def get_pin_buffer(offload_stream): + pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None) + if pin_buffer is None: + pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3)) + STREAM_PIN_BUFFERS[offload_stream] = pin_buffer + elif offload_stream is not None: + event = getattr(pin_buffer, "_comfy_event", None) + if event is not None: + event.synchronize() + delattr(pin_buffer, "_comfy_event") + return pin_buffer + +def resize_pin_buffer(pin_buffer, size): + global TOTAL_PINNED_MEMORY + old_size = pin_buffer.size + if size <= old_size: + return True + growth = size - old_size + comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) + ensure_pin_budget(growth, evict_active=True) + ensure_pin_registerable(growth, evict_active=True) + try: + pin_buffer.extend(size=size, reallocate=True) + except RuntimeError: + return False + TOTAL_PINNED_MEMORY += pin_buffer.size - old_size + return True + def reset_cast_buffers(): + global TOTAL_PINNED_MEMORY global LARGEST_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) - for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): + for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS): if offload_stream is not None: offload_stream.synchronize() synchronize() + for mmap_obj in DIRTY_MMAPS: + mmap_obj.bounce() + DIRTY_MMAPS.clear() + + for pin_buffer in STREAM_PIN_BUFFERS.values(): + TOTAL_PINNED_MEMORY -= pin_buffer.size + TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY) + + for loaded_model in current_loaded_models: + model = loaded_model.model + if model is not None and model.is_dynamic(): + model.model.dynamic_pins[model.load_device]["active"] = False + model.partially_unload_ram(1e30, subsets=[ "patches" ]) + model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0]) + STREAM_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear() + STREAM_PIN_BUFFERS.clear() soft_empty_cache() def get_offload_stream(device): @@ -1280,7 +1317,7 @@ def sync_stream(device, stream): current_stream(device).wait_stream(stream) -def cast_to_gathered(tensors, r, non_blocking=False, stream=None): +def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None): wf_context = nullcontext() if stream is not None: wf_context = stream @@ -1288,17 +1325,20 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None): wf_context = wf_context.as_context(stream) dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) + dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None with wf_context: for tensor in tensors: dest_view = dest_views.pop(0) + dest2_view = dest2_views.pop(0) if dest2_views is not None else None if tensor is None: continue - if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view): + if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view, stream=stream, destination2=dest2_view): continue storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() - if hasattr(storage, "_comfy_tensor_mmap_touched"): - storage._comfy_tensor_mmap_touched = True + mark_mmap_dirty(storage) dest_view.copy_(tensor, non_blocking=non_blocking) + if dest2_view is not None: + dest2_view.copy_(dest_view, non_blocking=non_blocking) def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): @@ -1339,14 +1379,18 @@ TOTAL_PINNED_MEMORY = 0 MAX_PINNED_MEMORY = -1 if not args.disable_pinned_memory: if is_nvidia() or is_amd(): + ram = get_total_memory(torch.device("cpu")) if WINDOWS: - MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50% + MAX_PINNED_MEMORY = ram * 0.40 # Windows limit is apparently 50% else: - MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90 + MAX_PINNED_MEMORY = ram * 0.90 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) +def pinned_hostbuf_size(size): + return max(0, int(min(size, MAX_PINNED_MEMORY) * 2)) + def discard_cuda_async_error(): try: a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) @@ -1378,8 +1422,8 @@ def pin_memory(tensor): return False size = tensor.nbytes - if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: - return False + comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) + ensure_pin_registerable(size) ptr = tensor.data_ptr() if ptr == 0: @@ -1416,7 +1460,8 @@ def unpin_memory(tensor): return False if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: - TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) + size = PINNED_MEMORY.pop(ptr) + TOTAL_PINNED_MEMORY -= size return True else: logging.warning("Unpin error.") diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4f9d8403e..c8ed02e70 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -35,6 +35,7 @@ import comfy.model_management import comfy.ops import comfy.patcher_extension import comfy.utils +import comfy_aimdo.host_buffer from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP @@ -117,6 +118,8 @@ def string_to_seed(data): return comfy.utils.string_to_seed(data) class LowVramPatch: + is_lowvram_patch = True + def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key self.patches = patches @@ -124,11 +127,21 @@ class LowVramPatch: self.set_func = set_func self.prepared_patches = None - def prepare(self, allocate_buffer, stream): - self.prepared_patches = [ - (patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4]) + def memory_required(self): + counter = [0] + for patch in self.patches[self.key]: + comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False) + return counter[0] + + def prepare(self, destination, stream, copy=True, commit=True): + counter = [0] + prepared_patches = [ + (patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4]) for patch in self.patches[self.key] ] + if commit: + self.prepared_patches = prepared_patches + return prepared_patches def clear_prepared(self): self.prepared_patches = None @@ -341,9 +354,6 @@ class ModelPatcher: self.size = comfy.model_management.module_size(self.model) return self.size - def model_mmap_residency(self, free=False): - return comfy.model_management.module_mmap_residency(self.model, free=free) - def loaded_size(self): return self.model.model_loaded_weight_memory @@ -1118,8 +1128,12 @@ class ModelPatcher: # Pinned memory pressure tracking is only implemented for DynamicVram loading return 0 + def loaded_ram_size(self): + # Loaded RAM pressure tracking is only implemented for DynamicVram loading + return 0 + def partially_unload_ram(self, ram_to_unload): - pass + return 0 def detach(self, unpatch_all=True): self.eject_model() @@ -1550,6 +1564,16 @@ class ModelPatcherDynamic(ModelPatcher): super().__init__(model, load_device, offload_device, size, weight_inplace_update) if not hasattr(self.model, "dynamic_vbars"): self.model.dynamic_vbars = {} + if not hasattr(self.model, "dynamic_pins"): + self.model.dynamic_pins = {} + if self.load_device not in self.model.dynamic_pins: + self.model.dynamic_pins[self.load_device] = { + "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), + "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), + "hostbufs_initialized": False, + "failed": False, + "active": False, + } self.non_dynamic_delegate_model = None assert load_device is not None @@ -1611,6 +1635,14 @@ class ModelPatcherDynamic(ModelPatcher): self.unpatch_hooks() vbar = self._vbar_get(create=True) + pin_state = self.model.dynamic_pins[self.load_device] + if not pin_state["hostbufs_initialized"]: + hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size()) + pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0]) + pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0]) + pin_state["hostbufs_initialized"] = True + pin_state["failed"] = False + pin_state["active"] = True if vbar is not None: vbar.prioritize() @@ -1636,7 +1668,9 @@ class ModelPatcherDynamic(ModelPatcher): if key in self.patches: if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape: return (True, 0) - setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) + lowvram_patch = LowVramPatch(key, self.patches) + lowvram_patch._pin_state = pin_state + setattr(m, param_key + "_lowvram_function", lowvram_patch) num_patches += 1 else: setattr(m, param_key + "_lowvram_function", None) @@ -1653,6 +1687,9 @@ class ModelPatcherDynamic(ModelPatcher): def force_load_param(self, param_key, device_to): key = key_param_name_to_key(n, param_key) + weight, _, _ = get_key_weight(self.model, key) + if weight is None: + return if key in self.backup: comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) self.patch_weight_to_device(key, device_to=device_to, force_cast=True) @@ -1662,17 +1699,23 @@ class ModelPatcherDynamic(ModelPatcher): if hasattr(m, "comfy_cast_weights"): m.comfy_cast_weights = True - m.pin_failed = False m.seed_key = n + m._pin_state = pin_state set_dirty(m, dirty) - force_load, v_weight_size = setup_param(self, m, n, "weight") - force_load_bias, v_weight_bias = setup_param(self, m, n, "bias") - force_load = force_load or force_load_bias - v_weight_size += v_weight_bias + #Models that mix tiny and giant weights can causing lopsided stream buffer + #rotations and stall. force the tinys over. + if module_mem > 16 * 1024: + force_load, v_weight_size = setup_param(self, m, n, "weight") + force_load_bias, v_weight_bias = setup_param(self, m, n, "bias") + force_load = force_load or force_load_bias + v_weight_size += v_weight_bias + if force_load: + logging.info(f"Module {n} has resizing Lora - force loading") + else: + force_load=True if force_load: - logging.info(f"Module {n} has resizing Lora - force loading") force_load_param(self, "weight", device_to) force_load_param(self, "bias", device_to) else: @@ -1740,23 +1783,58 @@ class ModelPatcherDynamic(ModelPatcher): return freed - def pinned_memory_size(self): - total = 0 - loading = self._load_list(for_dynamic=True) - for x in loading: - _, _, _, _, m, _ = x - pin = comfy.pinned_memory.get_pin(m) - if pin is not None: - total += pin.numel() * pin.element_size() - return total + def loaded_ram_size(self): + return (self.model.dynamic_pins[self.load_device]["weights"][0].size + + self.model.dynamic_pins[self.load_device]["patches"][0].size) - def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(for_dynamic=True, default_device=self.offload_device) - for x in loading: - *_, m, _ = x - ram_to_unload -= comfy.pinned_memory.unpin_memory(m) - if ram_to_unload <= 0: - return + def pinned_memory_size(self): + return (self.model.dynamic_pins[self.load_device]["weights"][3][0] + + self.model.dynamic_pins[self.load_device]["patches"][3][0]) + + def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]): + freed = 0 + pin_state = self.model.dynamic_pins[self.load_device] + for subset in subsets: + hostbuf, stack, stack_split, pinned_size = pin_state[subset] + split = stack_split[0] + while split >= 0: + module, offset = stack[split] + split -= 1 + stack_split[0] = split + if not module._pin_registered: + continue + size = module._pin.numel() * module._pin.element_size() + if torch.cuda.cudart().cudaHostUnregister(module._pin.data_ptr()) != 0: + comfy.model_management.discard_cuda_async_error() + continue + module._pin_registered = False + comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size) + pinned_size[0] = max(0, pinned_size[0] - size) + freed += size + ram_to_unload -= size + if ram_to_unload <= 0: + return freed + return freed + + def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]): + freed = 0 + pin_state = self.model.dynamic_pins[self.load_device] + for subset in subsets: + hostbuf, stack, stack_split, pinned_size = pin_state[subset] + while len(stack) > 0: + module, offset = stack.pop() + size = module._pin.numel() * module._pin.element_size() + del module._pin + hostbuf.truncate(offset, do_unregister=module._pin_registered) + stack_split[0] = min(stack_split[0], len(stack) - 1) + if module._pin_registered: + comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size) + pinned_size[0] = max(0, pinned_size[0] - size) + freed += size + ram_to_unload -= size + if ram_to_unload <= 0: + return freed + return freed def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): #This isn't used by the core at all and can only be to load a model out of diff --git a/comfy/ops.py b/comfy/ops.py index eae3bd873..9bcd6c900 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -75,6 +75,8 @@ except: cast_to = comfy.model_management.cast_to #TODO: remove once no more references +STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024 + def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) @@ -91,6 +93,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin offload_stream = None cast_buffer = None cast_buffer_offset = 0 + stream_pin_hostbuf = None + stream_pin_offset = 0 + stream_pin_queue = [] def ensure_offload_stream(module, required_size, check_largest): nonlocal offload_stream @@ -124,6 +129,22 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin cast_buffer_offset += buffer_size return buffer + def get_stream_pin_buffer_offset(buffer_size): + nonlocal stream_pin_hostbuf + nonlocal stream_pin_offset + + if buffer_size == 0 or offload_stream is None: + return None + + if stream_pin_hostbuf is None: + stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream) + if stream_pin_hostbuf is None: + return None + + offset = stream_pin_offset + stream_pin_offset += buffer_size + return offset + for s in comfy_modules: signature = comfy_aimdo.model_vbar.vbar_fault(s._v) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) @@ -162,23 +183,47 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin if xfer_dest is None: xfer_dest = get_cast_buffer(dest_size) - if signature is None and pin is None: - comfy.pinned_memory.pin_memory(s) - pin = comfy.pinned_memory.get_pin(s) - else: - pin = None + def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream): + if xfer_source is not None: + if getattr(xfer_source, "is_lowvram_patch", False): + xfer_source.prepare(xfer_dest, stream, copy=True, commit=False) + else: + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream) - if pin is not None: - comfy.model_management.cast_to_gathered(xfer_source, pin) - xfer_source = [ pin ] - #send it over - comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + def handle_pin(m, pin, source, dest, subset="weights", size=None): + if pin is not None: + cast_maybe_lowvram_patch([pin], dest, offload_stream) + return + if signature is None: + comfy.pinned_memory.pin_memory(m, subset=subset, size=size) + pin = comfy.pinned_memory.get_pin(m, subset=subset) + if pin is not None: + if isinstance(source, list): + comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest) + else: + cast_maybe_lowvram_patch(source, pin, None) + cast_maybe_lowvram_patch([ pin ], dest, offload_stream) + return + if pin is None: + pin_offset = get_stream_pin_buffer_offset(size) + if pin_offset is not None: + stream_pin_queue.append((source, pin_offset, size, dest)) + return + cast_maybe_lowvram_patch(source, dest, offload_stream) + + handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size) for param_key in ("weight", "bias"): - lowvram_fn = getattr(s, param_key + "_lowvram_function", None) - if lowvram_fn is not None: + lowvram_source = getattr(s, param_key + "_lowvram_function", None) + if lowvram_source is not None: ensure_offload_stream(s, cast_buffer_offset, False) - lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream) + lowvram_size = lowvram_source.memory_required() + lowvram_dest = get_cast_buffer(lowvram_size) + lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True) + + pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches") + handle_pin(lowvram_source, pin, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size) + prefetch["xfer_dest"] = xfer_dest prefetch["cast_dest"] = cast_dest @@ -186,6 +231,23 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin prefetch["needs_cast"] = needs_cast s._prefetch = prefetch + if stream_pin_offset > 0: + if stream_pin_hostbuf.size < stream_pin_offset: + if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM): + for xfer_source, _, _, xfer_dest in stream_pin_queue: + cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream) + return offload_stream + stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf) + stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf + for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue: + pin = stream_pin_tensor[pin_offset:pin_offset + pin_size] + if isinstance(xfer_source, list): + comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest) + else: + cast_maybe_lowvram_patch(xfer_source, pin, None) + comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream) + stream_pin_hostbuf._comfy_event = offload_stream.record_event() + return offload_stream diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 6d3ba367a..0e8f573ba 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -2,42 +2,62 @@ import comfy.model_management import comfy.memory_management import comfy_aimdo.host_buffer import comfy_aimdo.torch +import torch from comfy.cli_args import args -def get_pin(module): - return getattr(module, "_pin", None) +def get_pin(module, subset="weights"): + pin = getattr(module, "_pin", None) + if pin is None or module._pin_registered or args.disable_pinned_memory: + return pin -def pin_memory(module): - if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: + _, _, stack_split, pinned_size = module._pin_state[subset] + size = pin.nbytes + comfy.model_management.ensure_pin_registerable(size) + + if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0: + comfy.model_management.discard_cuda_async_error() + return pin + + module._pin_registered = True + stack_split[0] = max(stack_split[0], module._pin_stack_index) + comfy.model_management.TOTAL_PINNED_MEMORY += size + pinned_size[0] += size + return pin + +def pin_memory(module, subset="weights", size=None): + pin_state = module._pin_state + if args.disable_pinned_memory: return - size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) + pin = get_pin(module, subset) + if pin is not None or pin_state["failed"]: + return - if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: - module.pin_failed = True + hostbuf, stack, stack_split, pinned_size = pin_state[subset] + if size is None: + size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) + offset = hostbuf.size + registerable_size = size + max(0, hostbuf.size - pinned_size[0]) + + comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) + if (not comfy.model_management.ensure_pin_budget(size) or + not comfy.model_management.ensure_pin_registerable(registerable_size)): + pin_state["failed"] = True return False try: - hostbuf = comfy_aimdo.host_buffer.HostBuffer(size) + hostbuf.extend(size=size) except RuntimeError: - module.pin_failed = True + pin_state["failed"] = True return False - module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf) - module._pin_hostbuf = hostbuf + module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size] + module._pin.untyped_storage()._comfy_hostbuf = hostbuf + stack.append((module, offset)) + module._pin_registered = True + module._pin_stack_index = len(stack) - 1 + stack_split[0] = max(stack_split[0], module._pin_stack_index) comfy.model_management.TOTAL_PINNED_MEMORY += size + pinned_size[0] += size return True - -def unpin_memory(module): - if get_pin(module) is None: - return 0 - size = module._pin.numel() * module._pin.element_size() - - comfy.model_management.TOTAL_PINNED_MEMORY -= size - if comfy.model_management.TOTAL_PINNED_MEMORY < 0: - comfy.model_management.TOTAL_PINNED_MEMORY = 0 - - del module._pin - del module._pin_hostbuf - return size diff --git a/comfy/utils.py b/comfy/utils.py index 66682690a..00e382fac 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -113,7 +113,6 @@ def load_safetensors(ckpt): "_comfy_tensor_file_slice", comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start)) setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv)) - setattr(storage, "_comfy_tensor_mmap_touched", False) sd[name] = tensor return sd, header.get("__metadata__", {}), @@ -1451,4 +1450,3 @@ def deepcopy_list_dict(obj, memo=None): memo[obj_id] = res return res - diff --git a/comfy/windows.py b/comfy/windows.py deleted file mode 100644 index 213dc481d..000000000 --- a/comfy/windows.py +++ /dev/null @@ -1,52 +0,0 @@ -import ctypes -import logging -import psutil -from ctypes import wintypes - -import comfy_aimdo.control - -psapi = ctypes.WinDLL("psapi") -kernel32 = ctypes.WinDLL("kernel32") - -class PERFORMANCE_INFORMATION(ctypes.Structure): - _fields_ = [ - ("cb", wintypes.DWORD), - ("CommitTotal", ctypes.c_size_t), - ("CommitLimit", ctypes.c_size_t), - ("CommitPeak", ctypes.c_size_t), - ("PhysicalTotal", ctypes.c_size_t), - ("PhysicalAvailable", ctypes.c_size_t), - ("SystemCache", ctypes.c_size_t), - ("KernelTotal", ctypes.c_size_t), - ("KernelPaged", ctypes.c_size_t), - ("KernelNonpaged", ctypes.c_size_t), - ("PageSize", ctypes.c_size_t), - ("HandleCount", wintypes.DWORD), - ("ProcessCount", wintypes.DWORD), - ("ThreadCount", wintypes.DWORD), - ] - -def get_free_ram(): - #Windows is way too conservative and chalks recently used uncommitted model RAM - #as "in-use". So, calculate free RAM for the sake of general use as the greater of: - # - #1: What psutil says - #2: Total Memory - (Committed Memory - VRAM in use) - # - #We have to subtract VRAM in use from the comitted memory as WDDM creates a naked - #commit charge for all VRAM used just incase it wants to page it all out. This just - #isn't realistic so "overcommit" on our calculations by just subtracting it off. - - pi = PERFORMANCE_INFORMATION() - pi.cb = ctypes.sizeof(pi) - - if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb): - logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal") - return psutil.virtual_memory().available - - committed = pi.CommitTotal * pi.PageSize - total = pi.PhysicalTotal * pi.PageSize - - return max(psutil.virtual_memory().available, - total - (committed - comfy_aimdo.control.get_total_vram_usage())) - diff --git a/comfy_extras/mediapipe/face_geometry.py b/comfy_extras/mediapipe/face_geometry.py new file mode 100644 index 000000000..04b2b0557 --- /dev/null +++ b/comfy_extras/mediapipe/face_geometry.py @@ -0,0 +1,111 @@ +"""Pure-numpy port of MediaPipe's face_geometry (FACE_LANDMARK_PIPELINE mode) ++ weighted Procrustes solver. Computes the 4x4 facial transformation matrix. +""" + +from __future__ import annotations + +import math +import numpy as np + + +def _solve_weighted_orthogonal_problem(src: np.ndarray, tgt: np.ndarray, weights: np.ndarray) -> np.ndarray: + """Weighted orthogonal Procrustes (similarity). Returns 4x4 M with + `target ≈ M @ homogeneous(source)` in the weighted LS sense. fp64 for + SVD stability. Port of procrustes_solver.cc.""" + sqrt_w = np.sqrt(weights.astype(np.float64)) + w_total = float((sqrt_w ** 2).sum()) + ws = src.astype(np.float64) * sqrt_w + wt = tgt.astype(np.float64) * sqrt_w + + c_w = (ws @ sqrt_w) / w_total + centered = ws - np.outer(c_w, sqrt_w) + U, _S, Vt = np.linalg.svd(wt @ centered.T, full_matrices=True) + # Disallow reflection: flip the least-significant axis when det(U)·det(V)<0. + post, pre = U.copy(), Vt.T.copy() + if np.linalg.det(post) * np.linalg.det(pre) < 0: + post[:, 2] *= -1.0 + R = post @ pre.T + + denom = float((centered * ws).sum()) + if denom < 1e-12: + raise ValueError("Procrustes denominator collapsed (degenerate source).") + scale = float((R @ centered * wt).sum()) / denom + translation = ((wt - scale * (R @ ws)) @ sqrt_w) / w_total + + M = np.eye(4, dtype=np.float64) + M[:3, :3] = scale * R + M[:3, 3] = translation + return M + + +def _estimate_scale(canonical: np.ndarray, runtime: np.ndarray, weights: np.ndarray) -> float: + """scale = ‖first column of M[:3]‖ per geometry_pipeline.cc::EstimateScale.""" + return float(np.linalg.norm(_solve_weighted_orthogonal_problem(canonical, runtime, weights)[:3, 0])) + + +def solve_facial_transformation_matrix( + landmarks_normalized: np.ndarray, + canonical_vertices: np.ndarray, + procrustes_indices: np.ndarray, + procrustes_weights: np.ndarray, + image_width: int, + image_height: int, + # face_geometry_calculator_options.pbtxt defaults + vertical_fov_degrees: float = 63.0, + near: float = 1.0, +) -> np.ndarray: + """4x4 facial transformation matrix via two-pass scale recovery + `landmarks_normalized` is (N, 3) in MediaPipe normalized convention: x, y + in [0,1] with TOP-LEFT origin, z in width-scaled units. + """ + + h_near = 2.0 * near * math.tan(0.5 * math.radians(vertical_fov_degrees)) + w_near = image_width * h_near / image_height + + sub = procrustes_indices.astype(np.int64) + screen = landmarks_normalized[sub].T.astype(np.float64).copy() + canon = canonical_vertices[sub].T.astype(np.float64).copy() + weights = procrustes_weights.astype(np.float64) + + # ProjectXY (TOP_LEFT y-flip, then scale all 3 axes; z uses x-scale). + screen[1] = 1.0 - screen[1] + screen[0] = screen[0] * w_near - 0.5 * w_near + screen[1] = screen[1] * h_near - 0.5 * h_near + screen[2] = screen[2] * w_near + depth_offset = float(screen[2].mean()) + + def _unproject(s: np.ndarray, scale: float) -> np.ndarray: + s = s.copy() + s[2] = (s[2] - depth_offset + near) / scale + s[0] *= s[2] / near + s[1] *= s[2] / near + s[2] *= -1.0 + return s + + first = screen.copy() + first[2] *= -1.0 + s1 = _estimate_scale(canon, first, weights) # 1st pass: Procrustes on projected XY + s2 = _estimate_scale(canon, _unproject(screen, s1), weights) # 2nd pass: rescale z by s1, un-project XY + return _solve_weighted_orthogonal_problem(canon, _unproject(screen, s1 * s2), weights).astype(np.float32) + + +def transformation_matrix_from_detection(face_dict: dict, image_width: int, image_height: int, canonical_data: dict) -> np.ndarray: + """Adapt a FaceLandmarker face dict to MP's normalized convention and solve. + FaceMesh emits (x, y, z) in 192-canonical units; MP's geometry expects + z_norm = z_canonical * scale_x / image_width""" + + lmks_xy, lmks_3d = face_dict["landmarks_xy"], face_dict["landmarks_3d"] + aug = np.concatenate([lmks_3d[:, :2].astype(np.float64), np.ones((lmks_xy.shape[0], 1))], axis=1) + M, *_ = np.linalg.lstsq(aug, lmks_xy.astype(np.float64), rcond=None) + scale_x = float(np.linalg.norm(M[0])) + z_scale = scale_x / image_width if scale_x > 1e-6 else 1.0 / image_width + + normalized = np.empty((lmks_xy.shape[0], 3), dtype=np.float32) + normalized[:, 0] = lmks_xy[:, 0] / image_width + normalized[:, 1] = lmks_xy[:, 1] / image_height + normalized[:, 2] = lmks_3d[:, 2] * z_scale + return solve_facial_transformation_matrix( + normalized, canonical_data["canonical_vertices"], + canonical_data["procrustes_indices"], canonical_data["procrustes_weights"], + image_width=image_width, image_height=image_height, + ) diff --git a/comfy_extras/mediapipe/face_landmarker.py b/comfy_extras/mediapipe/face_landmarker.py new file mode 100644 index 000000000..a792b6046 --- /dev/null +++ b/comfy_extras/mediapipe/face_landmarker.py @@ -0,0 +1,682 @@ +"""Pure-PyTorch port of MediaPipe's face_landmarker_v2_with_blendshapes.task: +BlazeFace detector → FaceMesh v2 → ARKit-52 blendshapes.""" + +from __future__ import annotations + +import math +from functools import lru_cache +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.special import expit +from torch import Tensor, nn + + +# Values below must stay verbatim with the published face_landmarker_v2 graph + +# face_blendshapes_graph.cc::kLandmarksSubsetIdxs +_BS_INPUT_INDICES: Tuple[int, ...] = ( + 0, 1, 4, 5, 6, 7, 8, 10, 13, 14, 17, 21, 33, 37, 39, 40, 46, 52, 53, 54, + 55, 58, 61, 63, 65, 66, 67, 70, 78, 80, 81, 82, 84, 87, 88, 91, 93, 95, + 103, 105, 107, 109, 127, 132, 133, 136, 144, 145, 146, 148, 149, 150, 152, + 153, 154, 155, 157, 158, 159, 160, 161, 162, 163, 168, 172, 173, 176, 178, + 181, 185, 191, 195, 197, 234, 246, 249, 251, 263, 267, 269, 270, 276, 282, + 283, 284, 285, 288, 291, 293, 295, 296, 297, 300, 308, 310, 311, 312, 314, + 317, 318, 321, 323, 324, 332, 334, 336, 338, 356, 361, 362, 365, 373, 374, + 375, 377, 378, 379, 380, 381, 382, 384, 385, 386, 387, 388, 389, 390, 397, + 398, 400, 402, 405, 409, 415, 454, 466, 468, 469, 470, 471, 472, 473, 474, + 475, 476, 477, +) + +# face_blendshapes_graph.cc::kCategoryNames +BLENDSHAPE_NAMES: Tuple[str, ...] = ( + "_neutral", "browDownLeft", "browDownRight", "browInnerUp", "browOuterUpLeft", + "browOuterUpRight", "cheekPuff", "cheekSquintLeft", "cheekSquintRight", + "eyeBlinkLeft", "eyeBlinkRight", "eyeLookDownLeft", "eyeLookDownRight", + "eyeLookInLeft", "eyeLookInRight", "eyeLookOutLeft", "eyeLookOutRight", + "eyeLookUpLeft", "eyeLookUpRight", "eyeSquintLeft", "eyeSquintRight", + "eyeWideLeft", "eyeWideRight", "jawForward", "jawLeft", "jawOpen", + "jawRight", "mouthClose", "mouthDimpleLeft", "mouthDimpleRight", + "mouthFrownLeft", "mouthFrownRight", "mouthFunnel", "mouthLeft", + "mouthLowerDownLeft", "mouthLowerDownRight", "mouthPressLeft", + "mouthPressRight", "mouthPucker", "mouthRight", "mouthRollLower", + "mouthRollUpper", "mouthShrugLower", "mouthShrugUpper", "mouthSmileLeft", + "mouthSmileRight", "mouthStretchLeft", "mouthStretchRight", + "mouthUpperUpLeft", "mouthUpperUpRight", "noseSneerLeft", "noseSneerRight", +) + +# face_detection.pbtxt — short-range BlazeFace. +_BF_NUM_LAYERS = 4 +_BF_INPUT_SIZE = 128 +_BF_STRIDES = (8, 16, 16, 16) +_BF_ANCHOR_OFFSET_X = 0.5 +_BF_ANCHOR_OFFSET_Y = 0.5 +_BF_ASPECT_RATIOS = (1.0,) +_BF_INTERP_SCALE_AR = 1.0 +_BF_BOX_SCALE = 128.0 +_BF_KP_OFFSET = 4 +_BF_SCORE_CLIP = 100.0 +_BF_MIN_SCORE = 0.5 + +# face_detection_full_range.pbtxt — 48x48 grid at stride 4, 1 anchor/cell. +_BF_FR_INPUT_SIZE = 192 +_BF_FR_GRID = 48 +_BF_FR_NUM_ANCHORS = _BF_FR_GRID * _BF_FR_GRID +_BF_FR_BOX_SCALE = 192.0 +_BF_FR_SCORE_CLIP = 100.0 + +_FM_INPUT_SIZE = 192 + +# Face ROI: 1.5xbbox rect warped anisotropically into 192x192. +_FACE_LEFT_EYE_KP = 0 +_FACE_RIGHT_EYE_KP = 1 +_FACE_ROI_SCALE_X = 1.5 +_FACE_ROI_SCALE_Y = 1.5 +_FACE_ROI_TARGET_ANGLE = 0.0 + + +def _tf_same_pad(x: Tensor, kernel: int, stride: int) -> Tensor: + """TF SAME pad (asymmetric on stride-2; PyTorch's symmetric pad undershoots by 1 px).""" + H, W = x.shape[-2], x.shape[-1] + pad_h = max(((H + stride - 1) // stride - 1) * stride + kernel - H, 0) + pad_w = max(((W + stride - 1) // stride - 1) * stride + kernel - W, 0) + if pad_h == 0 and pad_w == 0: + return x + return F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) + + +# BlazeFace short-range: stem 5x5/s2 → 16 BlazeBlocks → parallel heads at +# 16²x88 (2 anchors/cell) and 8²x96 (6/cell) = 896 anchors. (in, out, stride): +_BLAZEFACE_BLOCKS = [ + (24, 24, 1), (24, 28, 1), (28, 32, 2), (32, 36, 1), + (36, 42, 1), (42, 48, 2), (48, 56, 1), (56, 64, 1), + (64, 72, 1), (72, 80, 1), (80, 88, 1), (88, 96, 2), + (96, 96, 1), (96, 96, 1), (96, 96, 1), (96, 96, 1), +] + + +class BlazeFaceBlock(nn.Module): + """DW 3x3 + PW + residual. Residual max-pools on stride>1, channel-pads on out_ch>in_ch.""" + + def __init__(self, in_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + self.in_ch, self.out_ch, self.stride = in_ch, out_ch, stride + self.depthwise = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, device=device, dtype=dtype) + self.pointwise = ops.Conv2d(in_ch, out_ch, 1, padding=0, bias=True, device=device, dtype=dtype) + + def forward(self, x: Tensor) -> Tensor: + residual = F.max_pool2d(x, 2, 2) if self.stride > 1 else x + if self.out_ch > self.in_ch: + residual = F.pad(residual, (0, 0, 0, 0, 0, self.out_ch - self.in_ch)) + x = _tf_same_pad(x, 3, self.stride) if self.stride > 1 else F.pad(x, (1, 1, 1, 1)) + return F.relu(self.pointwise(self.depthwise(x)) + residual) + + +class BlazeFace(nn.Module): + """Short-range BlazeFace: (B, 3, 128, 128) in [-1, 1] → 896 anchors x 17.""" + + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + self.stem = ops.Conv2d(3, 24, 5, stride=2, padding=0, bias=True, **kw) + self.blocks = nn.ModuleList(BlazeFaceBlock(i, o, s, device=device, dtype=dtype, operations=operations) + for (i, o, s) in _BLAZEFACE_BLOCKS) + # 16²x2 + 8²x6 = 512 + 384 = 896 anchors. + self.cls_16 = ops.Conv2d(88, 2, 1, padding=0, bias=True, **kw) + self.cls_8 = ops.Conv2d(96, 6, 1, padding=0, bias=True, **kw) + self.reg_16 = ops.Conv2d(88, 32, 1, padding=0, bias=True, **kw) + self.reg_8 = ops.Conv2d(96, 96, 1, padding=0, bias=True, **kw) + + def forward(self, image_chw_normalized: Tensor) -> tuple[Tensor, Tensor]: + x = F.relu(self.stem(_tf_same_pad(image_chw_normalized, 5, 2))) + # 16x16 tap is block-10 output (before the 88→96 stride-2 in block 11). + for i in range(11): + x = self.blocks[i](x) + feat_16 = x + for i in range(11, 16): + x = self.blocks[i](x) + feat_8 = x + + def flat(t, a, k): # NHWC flatten → (B, H*W*A, K) + B, _, H, W = t.shape + return t.permute(0, 2, 3, 1).reshape(B, H * W * a, k) + + cls = torch.cat([flat(self.cls_16(feat_16), 2, 1), flat(self.cls_8(feat_8), 6, 1)], dim=1) + reg = torch.cat([flat(self.reg_16(feat_16), 2, 16), flat(self.reg_8(feat_8), 6, 16)], dim=1) + return reg, cls + + +# BlazeFace full-range (face_detection_full_range_sparse.tflite): MobileNetV2-ish +# backbone + top-down FPN, 192² input → 2304 anchors at the 48x48 grid. +class FRBlock(nn.Module): + """Double inverted residual: DW → PW(mid) → DW → PW(out) [+ residual]. + + Per source tflite: dw* have no fused activation, pw1 is always ReLU, pw2 + is ReLU only when no residual (else ReLU fuses into the ADD). + """ + + def __init__(self, in_ch: int, mid_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + self.has_residual = (in_ch == out_ch and stride == 1) + self.dw1 = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, **kw) + self.pw1 = ops.Conv2d(in_ch, mid_ch, 1, padding=0, bias=True, **kw) + self.dw2 = ops.Conv2d(mid_ch, mid_ch, 3, stride=1, padding=0, groups=mid_ch, bias=True, **kw) + self.pw2 = ops.Conv2d(mid_ch, out_ch, 1, padding=0, bias=True, **kw) + + def forward(self, x: Tensor) -> Tensor: + residual = x if self.has_residual else None + x = F.relu(self.pw1(self.dw1(F.pad(x, (1, 1, 1, 1))))) + x = self.pw2(self.dw2(F.pad(x, (1, 1, 1, 1)))) + return F.relu(x + residual) if residual is not None else F.relu(x) + + +# (in_ch, mid_ch, out_ch, stride). Stages downsample 96²x32 → 48²x64 → 24²x128 +# → 12²x192 → 6²x384. Lateral taps at indices 4, 7, 10 (see _FR_LATERAL_*). +_FR_BACKBONE_BLOCKS = [ + (32, 8, 32, 1), (32, 8, 32, 1), # 96²x32 + (32, 16, 64, 2), (64, 16, 64, 1), (64, 16, 64, 1), # 48²x64 — tap[0] + (64, 32, 128, 2), (128, 32, 128, 1), (128, 32, 128, 1), # 24²x128 — tap[1] + (128, 48, 192, 2), (192, 48, 192, 1), (192, 48, 192, 1), # 12²x192 — tap[2] + (192, 96, 384, 2), (384, 96, 384, 1), (384, 96, 384, 1), (384, 96, 384, 1), # 6²x384 +] +_FR_LATERAL_TAP_INDICES = (4, 7, 10) +_FR_LATERAL_CHANNELS = ((64, 48), (128, 64), (192, 96)) # (in, out) per side-conv + +# Decoder blocks per FPN level (after upsample-and-merge with the lateral). +_FR_DECODER_BLOCKS = [ + [(96, 48, 96, 1), (96, 48, 96, 1)], # 12²x96 + [(64, 32, 64, 1), (64, 32, 64, 1)], # 24²x64 + [(48, 24, 48, 1)], # 48²x48 — feeds the heads +] + + +def _dcr_depth_to_space(t: Tensor, r: int, c_out: int) -> Tensor: + """TF DEPTH_TO_SPACE in DCR layout (input channels = (i, j, c_out)). + pixel_shuffle uses CRD which permutes output channels for c_out > 1.""" + B_, _, H_, W_ = t.shape + t = t.reshape(B_, r, r, c_out, H_, W_) + t = t.permute(0, 3, 4, 1, 5, 2).contiguous() + return t.reshape(B_, c_out, H_ * r, W_ * r) + + +class BlazeFaceFullRange(nn.Module): + """Full-range face detector: (B, 3, 192, 192) in [-1, 1] → 2304 anchors x 17 values.""" + + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + mk_block = lambda i, m, o, s: FRBlock(i, m, o, s, device=device, dtype=dtype, operations=operations) + self.stem = ops.Conv2d(3, 32, 3, stride=2, padding=0, bias=True, **kw) + self.backbone = nn.ModuleList(mk_block(i, m, o, s) for (i, m, o, s) in _FR_BACKBONE_BLOCKS) + self.lateral_convs = nn.ModuleList(ops.Conv2d(i, o, 1, padding=0, bias=True, **kw) for (i, o) in _FR_LATERAL_CHANNELS) + self.top_conv = ops.Conv2d(384, 96, 1, padding=0, bias=True, **kw) + self.decoder_levels = nn.ModuleList( + nn.ModuleList(mk_block(i, m, o, s) for (i, m, o, s) in lvl) for lvl in _FR_DECODER_BLOCKS + ) + # 96→64 before 12→24, 64→48 before 24→48. + self.decoder_reduce_convs = nn.ModuleList([ + ops.Conv2d(96, 64, 1, padding=0, bias=True, **kw), + ops.Conv2d(64, 48, 1, padding=0, bias=True, **kw), + ]) + # Heads mix 2x2-cell info via DW-stride-2 + depth_to_space block_size=2. + self.cls_conv = ops.Conv2d(48, 4, 1, padding=0, bias=True, **kw) + self.cls_dw = ops.Conv2d(4, 4, 3, stride=2, padding=0, groups=4, bias=True, **kw) + self.reg_conv = ops.Conv2d(48, 64, 1, padding=0, bias=True, **kw) + self.reg_dw = ops.Conv2d(64, 64, 3, stride=2, padding=0, groups=64, bias=True, **kw) + + def forward(self, image_chw_normalized: Tensor) -> tuple[Tensor, Tensor]: + # Symmetric pad-1 throughout (full-range tflite uses explicit TF PAD, not SAME). + x = F.relu(self.stem(F.pad(image_chw_normalized, (1, 1, 1, 1)))) + tap_set = set(_FR_LATERAL_TAP_INDICES) + laterals: list[Tensor] = [] + for i, blk in enumerate(self.backbone): + x = blk(x) + if i in tap_set: + laterals.append(x) + + # top_conv / lateral_convs / decoder_reduce_convs all have fused ReLU in the tflite. + p = F.relu(self.top_conv(x)) + laterals_rev = list(reversed(laterals)) + lateral_convs_rev = list(reversed(self.lateral_convs)) + for level in range(len(self.decoder_levels)): + lateral = laterals_rev[level] + p = F.interpolate(p, size=lateral.shape[-2:], mode="bilinear", align_corners=False) + p = p + F.relu(lateral_convs_rev[level](lateral)) + for blk in self.decoder_levels[level]: + p = blk(p) + if level < len(self.decoder_reduce_convs): + p = F.relu(self.decoder_reduce_convs[level](p)) + + c = self.cls_dw(F.pad(self.cls_conv(p), (1, 1, 1, 1))) + c = _dcr_depth_to_space(c, r=2, c_out=1) + r = self.reg_dw(F.pad(self.reg_conv(p), (1, 1, 1, 1))) + r = _dcr_depth_to_space(r, r=2, c_out=16) + B = c.shape[0] + cls_out = c.permute(0, 2, 3, 1).reshape(B, _BF_FR_NUM_ANCHORS, 1) + reg_out = r.permute(0, 2, 3, 1).reshape(B, _BF_FR_NUM_ANCHORS, 16) + return reg_out, cls_out + + +@lru_cache(maxsize=1) +def _blazeface_full_range_anchors() -> np.ndarray: + """2304 anchors over 48x48; anchor_w=anchor_h=1 (fixed_anchor_size).""" + feat = _BF_FR_GRID + yy, xx = np.meshgrid(np.arange(feat, dtype=np.float32), np.arange(feat, dtype=np.float32), indexing="ij") + cx, cy, ones = (xx + 0.5) / feat, (yy + 0.5) / feat, np.ones_like(xx) + return np.stack([cx, cy, ones, ones], axis=-1).reshape(_BF_FR_NUM_ANCHORS, 4) + + +def _decode_blazeface_full_range(regressors: np.ndarray, classificators: np.ndarray, + score_thresh: float = _BF_MIN_SCORE) -> np.ndarray: + """Same decode as short-range with 2304-anchor grid and box_scale=192.""" + scores = expit(np.clip(classificators[:, 0], -_BF_FR_SCORE_CLIP, _BF_FR_SCORE_CLIP)) + keep = scores >= score_thresh + if not keep.any(): + return np.empty((0, 17), dtype=np.float32) + r = regressors[keep] / _BF_FR_BOX_SCALE + a = _blazeface_full_range_anchors()[keep] + cxs, cys, aws, ahs = a[:, 0:1], a[:, 1:2], a[:, 2:3], a[:, 3:4] + xc, yc = r[:, 0:1] * aws + cxs, r[:, 1:2] * ahs + cys + w, h = r[:, 2:3] * aws, r[:, 3:4] * ahs + out = np.empty((r.shape[0], 17), dtype=np.float32) + out[:, 0:1], out[:, 1:2], out[:, 2:3], out[:, 3:4] = xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2 + out[:, 4:16:2] = r[:, _BF_KP_OFFSET::2] * aws + cxs + out[:, 5:16:2] = r[:, _BF_KP_OFFSET + 1::2] * ahs + cys + out[:, 16] = scores[keep] + return out + + +# FaceMesh (face_landmarks_detector.tflite): PReLU variant of BlazeBlock, +# 17 blocks, heads for 478x3 landmarks + presence. +_FACEMESH_BLOCKS = [ # (in_ch, out_ch, stride) + (16, 16, 1), (16, 16, 1), (16, 32, 2), (32, 32, 1), (32, 32, 1), (32, 64, 2), + (64, 64, 1), (64, 64, 1), (64, 128, 2), (128, 128, 1), (128, 128, 1), (128, 128, 2), + (128, 128, 1), (128, 128, 1), (128, 128, 2), (128, 128, 1), (128, 128, 1), +] + + +class FaceMeshBlock(nn.Module): + """PReLU BlazeBlock: PReLU between DW and PW, and after the residual add.""" + + def __init__(self, in_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + self.in_ch, self.out_ch, self.stride = in_ch, out_ch, stride + self.depthwise = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, **kw) + self.prelu_dwise = nn.PReLU(num_parameters=in_ch, **kw) + self.pointwise = ops.Conv2d(in_ch, out_ch, 1, padding=0, bias=True, **kw) + self.prelu_out = nn.PReLU(num_parameters=out_ch, **kw) + + def forward(self, x: Tensor) -> Tensor: + residual = F.max_pool2d(x, 2, 2) if self.stride > 1 else x + if self.out_ch > self.in_ch: + residual = F.pad(residual, (0, 0, 0, 0, 0, self.out_ch - self.in_ch)) + x = _tf_same_pad(x, 3, self.stride) if self.stride > 1 else F.pad(x, (1, 1, 1, 1)) + return self.prelu_out(self.pointwise(self.prelu_dwise(self.depthwise(x))) + residual) + + +class FaceMesh(nn.Module): + NUM_LANDMARKS = 478 + + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + self.stem = ops.Conv2d(3, 16, 3, stride=2, padding=0, bias=True, **kw) + self.prelu_stem = nn.PReLU(num_parameters=16, **kw) + self.blocks = nn.ModuleList(FaceMeshBlock(i, o, s, device=device, dtype=dtype, operations=operations) + for (i, o, s) in _FACEMESH_BLOCKS) + self.head_reduce = ops.Conv2d(128, 8, 1, padding=0, bias=True, **kw) + self.prelu_head_reduce = nn.PReLU(num_parameters=8, **kw) + self.head_block = FaceMeshBlock(8, 8, 1, device=device, dtype=dtype, operations=operations) + self.head_presence = ops.Conv2d(8, 1, 3, padding=0, bias=True, **kw) + self.head_landmarks = ops.Conv2d(8, self.NUM_LANDMARKS * 3, 3, padding=0, bias=True, **kw) + + def forward(self, face_chw_normalized: Tensor) -> tuple[Tensor, Tensor]: + """(B, 3, 192, 192) in [0, 1] → ((B, 478, 3) landmarks in 192-canonical, (B,) presence).""" + x = self.prelu_stem(self.stem(_tf_same_pad(face_chw_normalized, 3, 2))) + for blk in self.blocks: + x = blk(x) + x = self.prelu_head_reduce(self.head_reduce(x)) + x = self.head_block(x) + B = x.shape[0] + presence = self.head_presence(x).reshape(B) + lmks = self.head_landmarks(x).reshape(B, self.NUM_LANDMARKS, 3) + return lmks, presence + + +# FaceBlendshapes (MLP-Mixer "GhumMarkerPoserMlpMixerGeneral"): +# 146x2 → token-reduce 146→96 → embed 2→64 → +cls token → 4x mixer → cls→52. +_BS_NUM_INPUT_LANDMARKS = 146 +_BS_NUM_TOKENS_REDUCED = 96 +_BS_NUM_TOKENS = 97 # +1 cls +_BS_TOKEN_DIM = 64 +_BS_TOKEN_MIX_HIDDEN = 384 +_BS_CHANNEL_MIX_HIDDEN = 256 +_BS_NUM_BLENDSHAPES = 52 +_BS_LN_EPS = 1e-6 + + +class MlpMixerBlock(nn.Module): + """MLP-Mixer block: token-mixing MLP (over tokens) → channel-mixing MLP (over dim). + Both pre-LN, both residual. LN has no beta (bias=False) to match MP.""" + + def __init__(self, num_tokens: int, token_dim: int, token_hidden: int, channel_hidden: int, + device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + # bias=False → no LN beta (matches MP). + self.ln1 = ops.LayerNorm(token_dim, eps=_BS_LN_EPS, bias=False, **kw) + self.ln2 = ops.LayerNorm(token_dim, eps=_BS_LN_EPS, bias=False, **kw) + self.token_mlp1 = ops.Linear(num_tokens, token_hidden, bias=True, **kw) + self.token_mlp2 = ops.Linear(token_hidden, num_tokens, bias=True, **kw) + self.channel_mlp1 = ops.Linear(token_dim, channel_hidden, bias=True, **kw) + self.channel_mlp2 = ops.Linear(channel_hidden, token_dim, bias=True, **kw) + + def forward(self, x: Tensor) -> Tensor: + y = self.ln1(x).transpose(1, 2) + x = x + self.token_mlp2(F.relu(self.token_mlp1(y))).transpose(1, 2) + return x + self.channel_mlp2(F.relu(self.channel_mlp1(self.ln2(x)))) + + +class FaceBlendshapes(nn.Module): + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + self.token_reduce = ops.Linear(_BS_NUM_INPUT_LANDMARKS, _BS_NUM_TOKENS_REDUCED, bias=True, **kw) + self.token_embed = ops.Linear(2, _BS_TOKEN_DIM, bias=True, **kw) + self.cls_token = nn.Parameter(torch.zeros(1, 1, _BS_TOKEN_DIM, **kw)) + self.blocks = nn.ModuleList( + MlpMixerBlock(_BS_NUM_TOKENS, _BS_TOKEN_DIM, _BS_TOKEN_MIX_HIDDEN, _BS_CHANNEL_MIX_HIDDEN, + device=device, dtype=dtype, operations=operations) for _ in range(4) + ) + self.head = ops.Linear(_BS_TOKEN_DIM, _BS_NUM_BLENDSHAPES, bias=True, **kw) + + @staticmethod + def _input_normalize(landmarks_2d: Tensor) -> Tensor: + # Centroid-subtract → L2 scale → x0.5. The 0.5 is baked into training. + centroid = landmarks_2d.mean(dim=1, keepdim=True) + x = landmarks_2d - centroid + mag = torch.sqrt((x * x).sum(dim=-1, keepdim=True)) + scale = mag.mean(dim=1, keepdim=True) + return (x / scale.clamp(min=1e-12)) * 0.5 + + def forward(self, landmarks_2d: Tensor) -> Tensor: + """(B, 146, 2) → (B, 52) in [0, 1]. Input units don't matter (centroid + L2 normalize).""" + x = self._input_normalize(landmarks_2d) + x = self.token_reduce(x.transpose(1, 2)).transpose(1, 2) + x = self.token_embed(x) + cls = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat([cls, x], dim=1) + for blk in self.blocks: + x = blk(x) + return torch.sigmoid(self.head(x[:, 0])) + + +@lru_cache(maxsize=1) +def _blazeface_anchors() -> np.ndarray: + """896 anchors per SsdAnchorsCalculator (fixed_anchor_size → anchor_w=anchor_h=1).""" + per_ar = len(_BF_ASPECT_RATIOS) + (1 if _BF_INTERP_SCALE_AR > 0 else 0) + layer_anchors: List[np.ndarray] = [] + layer = 0 + while layer < _BF_NUM_LAYERS: + stride = _BF_STRIDES[layer] + last = layer + while last < _BF_NUM_LAYERS and _BF_STRIDES[last] == stride: + last += 1 + per_cell = per_ar * (last - layer) + feat = (_BF_INPUT_SIZE + stride - 1) // stride + yy, xx = np.meshgrid(np.arange(feat, dtype=np.float32), np.arange(feat, dtype=np.float32), indexing="ij") + cx, cy, ones = (xx + _BF_ANCHOR_OFFSET_X) / feat, (yy + _BF_ANCHOR_OFFSET_Y) / feat, np.ones_like(xx) + cell = np.stack([cx, cy, ones, ones], axis=-1).reshape(-1, 4) + layer_anchors.append(np.repeat(cell, per_cell, axis=0)) + layer = last + out = np.concatenate(layer_anchors, axis=0) + assert out.shape == (896, 4), out.shape + return out + + +def _decode_blazeface(regressors: np.ndarray, classificators: np.ndarray, + score_thresh: float = _BF_MIN_SCORE) -> np.ndarray: + """Decode (regs (896,16), cls (896,1)) → (N, 17) = [xyxy, kp0x..kp5y, score] in [0, 1].""" + scores = expit(np.clip(classificators[:, 0], -_BF_SCORE_CLIP, _BF_SCORE_CLIP)) + keep = scores >= score_thresh + if not keep.any(): + return np.empty((0, 17), dtype=np.float32) + r = regressors[keep] / _BF_BOX_SCALE + a = _blazeface_anchors()[keep] # (N, 4) cx, cy, 1, 1 + cxs, cys, aws, ahs = a[:, 0:1], a[:, 1:2], a[:, 2:3], a[:, 3:4] + xc, yc = r[:, 0:1] * aws + cxs, r[:, 1:2] * ahs + cys + w, h = r[:, 2:3] * aws, r[:, 3:4] * ahs + out = np.empty((r.shape[0], 17), dtype=np.float32) + out[:, 0:1], out[:, 1:2], out[:, 2:3], out[:, 3:4] = xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2 + out[:, 4:16:2] = r[:, _BF_KP_OFFSET::2] * aws + cxs + out[:, 5:16:2] = r[:, _BF_KP_OFFSET + 1::2] * ahs + cys + out[:, 16] = scores[keep] + return out + + +def _weighted_nms(detections: np.ndarray, iou_thresh: float = 0.5) -> np.ndarray: + """MP weighted NMS — kept boxes are score-weighted averages of overlapping detections.""" + if detections.shape[0] == 0: + return detections + dets = detections[np.argsort(-detections[:, 16])] + N = dets.shape[0] + areas = np.clip(dets[:, 2] - dets[:, 0], 0, None) * np.clip(dets[:, 3] - dets[:, 1], 0, None) + kept: List[np.ndarray] = [] + used = np.zeros(N, dtype=bool) + for i in range(N): + if used[i]: + continue + ax1, ay1, ax2, ay2 = dets[i, 0:4] + merge_idx = [i] + for j in range(i + 1, N): + if used[j]: + continue + bx1, by1, bx2, by2 = dets[j, 0:4] + iw = max(0.0, min(ax2, bx2) - max(ax1, bx1)) + ih = max(0.0, min(ay2, by2) - max(ay1, by1)) + inter = iw * ih + union = areas[i] + areas[j] - inter + if union > 0 and inter / union > iou_thresh: # strict > matches MP + merge_idx.append(j) + used[j] = True + used[i] = True + cluster = dets[merge_idx] + ws = cluster[:, 16:17] + ws_sum = ws.sum() + merged = np.copy(cluster[0]) + if ws_sum > 0: + merged[:16] = (cluster[:, :16] * ws).sum(axis=0) / ws_sum + kept.append(merged) + return np.stack(kept, axis=0) if kept else np.empty((0, 17), dtype=np.float32) + + +def _detection_to_face_rect(detection: np.ndarray, image_w: int, image_h: int) -> Tuple[float, float, float, float, float]: + """Detection (normalized) → rotated 1.5xbbox ROI in image pixels (anisotropic).""" + xmin, ymin, xmax, ymax = detection[0:4] + lx = detection[4 + _FACE_LEFT_EYE_KP * 2 + 0] * image_w + ly = detection[4 + _FACE_LEFT_EYE_KP * 2 + 1] * image_h + rx = detection[4 + _FACE_RIGHT_EYE_KP * 2 + 0] * image_w + ry = detection[4 + _FACE_RIGHT_EYE_KP * 2 + 1] * image_h + # Image-y-down convention: angle = target - atan2(-dy, dx). + angle = _FACE_ROI_TARGET_ANGLE - math.atan2(ly - ry, rx - lx) + return (float((xmin + xmax) * 0.5 * image_w), + float((ymin + ymax) * 0.5 * image_h), + float((xmax - xmin) * image_w * _FACE_ROI_SCALE_X), + float((ymax - ymin) * image_h * _FACE_ROI_SCALE_Y), + float(angle)) + + +def _sample_warp(image_chw: Tensor, src_x: Tensor, src_y: Tensor, padding_mode: str) -> Tensor: + """Bilinear-sample image_chw at corner-aligned (src_x, src_y).""" + H, W = int(image_chw.shape[-2]), int(image_chw.shape[-1]) + grid = torch.stack([(2.0 * src_x + 1.0) / W - 1.0, + (2.0 * src_y + 1.0) / H - 1.0], dim=-1).unsqueeze(0) + return F.grid_sample(image_chw.unsqueeze(0), grid, mode="bilinear", + align_corners=False, padding_mode=padding_mode).squeeze(0) + + +def _warp_face_crop(image_chw: Tensor, cx: float, cy: float, width: float, height: float, + angle: float, output_size: int = _FM_INPUT_SIZE) -> Tensor: + """Rotated rect → output_size² with BORDER_REPLICATE. image_chw must be in [0, 1].""" + s_x, s_y = width / output_size, height / output_size + cos_a, sin_a = math.cos(angle), math.sin(angle) + arange = torch.arange(output_size, dtype=image_chw.dtype, device=image_chw.device) - output_size * 0.5 + v_grid, u_grid = torch.meshgrid(arange, arange, indexing="ij") + src_x = cx + u_grid * s_x * cos_a - v_grid * s_y * sin_a + src_y = cy + u_grid * s_x * sin_a + v_grid * s_y * cos_a + return _sample_warp(image_chw, src_x, src_y, "border") + + +def _blazeface_input_warp(image_chw_raw: Tensor, target: int = _BF_INPUT_SIZE) -> Tuple[Tensor, float, float, float]: + """Centered max(W,H) square → target² with BORDER_ZERO + [-1, 1] norm. + + Sub-pixel grid_sample matters; integer-pad-then-resize drifts the bbox ~5%. + Returns (warped, sub_rect_cx, sub_rect_cy, sub_rect_size) — the triplet maps + tensor-normalized [0,1] detections back to image pixels. + """ + H, W = int(image_chw_raw.shape[1]), int(image_chw_raw.shape[2]) + sub_rect_size = float(max(W, H)) + sub_rect_cx, sub_rect_cy = W * 0.5, H * 0.5 + s = sub_rect_size / target + arange = torch.arange(target, dtype=image_chw_raw.dtype, device=image_chw_raw.device) - target * 0.5 + v_grid, u_grid = torch.meshgrid(arange, arange, indexing="ij") + out = _sample_warp(image_chw_raw, sub_rect_cx + u_grid * s, sub_rect_cy + v_grid * s, "zeros") + return (out / 127.5) - 1.0, sub_rect_cx, sub_rect_cy, sub_rect_size + + +class FaceLandmarker(nn.Module): + """BlazeFace → FaceMesh v2 → blendshapes. `detector_variant` selects 'short' + (128², ≤2m) or 'full' (192² FPN, ≤5m). State dict uses inner-module prefixes + `detector.*` / `mesh.*` / `blendshapes.*`; the outer FaceLandmarkerModel + wrapper rewrites `detector_{variant}.*` keys to `detector.*` before loading. + """ + + def __init__(self, device=None, dtype=None, operations=None, detector_variant: str = "short"): + super().__init__() + det_cls = {"short": BlazeFace, "full": BlazeFaceFullRange}.get(detector_variant) + + self.detector_variant = detector_variant + self.detector = det_cls(device=device, dtype=dtype, operations=operations) + self.mesh = FaceMesh(device=device, dtype=dtype, operations=operations) + self.blendshapes = FaceBlendshapes(device=device, dtype=dtype, operations=operations) + self.register_buffer("_bs_idx", torch.tensor(_BS_INPUT_INDICES, dtype=torch.long), persistent=False) + + def run_detector_batch(self, images_rgb_uint8: List[np.ndarray], + score_thresh: float = _BF_MIN_SCORE, + iou_thresh: float = 0.5): + """Batched detector pass. Returns (img_raws, sub_rects, sizes, per_frame_decoded) + where per_frame_decoded[b] is (N, 17) in tensor-normalized [0,1] coords.""" + if not images_rgb_uint8: + return [], [], [], [] + device, dtype = self.detector.stem.weight.device, self.detector.stem.weight.dtype + det_input_size, decode_fn = ((_BF_FR_INPUT_SIZE, _decode_blazeface_full_range) + if self.detector_variant == "full" + else (_BF_INPUT_SIZE, _decode_blazeface)) + + # Same-size frames: stack once and transfer once. Variable size falls back + # to per-image (only triggers for SAM3DBody's head crops). + sizes = [tuple(img.shape[:2]) for img in images_rgb_uint8] + if len(set(sizes)) == 1: + batch_chw = torch.from_numpy(np.stack(images_rgb_uint8, axis=0)).to(device, dtype).movedim(-1, -3).contiguous() + img_raws = [batch_chw[bi] for bi in range(batch_chw.shape[0])] + else: + img_raws = [torch.from_numpy(img).to(device, dtype).movedim(-1, -3).contiguous() for img in images_rgb_uint8] + + warps = [_blazeface_input_warp(img_raw, det_input_size) for img_raw in img_raws] + det_crops = [w[0] for w in warps] + sub_rects = [(w[1], w[2], w[3]) for w in warps] + + regs_b, cls_b = self.detector(torch.stack(det_crops, dim=0)) + regs_np, cls_np = regs_b.float().cpu().numpy(), cls_b.float().cpu().numpy() + per_frame = [] + for b in range(len(images_rgb_uint8)): + decoded = decode_fn(regs_np[b], cls_np[b], score_thresh=score_thresh) + per_frame.append(_weighted_nms(decoded, iou_thresh=iou_thresh) if decoded.shape[0] > 0 else decoded) + return img_raws, sub_rects, sizes, per_frame + + def detect_batch(self, images_rgb_uint8: List[np.ndarray], num_faces: int = 1, + score_thresh: float = _BF_MIN_SCORE) -> List[List[dict]]: + """Full pipeline batched across `images_rgb_uint8`. Returns one face-dict + list per image (empty if nothing detected). Face dict: + bbox_xyxy (4,) image pixels, blendshapes {52} ∈ [0,1], + landmarks_xy (478, 2) image pixels, landmarks_3d (478, 3) in + 192-canonical (pre-transformation) units, presence float (raw logit). + """ + img_raws, sub_rects, sizes, per_frame_dets = self.run_detector_batch( + images_rgb_uint8, score_thresh=score_thresh, + ) + # tensor-normalized → image-normalized [0,1] for _detection_to_face_rect. + for b, decoded in enumerate(per_frame_dets): + if decoded.shape[0] == 0: + continue + cx, cy, size = sub_rects[b] + H, W = sizes[b] + sx0, sy0 = cx - size * 0.5, cy - size * 0.5 + decoded[:, 0:16:2] = (sx0 + size * decoded[:, 0:16:2]) / W + decoded[:, 1:16:2] = (sy0 + size * decoded[:, 1:16:2]) / H + if num_faces > 0: + per_frame_dets[b] = decoded[: int(num_faces)] + + # Collect every detected face across all frames into one mesh input. + face_params: List[Tuple[int, float, float, float, float, float, float]] = [] + mesh_crops: List[Tensor] = [] + for b, dets in enumerate(per_frame_dets): + if dets.shape[0] == 0: + continue + H, W = sizes[b] + img_for_mesh = img_raws[b] / 255.0 + for det in dets: + cx, cy, w, h, angle = _detection_to_face_rect(det, W, H) + mesh_crops.append(_warp_face_crop(img_for_mesh, cx, cy, w, h, angle, _FM_INPUT_SIZE)) + face_params.append((b, float(det[16]), cx, cy, w, h, angle)) + + results: List[List[dict]] = [[] for _ in range(len(images_rgb_uint8))] + if not mesh_crops: + return results + + lmks_canon_b, presence_b = self.mesh(torch.stack(mesh_crops, dim=0)) + bs_out_b = self.blendshapes(lmks_canon_b[:, self._bs_idx, :2]) + + # Batched canonical→image affine + params_t = torch.tensor( + [(cx, cy, w, h, math.cos(a), math.sin(a)) for (_b, _s, cx, cy, w, h, a) in face_params], + device=lmks_canon_b.device, dtype=lmks_canon_b.dtype, + ) + cxs, cys, ws, hs, cos_a, sin_a = params_t.unbind(dim=1) + inv = 1.0 / _FM_INPUT_SIZE + u = lmks_canon_b[..., 0] - _FM_INPUT_SIZE * 0.5 + v = lmks_canon_b[..., 1] - _FM_INPUT_SIZE * 0.5 + lmks_xy_t = torch.stack([ + cxs[:, None] + u * (ws * inv * cos_a)[:, None] - v * (hs * inv * sin_a)[:, None], + cys[:, None] + u * (ws * inv * sin_a)[:, None] + v * (hs * inv * cos_a)[:, None], + ], dim=-1) + + lmks_xy_np = lmks_xy_t.float().cpu().numpy() + lmks_canon_np = lmks_canon_b.float().cpu().numpy() + presence_np = presence_b.float().cpu().numpy() + bs_np = bs_out_b.float().cpu().numpy() + + for i, (b, score, *_) in enumerate(face_params): + lmks_xy = lmks_xy_np[i] + mn, mx = lmks_xy.min(0), lmks_xy.max(0) + results[b].append({ + "bbox_xyxy": np.array([mn[0], mn[1], mx[0], mx[1]], dtype=np.float32), + "blendshapes": dict(zip(BLENDSHAPE_NAMES, bs_np[i].tolist())), + "landmarks_xy": lmks_xy, + "landmarks_3d": lmks_canon_np[i], + "presence": float(presence_np[i]), + "score": score, + }) + return results diff --git a/comfy_extras/nodes_mediapipe.py b/comfy_extras/nodes_mediapipe.py new file mode 100644 index 000000000..2e67ae83f --- /dev/null +++ b/comfy_extras/nodes_mediapipe.py @@ -0,0 +1,502 @@ +"""ComfyUI nodes for the pure-PyTorch MediaPipe Face Landmarker port. + +Custom IO types: + FACE_LANDMARKER — FaceLandmarkerModel wrapper (ModelPatcher inside) + FACE_LANDMARKS — {"frames": List[List[face_dict]], "image_size": (H, W), + "connection_sets": dict[str, frozenset[(int, int)]]} + face_dict: bbox_xyxy, blendshapes, landmarks_xy, + landmarks_3d, presence, score, transformation_matrix + +MediaPipeFaceLandmarker also emits the core BOUNDING_BOX type — pair with DrawBBoxes. +""" + +from __future__ import annotations + +import numpy as np +import torch +from PIL import Image, ImageColor, ImageDraw +from tqdm.auto import tqdm +from typing_extensions import override + +import comfy.model_management +import comfy.model_patcher +import comfy.utils +import folder_paths +from comfy_api.latest import ComfyExtension, io + +from comfy_extras.mediapipe.face_landmarker import FaceLandmarker +from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection + + +FaceLandmarkerType = io.Custom("FACE_LANDMARKER") +FaceLandmarksType = io.Custom("FACE_LANDMARKS") + +_CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights") +_CONTOUR_PARTS = ("face_oval", "left_eye", "right_eye", "left_eyebrow", "right_eyebrow", "lips") + + +class FaceLandmarkerModel: + """Loaded FaceLandmarker variants + ModelPatcher per variant. + + Safetensors layout: `detector_short.*` / `detector_full.*` plus shared + `mesh.*`, `blendshapes.*`, `canonical_*`, and `topology.*`. + PReLU forces plain-nn / fp32 (manual_cast strands buffers across devices). + """ + + def __init__(self, state_dict: dict): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + self.dtype = torch.float32 + + # FACEMESH_* connection sets, embedded as int32 (N, 2) under topology.*. + base: dict[str, frozenset] = {} + for k in [k for k in state_dict if k.startswith("topology.")]: + base[k[len("topology."):]] = frozenset(map(tuple, state_dict.pop(k).tolist())) + base["contours"] = frozenset().union(*(base[p] for p in _CONTOUR_PARTS)) + base["all"] = base["contours"] | base["irises"] | base["nose"] + + self.connection_sets: dict[str, frozenset] = base + self.canonical_data: dict[str, np.ndarray] = {k: state_dict.pop(k).numpy() for k in _CANONICAL_KEYS} + + shared = {k: v for k, v in state_dict.items() if k.startswith(("mesh.", "blendshapes."))} + + self.models: dict[str, FaceLandmarker] = {} + self.patchers: dict[str, comfy.model_patcher.ModelPatcher] = {} + for variant in ("short", "full"): + prefix = f"detector_{variant}." + sub = dict(shared) + sub.update({f"detector.{k[len(prefix):]}": v for k, v in state_dict.items() if k.startswith(prefix)}) + fl = FaceLandmarker(device=offload_device, dtype=self.dtype, operations=None, detector_variant=variant).eval() + fl.load_state_dict(sub, strict=False) + + self.models[variant] = fl + self.patchers[variant] = comfy.model_patcher.CoreModelPatcher( + fl, load_device=self.load_device, offload_device=offload_device, + size=comfy.model_management.module_size(fl), + ) + + def detect_batch(self, images, num_faces: int, score_thresh: float, variant: str): + comfy.model_management.load_model_gpu(self.patchers[variant]) + return self.models[variant].detect_batch(images, num_faces=num_faces, score_thresh=score_thresh) + + +def _image_to_uint8(image: torch.Tensor) -> np.ndarray: + return image[..., :3].mul(255.0).add_(0.5).clamp_(0, 255).to(torch.uint8).cpu().numpy() + + +def _parse_color(color: str) -> tuple[int, int, int]: + try: + return ImageColor.getrgb(color)[:3] + except ValueError: + return (0, 255, 0) + + +def _copy_face(face: dict) -> dict: + """Shallow copy of a face_dict with array-fields cloned so callers can mutate.""" + return { + "bbox_xyxy": face["bbox_xyxy"].copy(), + "blendshapes": dict(face["blendshapes"]), + "landmarks_xy": face["landmarks_xy"].copy(), + "landmarks_3d": face["landmarks_3d"].copy(), + "presence": face["presence"], + "score": face["score"], + } + + +def _lerp_face(a: dict, b: dict, t: float) -> dict: + return { + "bbox_xyxy": (1 - t) * a["bbox_xyxy"] + t * b["bbox_xyxy"], + "blendshapes": {k: (1 - t) * a["blendshapes"][k] + t * b["blendshapes"][k] for k in a["blendshapes"]}, + "landmarks_xy": (1 - t) * a["landmarks_xy"] + t * b["landmarks_xy"], + "landmarks_3d": (1 - t) * a["landmarks_3d"] + t * b["landmarks_3d"], + "presence": (1 - t) * a["presence"] + t * b["presence"], + "score": (1 - t) * a["score"] + t * b["score"], + } + + +def _match_faces(a: list[dict], b: list[dict]) -> list[tuple[int, int]]: + """Greedy nearest-neighbour pairing of faces between two frames by bbox + centre distance. Unmatched (when counts differ) are dropped.""" + if not a or not b: + return [] + centers_a = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]), + 0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in a]) + centers_b = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]), + 0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in b]) + dists = np.linalg.norm(centers_a[:, None] - centers_b[None], axis=-1) + pairs: list[tuple[int, int]] = [] + used_a: set[int] = set() + used_b: set[int] = set() + candidates = sorted((dists[ia, ib], ia, ib) for ia in range(len(a)) for ib in range(len(b))) + for _, ia, ib in candidates: + if ia in used_a or ib in used_b: + continue + pairs.append((ia, ib)) + used_a.add(ia) + used_b.add(ib) + return pairs + + +def _fill_missing_frames(frames: list[list[dict]], mode: str) -> None: + """In-place fill empty frame slots from neighbouring detections. Multi-face + aware: pairs faces across bracketing frames by greedy bbox-centre NN. + When counts differ, unmatched faces are dropped from the synthesised frame.""" + if mode == "empty": + return + valid = [i for i, fr in enumerate(frames) if fr] + if not valid: + return # nothing to fill from + if mode == "previous": + last: list[dict] = [] + for i, fr in enumerate(frames): + if fr: + last = fr + elif last: + frames[i] = [_copy_face(f) for f in last] + return + # interpolate: lerp between bracketing valid frames; clamp at ends. + for i in range(len(frames)): + if frames[i]: + continue + prev_i = max((v for v in valid if v < i), default=None) + next_i = min((v for v in valid if v > i), default=None) + if prev_i is None: + frames[i] = [_copy_face(f) for f in frames[next_i]] + elif next_i is None: + frames[i] = [_copy_face(f) for f in frames[prev_i]] + else: + t = (i - prev_i) / (next_i - prev_i) + pairs = _match_faces(frames[prev_i], frames[next_i]) + frames[i] = [_lerp_face(frames[prev_i][a], frames[next_i][b], t) for a, b in pairs] + + +def _ordered_rings(edges: frozenset[tuple[int, int]]) -> list[list[int]]: + """Walk an unordered edge set into one or more closed-loop vertex rings + (handles multi-loop sets like FACEMESH_LIPS: outer + inner).""" + adj: dict[int, set[int]] = {} + for a, b in edges: + adj.setdefault(a, set()).add(b) + adj.setdefault(b, set()).add(a) + visited: set[int] = set() + rings: list[list[int]] = [] + for start in adj: + if start in visited: + continue + ring = [start] + visited.add(start) + prev, cur = -1, start + while True: + nxt = next((v for v in adj[cur] if v != prev), None) + if nxt is None or nxt == start: + break + ring.append(nxt) + visited.add(nxt) + prev, cur = cur, nxt + rings.append(ring) + return rings + + +class LoadMediaPipeFaceLandmarker(io.ComfyNode): + """Load MediaPipe Face Landmarker v2 weights. Contains both detector variants + (short / full), shared mesh, blendshapes, and canonical geometry.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadMediaPipeFaceLandmarker", + display_name="Load MediaPipe Face Landmarker", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("mediapipe"), + tooltip="Face Landmarker safetensors from models/mediapipe/."), + ], + outputs=[FaceLandmarkerType.Output()], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("mediapipe", model_name), safe_load=True) + wrapper = FaceLandmarkerModel(sd) + return io.NodeOutput(wrapper) + + +# Per-frame fallback modes for detection failures in a batch. +_FALLBACK_MODES = ("empty", "previous", "interpolate") + + +class MediaPipeFaceLandmarker(io.ComfyNode): + """BlazeFace → FaceMesh v2 → ARKit-52 blendshapes, batched across the + input. Also emits a BOUNDING_BOX list (landmark-extent bbox per face) — + pair with DrawBBoxes for detector-only viz or MediaPipeFaceMeshVisualize + for the mesh overlay.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MediaPipeFaceLandmarker", + display_name="MediaPipe Face Landmarker", + category="image/detection", + inputs=[ + FaceLandmarkerType.Input("face_landmarker"), + io.Image.Input("image"), + io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short", + tooltip="Face detector range. 'short' is tuned for close-up faces " + "(within ~2 m of the camera); 'full' covers farther / smaller " + "faces (up to ~5 m) but is slower. 'both' runs both detectors and " + "keeps whichever found more faces per frame (~2× detection cost)."), + io.Int.Input("num_faces", default=1, min=0, max=16, step=1, + tooltip="Maximum faces to return per frame. 0 = no cap (return all detected)."), + io.Float.Input("min_confidence", default=0.5, min=0.0, max=1.0, step=0.01, advanced=True, + tooltip="BlazeFace score threshold. Lower to catch small/occluded faces."), + io.Combo.Input("missing_frame_fallback", options=list(_FALLBACK_MODES), default="empty", advanced=True, + tooltip="Per-frame behaviour when detection fails in a batch. " + "'empty' leaves the frame faceless. 'previous' copies the most recent successful " + "detection. 'interpolate' lerps landmarks/bbox/blendshapes between bracketing " + "successful frames. Multi-face: pairs faces across frames by greedy bbox-centre NN."), + ], + outputs=[ + FaceLandmarksType.Output(display_name="face_landmarks"), + io.BoundingBox.Output("bboxes"), + ], + ) + + @classmethod + def execute(cls, face_landmarker, image, detector_variant, num_faces, min_confidence, + missing_frame_fallback) -> io.NodeOutput: + canonical = face_landmarker.canonical_data + img_np = _image_to_uint8(image) + B, H, W = img_np.shape[:3] + chunk = 16 + is_both = detector_variant == "both" + total_work = 2 * B if is_both else B + pbar = comfy.utils.ProgressBar(total_work) + + def _run(variant: str) -> list[list[dict]]: + res: list[list[dict]] = [] + with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq: + for i in range(0, B, chunk): + end = min(i + chunk, B) + res.extend(face_landmarker.detect_batch( + [img_np[bi] for bi in range(i, end)], + num_faces=int(num_faces), + score_thresh=float(min_confidence), + variant=variant, + )) + pbar.update_absolute(min(pbar.current + (end - i), total_work)) + tq.update(end - i) + return res + + if is_both: + short_res = _run("short") + full_res = _run("full") + # Per-frame keep whichever found more faces (tie → short). + frames: list[list[dict]] = [ + short_res[bi] if len(short_res[bi]) >= len(full_res[bi]) else full_res[bi] + for bi in range(B) + ] + else: + frames = _run(detector_variant) + _fill_missing_frames(frames, missing_frame_fallback) + bboxes = [] + for per_frame in frames: + per_bb = [] + for f in per_frame: + f["transformation_matrix"] = transformation_matrix_from_detection(f, W, H, canonical) + x1, y1, x2, y2 = (float(v) for v in f["bbox_xyxy"]) + per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])}) + bboxes.append(per_bb) + return io.NodeOutput({"frames": frames, "image_size": (H, W), + "connection_sets": face_landmarker.connection_sets}, bboxes) + + +# Topology keys unioned by the 'all' connections preset (contour parts + irises + nose). +_ALL_CONNECTION_PARTS: tuple[str, ...] = (*_CONTOUR_PARTS, "irises", "nose") +_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = ( + ("face_oval", True), + ("lips", True), + ("left_eye", True), + ("right_eye", True), + ("left_eyebrow", True), + ("right_eyebrow", True), + ("irises", True), + ("nose", True), + ("tesselation", False), +) + + +class MediaPipeFaceMeshVisualize(io.ComfyNode): + """Draw a FACEMESH_* subset over an image. Topology travels with the + FACE_LANDMARKS payload (set at detection time).""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MediaPipeFaceMeshVisualize", + display_name="MediaPipe Face Mesh Visualize", + category="image/detection", + inputs=[ + FaceLandmarksType.Input("face_landmarks"), + io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."), + io.DynamicCombo.Input( + "connections", + tooltip="'all' = oval+eyes+brows+lips+irises+nose. 'fill' = solid face_oval polygon (silhouette mask). 'custom' = toggle each feature individually (including 'tesselation', the full 2547-edge wireframe).", + options=[ + io.DynamicCombo.Option("all", []), + io.DynamicCombo.Option("fill", []), + io.DynamicCombo.Option("custom", [ + io.Boolean.Input(feat, default=default, + tooltip=f"Draw the '{feat}' connection set.") + for feat, default in _CUSTOM_FEATURES + ]), + ], + ), + io.Color.Input("color", default="#00ff00"), + io.Int.Input("thickness", default=1, min=0, max=8, step=1, + tooltip="Edge line thickness in pixels. 0 disables edge drawing."), + io.Int.Input("point_size", default=2, min=0, max=16, step=1, + tooltip="Landmark dot radius in pixels. 0 disables point drawing."), + ], + outputs=[io.Image.Output()], + ) + + @classmethod + def execute(cls, face_landmarks, connections, color, thickness, point_size, image=None) -> io.NodeOutput: + sets = face_landmarks["connection_sets"] + sel = connections["connections"] + fill_rings: list[list[int]] | None = None + if sel == "fill": + fill_rings = _ordered_rings(sets["face_oval"]) + edges = frozenset() + elif sel == "custom": + parts = [feat for feat, _ in _CUSTOM_FEATURES if connections.get(feat, False)] + edges = frozenset().union(*(sets[p] for p in parts)) + else: # "all" + edges = frozenset().union(*(sets[p] for p in _ALL_CONNECTION_PARTS)) + rgb, thick, psize = _parse_color(color), int(thickness), int(point_size) + frames = face_landmarks["frames"] + if image is None: + H, W = face_landmarks["image_size"] + img_np = np.zeros((len(frames), H, W, 3), dtype=np.uint8) + else: + img_np = _image_to_uint8(image) + B = img_np.shape[0] + n_frames = len(frames) + pbar = comfy.utils.ProgressBar(B) + out = np.empty_like(img_np) + for bi in range(B): + faces = frames[bi] if bi < n_frames else [] + out[bi] = _draw_mesh(img_np[bi], faces, edges, rgb, thick, psize, fill_rings) + pbar.update_absolute(bi + 1) + return io.NodeOutput(torch.from_numpy(out).to( + device=comfy.model_management.intermediate_device(), + dtype=comfy.model_management.intermediate_dtype(), + ).div_(255.0)) + + +def _draw_mesh(image_rgb: np.ndarray, faces: list, edges, + rgb: tuple[int, int, int], thickness: int, + point_size: int, fill_rings: list[list[int]] | None = None) -> np.ndarray: + draw_edges = thickness > 0 and edges + if not faces or (fill_rings is None and not draw_edges and point_size <= 0): + return image_rgb.copy() + pil = Image.fromarray(image_rgb) + draw = ImageDraw.Draw(pil) + r = point_size * 0.5 + if fill_rings is not None: + for f in faces: + lmks = f["landmarks_xy"] + for ring in fill_rings: + draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=rgb) + return np.asarray(pil) + for f in faces: + lmks = f["landmarks_xy"] + n = lmks.shape[0] + if draw_edges: + for a, b in edges: + if a < n and b < n: + draw.line([(float(lmks[a, 0]), float(lmks[a, 1])), + (float(lmks[b, 0]), float(lmks[b, 1]))], fill=rgb, width=thickness) + if point_size == 1: + draw.point(lmks.flatten().tolist(), fill=rgb) + elif point_size > 1: + for x, y in lmks: + draw.ellipse((float(x) - r, float(y) - r, float(x) + r, float(y) + r), fill=rgb) + return np.asarray(pil) + + +# Mask region presets — closed-loop topologies only. +_MASK_REGIONS: tuple[str, ...] = ("face_oval", "lips", "left_eye", "right_eye", "irises") +_MASK_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = ( + ("face_oval", True), + ("lips", False), + ("left_eye", False), + ("right_eye", False), + ("irises", False), +) + + +class MediaPipeFaceMask(io.ComfyNode): + """Binary mask from face landmarks, filled polygon per face. One mask per + frame in the batch; faces in the same frame composite (union).""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MediaPipeFaceMask", + display_name="MediaPipe Face Mask", + category="image/detection", + inputs=[ + FaceLandmarksType.Input("face_landmarks"), + io.DynamicCombo.Input( + "regions", + tooltip="'all' = union of face_oval+lips+eyes+irises (which collapses to face_oval since it encloses the rest). 'custom' = toggle each region individually for combos like lips+eyes.", + options=[ + io.DynamicCombo.Option("all", []), + io.DynamicCombo.Option("custom", [ + io.Boolean.Input(reg, default=default, + tooltip=f"Include the '{reg}' region in the mask.") + for reg, default in _MASK_CUSTOM_FEATURES + ]), + ], + ), + ], + outputs=[io.Mask.Output()], + ) + + @classmethod + def execute(cls, face_landmarks, regions) -> io.NodeOutput: + sets = face_landmarks["connection_sets"] + sel = regions["regions"] + if sel == "custom": + picked = [reg for reg, _ in _MASK_CUSTOM_FEATURES if regions.get(reg, False)] + else: + picked = list(_MASK_REGIONS) + rings = [r for reg in picked for r in _ordered_rings(sets[reg])] + frames = face_landmarks["frames"] + H, W = face_landmarks["image_size"] + masks = np.zeros((len(frames), H, W), dtype=np.uint8) + pbar = comfy.utils.ProgressBar(len(frames)) + for bi, per_frame in enumerate(frames): + if per_frame: + pil = Image.new("L", (W, H), 0) + draw = ImageDraw.Draw(pil) + for f in per_frame: + lmks = f["landmarks_xy"] + for ring in rings: + draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=255) + masks[bi] = np.asarray(pil) + pbar.update_absolute(bi + 1) + return io.NodeOutput(torch.from_numpy(masks).to( + device=comfy.model_management.intermediate_device(), + dtype=comfy.model_management.intermediate_dtype(), + ).div_(255.0)) + + +class MediaPipeFaceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [LoadMediaPipeFaceLandmarker, MediaPipeFaceLandmarker, MediaPipeFaceMeshVisualize, MediaPipeFaceMask] + + +async def comfy_entrypoint() -> MediaPipeFaceExtension: + return MediaPipeFaceExtension() diff --git a/execution.py b/execution.py index 73a454751..f6934ea6a 100644 --- a/execution.py +++ b/execution.py @@ -2,6 +2,7 @@ import copy import heapq import inspect import logging +import psutil import sys import threading import time @@ -732,6 +733,7 @@ class PromptExecutor: self._notify_prompt_lifecycle("start", prompt_id) ram_headroom = int(self.cache_args["ram"] * (1024 ** 3)) + ram_inactive_headroom = int(self.cache_args["ram_inactive"] * (1024 ** 3)) ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom) @@ -785,8 +787,14 @@ class PromptExecutor: execution_list.complete_node_execution() if self.cache_type == CacheType.RAM_PRESSURE: - comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom) - ram_release_callback(ram_headroom, free_active=True) + ram_release_callback(ram_inactive_headroom) + ram_shortfall = ram_headroom - psutil.virtual_memory().available + freed = comfy.model_management.free_pins(ram_shortfall + 512 * (1024 ** 2)) + if freed < ram_shortfall: + if freed > 64 * (1024 ** 2): + # AIMDO MEM_DECOMMIT can outrun psutil.available catching up. + time.sleep(0.05) + ram_release_callback(ram_headroom, free_active=True) else: # Only execute when the while-loop ends without break # Send cached UI for intermediate output nodes that weren't executed diff --git a/folder_paths.py b/folder_paths.py index ad7f0f4fc..ce152eb37 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -60,6 +60,8 @@ folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geom folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions) +folder_names_and_paths["mediapipe"] = ([os.path.join(models_dir, "mediapipe")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") diff --git a/main.py b/main.py index a6fdaf43c..1e47cab84 100644 --- a/main.py +++ b/main.py @@ -283,19 +283,25 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]: def prompt_worker(q, server_instance): current_time: float = 0.0 - cache_ram = args.cache_ram - if cache_ram < 0: + cache_ram = 0 + cache_ram_inactive = 0 + if not args.cache_classic and not args.cache_none and args.cache_lru <= 0: cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0)) + cache_ram_inactive = min(96.0, max(12.0, comfy.model_management.total_ram * 0.75 / 1024.0)) + if len(args.cache_ram) > 0: + cache_ram = args.cache_ram[0] + if len(args.cache_ram) > 1: + cache_ram_inactive = args.cache_ram[1] - cache_type = execution.CacheType.CLASSIC - if args.cache_lru > 0: + cache_type = execution.CacheType.RAM_PRESSURE + if args.cache_classic: + cache_type = execution.CacheType.CLASSIC + elif args.cache_lru > 0: cache_type = execution.CacheType.LRU - elif cache_ram > 0: - cache_type = execution.CacheType.RAM_PRESSURE elif args.cache_none: cache_type = execution.CacheType.NONE - e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } ) + e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram, "ram_inactive" : cache_ram_inactive } ) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 diff --git a/models/mediapipe/put_mediapipe_models_here b/models/mediapipe/put_mediapipe_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index fdd6eeb5f..13e46ac8a 100644 --- a/nodes.py +++ b/nodes.py @@ -2444,6 +2444,7 @@ async def init_builtin_extra_nodes(): "nodes_hidream_o1.py", "nodes_save_3d.py", "nodes_moge.py", + "nodes_mediapipe.py", ] import_failed = [] diff --git a/requirements.txt b/requirements.txt index 1c87690da..d2986eda8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0 filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo==0.3.0 +comfy-aimdo==0.4.3 requests simpleeval>=1.0.0 blake3 diff --git a/tests/execution/test_async_nodes.py b/tests/execution/test_async_nodes.py index c771b4b36..54660c112 100644 --- a/tests/execution/test_async_nodes.py +++ b/tests/execution/test_async_nodes.py @@ -14,7 +14,6 @@ from tests.execution.test_execution import ComfyClient, run_warmup class TestAsyncNodes: @fixture(scope="class", autouse=True, params=[ (False, 0), - (True, 0), (True, 100), ]) def _server(self, args_pytest, request): @@ -29,6 +28,8 @@ class TestAsyncNodes: use_lru, lru_size = request.param if use_lru: pargs += ['--cache-lru', str(lru_size)] + else: + pargs += ['--cache-classic'] # Running server with args: pargs p = subprocess.Popen(pargs) yield diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index f73ca7e3c..15e2304fc 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -183,8 +183,7 @@ class TestExecution: # Initialize server and client # @fixture(scope="class", autouse=True, params=[ - { "extra_args" : [], "should_cache_results" : True }, - { "extra_args" : ["--cache-lru", 0], "should_cache_results" : True }, + { "extra_args" : ["--cache-classic"], "should_cache_results" : True }, { "extra_args" : ["--cache-lru", 100], "should_cache_results" : True }, { "extra_args" : ["--cache-none"], "should_cache_results" : False }, ])