Compare commits

..

1 Commits

Author SHA1 Message Date
63dc90e6c0 chore(openapi): sync shared API contract from cloud@773d43b 2026-05-29 22:40:17 +00:00
13 changed files with 10072 additions and 11888 deletions

View File

@ -149,7 +149,6 @@ parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=Non
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.") parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.") parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
parser.add_argument("--fast-disk", action="store_true", help="Prefer disk-backed dynamic loading and offload over unpinned RAM. Can be faster for users with fast NVME disks.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")

View File

@ -15,6 +15,15 @@ import comfy.patcher_extension
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
import comfy.ldm.common_dit import comfy.ldm.common_dit
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
) -> torch.Tensor:
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
return t_out
# ---------------------- Feed Forward Network ----------------------- # ---------------------- Feed Forward Network -----------------------
class GPT2FeedForward(nn.Module): class GPT2FeedForward(nn.Module):
@ -164,7 +173,8 @@ class Attention(nn.Module):
k = self.k_norm(k) k = self.k_norm(k)
v = self.v_norm(v) v = self.v_norm(v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention! if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
q, k = comfy.quant_ops.ck.apply_rope_split_half(q, k, rope_emb) q = apply_rotary_pos_emb(q, rope_emb)
k = apply_rotary_pos_emb(k, rope_emb)
return q, k, v return q, k, v
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb) q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)

View File

@ -51,6 +51,15 @@ class FeedForward(nn.Module):
return hidden_states return hidden_states
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape)
class QwenTimestepProjEmbeddings(nn.Module): class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None): def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
super().__init__() super().__init__()

View File

@ -4,7 +4,6 @@ import dataclasses
import torch import torch
from typing import NamedTuple from typing import NamedTuple
import comfy_aimdo.host_buffer
from comfy.quant_ops import QuantizedTensor from comfy.quant_ops import QuantizedTensor
@ -18,18 +17,21 @@ class TensorFileSlice(NamedTuple):
def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None): def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):
if isinstance(tensor, QuantizedTensor): if isinstance(tensor, QuantizedTensor):
if not read_tensor_file_slice_into(tensor._qdata, if not isinstance(destination, QuantizedTensor):
destination._qdata if destination is not None else None, stream=stream, return False
if tensor._layout_cls != destination._layout_cls:
return False
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream,
destination2=(destination2._qdata if destination2 is not None else None)): destination2=(destination2._qdata if destination2 is not None else None)):
return False return False
if destination is not None:
dst_orig_dtype = destination._params.orig_dtype dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False) destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype) destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
if destination2 is not None: if destination2 is not None:
dst_orig_dtype = destination2._params.orig_dtype dst_orig_dtype = destination2._params.orig_dtype
destination2._params.copy_from(destination._params if destination is not None else tensor._params, non_blocking=True) destination2._params.copy_from(destination._params, non_blocking=True)
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype) destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
return True return True
@ -37,15 +39,10 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
if info is None: if info is None:
return False return False
if destination is not None and destination.device.type != "cpu" and destination2 is None:
destination2 = destination
destination = None
file_obj = info.file_ref file_obj = info.file_ref
if (file_obj is None if (destination.device.type != "cpu"
or (destination is None and destination2 is None) or file_obj is None
or (destination is not None and (destination.device.type != "cpu" or destination.numel() * destination.element_size() < info.size)) or destination.numel() * destination.element_size() < info.size
or (destination2 is not None and (destination2.device.type == "cpu" or destination2.numel() * destination2.element_size() < info.size))
or tensor.numel() * tensor.element_size() != info.size or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0 or tensor.storage_offset() != 0
or not tensor.is_contiguous()): or not tensor.is_contiguous()):
@ -54,14 +51,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
if info.size == 0: if info.size == 0:
return True return True
if destination is None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
comfy_aimdo.host_buffer.read_file_to_device(file_obj, info.offset, info.size,
stream_ptr, destination2.data_ptr(),
destination2.device.index,
mark_cold=False)
return True
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None) hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
if hostbuf is not None: if hostbuf is not None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0 stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
@ -74,9 +63,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
device=None if destination2 is None else destination2.device.index) device=None if destination2 is None else destination2.device.index)
return True return True
if not hasattr(file_obj, "seek") or not hasattr(file_obj, "readinto"):
return False
buf_type = ctypes.c_ubyte * info.size buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr())) view = memoryview(buf_type.from_address(destination.data_ptr()))

View File

@ -641,17 +641,14 @@ def free_pins(size, evict_active=False):
return freed_total return freed_total
def ensure_pin_budget(size, evict_active=False): def ensure_pin_budget(size, evict_active=False):
if args.fast_disk: shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
else:
shortfall = size + max(comfy.memory_management.RAM_CACHE_HEADROOM / 2, 2048 * 1024 ** 2) - psutil.virtual_memory().available
if shortfall <= 0: if shortfall <= 0:
return True return True
to_free = shortfall + PIN_PRESSURE_HYSTERESIS to_free = shortfall + PIN_PRESSURE_HYSTERESIS
return free_pins(to_free, evict_active=evict_active) >= shortfall return free_pins(to_free, evict_active=evict_active) >= shortfall
def ensure_pin_registerable(size, evict_active=True): def ensure_pin_registerable(size, evict_active=False):
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0: if MAX_PINNED_MEMORY <= 0:
return False return False
@ -661,14 +658,7 @@ def ensure_pin_registerable(size, evict_active=True):
shortfall += REGISTERABLE_PIN_HYSTERESIS shortfall += REGISTERABLE_PIN_HYSTERESIS
for loaded_model in reversed(current_loaded_models): for loaded_model in reversed(current_loaded_models):
model = loaded_model.model model = loaded_model.model
if model is not None and model.is_dynamic() and not model.model.dynamic_pins[model.load_device]["active"]: if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
shortfall -= model.unregister_inactive_pins(shortfall)
if shortfall <= 0:
return True
if evict_active:
for loaded_model in current_loaded_models:
model = loaded_model.model
if model is not None and model.is_dynamic() and model.model.dynamic_pins[model.load_device]["active"]:
shortfall -= model.unregister_inactive_pins(shortfall) shortfall -= model.unregister_inactive_pins(shortfall)
if shortfall <= 0: if shortfall <= 0:
return True return True
@ -813,9 +803,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for x in can_unload_sorted: for x in can_unload_sorted:
i = x[-1] i = x[-1]
memory_to_free = 1e32 memory_to_free = 1e32
if not DISABLE_SMART_MEMORY or device is None: if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None):
memory_to_free = 0 if device is None else memory_required - get_free_memory(device) memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
if current_loaded_models[i].model.is_dynamic() and for_dynamic: if for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models #don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand. #as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size() memory_required -= current_loaded_models[i].model.loaded_size()
@ -827,10 +817,6 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for i in sorted(unloaded_model, reverse=True): for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i)) unloaded_models.append(current_loaded_models.pop(i))
if not for_dynamic and pins_required > 0:
ensure_pin_budget(pins_required)
ensure_pin_registerable(pins_required)
if len(unloaded_model) > 0: if len(unloaded_model) > 0:
soft_empty_cache() soft_empty_cache()
elif device is not None: elif device is not None:
@ -893,19 +879,15 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
model_to_unload.model_finalizer.detach() model_to_unload.model_finalizer.detach()
total_memory_required = {} total_memory_required = {}
total_pins_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
device = loaded_model.device device = loaded_model.device
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device) total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
if not loaded_model.model.is_dynamic():
total_pins_required[device] = total_pins_required.get(device, 0) + loaded_model.model_memory()
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, free_memory(total_memory_required[device] * 1.1 + extra_mem,
device, device,
for_dynamic=free_for_dynamic, for_dynamic=free_for_dynamic)
pins_required=total_pins_required.get(device, 0))
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
@ -1301,6 +1283,7 @@ STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {} STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
STREAM_PIN_BUFFERS = {}
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
@ -1343,13 +1326,42 @@ def get_aimdo_cast_buffer(offload_stream, device):
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer return cast_buffer
def get_pin_buffer(offload_stream):
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
if pin_buffer is None:
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3), mark_cold=False)
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
elif offload_stream is not None:
event = getattr(pin_buffer, "_comfy_event", None)
if event is not None:
event.synchronize()
delattr(pin_buffer, "_comfy_event")
return pin_buffer
def resize_pin_buffer(pin_buffer, size):
global TOTAL_PINNED_MEMORY
old_size = pin_buffer.size
if size <= old_size:
return True
growth = size - old_size
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
ensure_pin_budget(growth, evict_active=True)
ensure_pin_registerable(growth, evict_active=True)
try:
pin_buffer.extend(size=size, reallocate=True)
except RuntimeError:
return False
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
return True
def reset_cast_buffers(): def reset_cast_buffers():
global TOTAL_PINNED_MEMORY
global LARGEST_CASTED_WEIGHT global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_CASTED_WEIGHT = (None, 0)
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS):
if offload_stream is not None: if offload_stream is not None:
offload_stream.synchronize() offload_stream.synchronize()
synchronize() synchronize()
@ -1358,24 +1370,20 @@ def reset_cast_buffers():
mmap_obj.bounce() mmap_obj.bounce()
DIRTY_MMAPS.clear() DIRTY_MMAPS.clear()
for pin_buffer in STREAM_PIN_BUFFERS.values():
TOTAL_PINNED_MEMORY -= pin_buffer.size
TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY)
for loaded_model in current_loaded_models: for loaded_model in current_loaded_models:
model = loaded_model.model model = loaded_model.model
if model is not None and model.is_dynamic(): if model is not None and model.is_dynamic():
pin_state = model.model.dynamic_pins[model.load_device] model.model.dynamic_pins[model.load_device]["active"] = False
if pin_state["active"]:
*_, buckets = pin_state["weights"]
for size, bucket in list(buckets.items()):
bucket[:] = [ entry for entry in bucket if entry[-1] is not None ]
if not bucket:
del buckets[size]
pin_state["active"] = False
model.partially_unload_ram(1e30, subsets=[ "patches" ]) model.partially_unload_ram(1e30, subsets=[ "patches" ])
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0], [0], {}) model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0])
STREAM_CAST_BUFFERS.clear() STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear()
STREAM_PIN_BUFFERS.clear()
soft_empty_cache() soft_empty_cache()
def get_offload_stream(device): def get_offload_stream(device):
@ -1428,7 +1436,7 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
if hasattr(wf_context, "as_context"): if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream) wf_context = wf_context.as_context(stream)
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) if r is not None else [None] * len(tensors) dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
with wf_context: with wf_context:
for tensor in tensors: for tensor in tensors:
@ -1440,10 +1448,9 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
continue continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
mark_mmap_dirty(storage) mark_mmap_dirty(storage)
if dest_view is not None:
dest_view.copy_(tensor, non_blocking=non_blocking) dest_view.copy_(tensor, non_blocking=non_blocking)
if dest2_view is not None: if dest2_view is not None:
dest2_view.copy_(tensor if dest_view is None else dest_view, non_blocking=non_blocking) dest2_view.copy_(dest_view, non_blocking=non_blocking)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):

View File

@ -1721,8 +1721,8 @@ class ModelPatcherDynamic(ModelPatcher):
""" """
if device not in self.model.dynamic_pins: if device not in self.model.dynamic_pins:
self.model.dynamic_pins[device] = { self.model.dynamic_pins[device] = {
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}), "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}), "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"hostbufs_initialized": False, "hostbufs_initialized": False,
"failed": False, "failed": False,
"active": False, "active": False,
@ -1799,8 +1799,8 @@ class ModelPatcherDynamic(ModelPatcher):
pin_state = self.model.dynamic_pins[self.load_device] pin_state = self.model.dynamic_pins[self.load_device]
if not pin_state["hostbufs_initialized"]: if not pin_state["hostbufs_initialized"]:
hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size()) hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size())
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {}) pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {}) pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["hostbufs_initialized"] = True pin_state["hostbufs_initialized"] = True
pin_state["failed"] = False pin_state["failed"] = False
pin_state["active"] = True pin_state["active"] = True
@ -1942,16 +1942,18 @@ class ModelPatcherDynamic(ModelPatcher):
return freed return freed
def loaded_ram_size(self): def loaded_ram_size(self):
return (self.model.dynamic_pins[self.load_device]["weights"][0].size) return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
self.model.dynamic_pins[self.load_device]["patches"][0].size)
def pinned_memory_size(self): def pinned_memory_size(self):
return (self.model.dynamic_pins[self.load_device]["weights"][3][0]) return (self.model.dynamic_pins[self.load_device]["weights"][3][0] +
self.model.dynamic_pins[self.load_device]["patches"][3][0])
def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]): def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]):
freed = 0 freed = 0
pin_state = self.model.dynamic_pins[self.load_device] pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets: for subset in subsets:
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset] hostbuf, stack, stack_split, pinned_size = pin_state[subset]
split = stack_split[0] split = stack_split[0]
while split >= 0: while split >= 0:
module, offset = stack[split] module, offset = stack[split]
@ -1976,12 +1978,10 @@ class ModelPatcherDynamic(ModelPatcher):
freed = 0 freed = 0
pin_state = self.model.dynamic_pins[self.load_device] pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets: for subset in subsets:
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset] hostbuf, stack, stack_split, pinned_size = pin_state[subset]
while len(stack) > 0: while len(stack) > 0:
module, offset = stack.pop() module, offset = stack.pop()
size = module._pin.numel() * module._pin.element_size() size = module._pin.numel() * module._pin.element_size()
module._pin_balancer_entry[-1] = None
del module._pin_balancer_entry
del module._pin del module._pin
hostbuf.truncate(offset, do_unregister=module._pin_registered) hostbuf.truncate(offset, do_unregister=module._pin_registered)
stack_split[0] = min(stack_split[0], len(stack) - 1) stack_split[0] = min(stack_split[0], len(stack) - 1)

View File

@ -1,5 +1,4 @@
import comfy_aimdo.model_vbar import comfy_aimdo.model_vbar
import comfy.memory_management
import comfy.model_management import comfy.model_management
import comfy.ops import comfy.ops
@ -51,17 +50,7 @@ def prefetch_queue_pop(queue, device, module):
if hasattr(s, "_v"): if hasattr(s, "_v"):
comfy_modules.append(s) comfy_modules.append(s)
registerable_size = 0
for s in comfy_modules:
registerable_size += comfy.memory_management.vram_aligned_size([s.weight, s.bias])
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
registerable_size += lowvram_fn.memory_required()
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True) offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
if not comfy.model_management.args.fast_disk:
comfy.model_management.ensure_pin_registerable(registerable_size)
comfy.model_management.sync_stream(device, offload_stream) comfy.model_management.sync_stream(device, offload_stream)
queue[0] = (offload_stream, (prefetch, comfy_modules)) queue[0] = (offload_stream, (prefetch, comfy_modules))

View File

@ -76,6 +76,8 @@ except:
cast_to = comfy.model_management.cast_to #TODO: remove once no more references cast_to = comfy.model_management.cast_to #TODO: remove once no more references
STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024
def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@ -92,6 +94,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
offload_stream = None offload_stream = None
cast_buffer = None cast_buffer = None
cast_buffer_offset = 0 cast_buffer_offset = 0
stream_pin_hostbuf = None
stream_pin_offset = 0
stream_pin_queue = []
def ensure_offload_stream(module, required_size, check_largest): def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream nonlocal offload_stream
@ -125,6 +130,22 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
cast_buffer_offset += buffer_size cast_buffer_offset += buffer_size
return buffer return buffer
def get_stream_pin_buffer_offset(buffer_size):
nonlocal stream_pin_hostbuf
nonlocal stream_pin_offset
if buffer_size == 0 or offload_stream is None:
return None
if stream_pin_hostbuf is None:
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
if stream_pin_hostbuf is None:
return None
offset = stream_pin_offset
stream_pin_offset += buffer_size
return offset
for s in comfy_modules: for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v) signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
@ -163,18 +184,12 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
if xfer_dest is None: if xfer_dest is None:
xfer_dest = get_cast_buffer(dest_size) xfer_dest = get_cast_buffer(dest_size)
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream, xfer_dest2=None): def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
if xfer_source is not None: if xfer_source is not None:
if getattr(xfer_source, "is_lowvram_patch", False): if getattr(xfer_source, "is_lowvram_patch", False):
if xfer_dest is not None:
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False) xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
xfer_source = [ xfer_dest ] else:
xfer_dest = xfer_dest2 comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
xfer_dest2 = None
elif xfer_dest2 is not None:
xfer_source.prepare(xfer_dest2, stream, copy=True, commit=False)
return
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream, r2=xfer_dest2)
def handle_pin(m, pin, source, dest, subset="weights", size=None): def handle_pin(m, pin, source, dest, subset="weights", size=None):
if pin is not None: if pin is not None:
@ -183,7 +198,19 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
if signature is None: if signature is None:
comfy.pinned_memory.pin_memory(m, subset=subset, size=size) comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
pin = comfy.pinned_memory.get_pin(m, subset=subset) pin = comfy.pinned_memory.get_pin(m, subset=subset)
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest) if pin is not None:
if isinstance(source, list):
comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest)
else:
cast_maybe_lowvram_patch(source, pin, None)
cast_maybe_lowvram_patch([ pin ], dest, offload_stream)
return
if pin is None:
pin_offset = get_stream_pin_buffer_offset(size)
if pin_offset is not None:
stream_pin_queue.append((source, pin_offset, size, dest))
return
cast_maybe_lowvram_patch(source, dest, offload_stream)
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size) handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
@ -205,6 +232,23 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
prefetch["needs_cast"] = needs_cast prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch s._prefetch = prefetch
if stream_pin_offset > 0:
if stream_pin_hostbuf.size < stream_pin_offset:
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
for xfer_source, _, _, xfer_dest in stream_pin_queue:
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
return offload_stream
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf)
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
if isinstance(xfer_source, list):
comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest)
else:
cast_maybe_lowvram_patch(xfer_source, pin, None)
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
stream_pin_hostbuf._comfy_event = offload_stream.record_event()
return offload_stream return offload_stream

View File

@ -1,55 +1,17 @@
import bisect
import comfy.model_management import comfy.model_management
import comfy.memory_management import comfy.memory_management
import comfy.utils
import comfy_aimdo.host_buffer import comfy_aimdo.host_buffer
import comfy_aimdo.torch import comfy_aimdo.torch
import torch import torch
from comfy.cli_args import args from comfy.cli_args import args
def _add_to_bucket(module, buckets, size, priority):
bucket = buckets.setdefault(size, [])
entry = [-priority, 0, module]
entry[1] = id(entry)
bisect.insort(bucket, entry)
module._pin_balancer_entry = entry
def _steal_pin(module, stack, buckets, size, priority):
bucket = buckets.get(size)
if bucket is None:
return False
while bucket and bucket[-1][-1] is None:
bucket.pop()
if not bucket:
del buckets[size]
return False
if priority <= -bucket[-1][0]:
return False
*_, victim = bucket.pop()
module._pin = victim._pin
module._pin_registered = victim._pin_registered
module._pin_stack_index = victim._pin_stack_index
stack[module._pin_stack_index] = (module, stack[module._pin_stack_index][1])
victim._pin_registered = False
del victim._pin
del victim._pin_stack_index
del victim._pin_balancer_entry
_add_to_bucket(module, buckets, size, priority)
return True
def get_pin(module, subset="weights"): def get_pin(module, subset="weights"):
pin = getattr(module, "_pin", None) pin = getattr(module, "_pin", None)
if pin is None or module._pin_registered or args.disable_pinned_memory: if pin is None or module._pin_registered or args.disable_pinned_memory:
return pin return pin
_, _, stack_split, pinned_size, *_ = module._pin_state[subset] _, _, stack_split, pinned_size = module._pin_state[subset]
size = pin.nbytes size = pin.nbytes
comfy.model_management.ensure_pin_registerable(size) comfy.model_management.ensure_pin_registerable(size)
@ -69,30 +31,26 @@ def pin_memory(module, subset="weights", size=None):
return return
pin = get_pin(module, subset) pin = get_pin(module, subset)
if pin is not None: if pin is not None or pin_state["failed"]:
return return
hostbuf, stack, stack_split, pinned_size, counter, buckets = pin_state[subset] hostbuf, stack, stack_split, pinned_size = pin_state[subset]
if size is None: if size is None:
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
offset = hostbuf.size offset = hostbuf.size
registerable_size = size registerable_size = size + max(0, hostbuf.size - pinned_size[0])
priority = getattr(module, "_pin_balancer_priority", None)
if priority is None:
priority = comfy.utils.bit_reverse_range(counter[0], 16)
counter[0] += 1
module._pin_balancer_priority = priority
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
if (not comfy.model_management.ensure_pin_budget(size) or if (not comfy.model_management.ensure_pin_budget(size) or
not comfy.model_management.ensure_pin_registerable(registerable_size)): not comfy.model_management.ensure_pin_registerable(registerable_size)):
return _steal_pin(module, stack, buckets, size, priority) pin_state["failed"] = True
return False
try: try:
hostbuf.extend(size=size) hostbuf.extend(size=size)
except RuntimeError: except RuntimeError:
return _steal_pin(module, stack, buckets, size, priority) pin_state["failed"] = True
return False
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size] module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
module._pin.untyped_storage()._comfy_hostbuf = hostbuf module._pin.untyped_storage()._comfy_hostbuf = hostbuf
@ -102,5 +60,4 @@ def pin_memory(module, subset="weights", size=None):
stack_split[0] = max(stack_split[0], module._pin_stack_index) stack_split[0] = max(stack_split[0], module._pin_stack_index)
comfy.model_management.TOTAL_PINNED_MEMORY += size comfy.model_management.TOTAL_PINNED_MEMORY += size
pinned_size[0] += size pinned_size[0] += size
_add_to_bucket(module, buckets, size, priority)
return True return True

View File

@ -85,9 +85,9 @@ _TYPES = {
def load_safetensors(ckpt): def load_safetensors(ckpt):
import comfy_aimdo.model_mmap import comfy_aimdo.model_mmap
f = open(ckpt, "rb", buffering=0)
file_lock = threading.Lock() file_lock = threading.Lock()
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt) model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
f = model_mmap.get_file_handle()
file_size = os.path.getsize(ckpt) file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get())) mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
@ -1452,10 +1452,3 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res memo[obj_id] = res
return res return res
def bit_reverse_range(index, bits):
result = 0
for _ in range(bits):
result = (result << 1) | (index & 1)
index >>= 1
return result

View File

@ -727,30 +727,6 @@ class File3DUSDZ(ComfyTypeIO):
Type = File3D Type = File3D
@comfytype(io_type="FILE_3D_PLY")
class File3DPLY(ComfyTypeIO):
"""PLY format 3D file - point cloud or Gaussian splat."""
Type = File3D
@comfytype(io_type="FILE_3D_SPLAT")
class File3DSPLAT(ComfyTypeIO):
"""SPLAT format 3D file - 3D Gaussian splat."""
Type = File3D
@comfytype(io_type="FILE_3D_SPZ")
class File3DSPZ(ComfyTypeIO):
"""SPZ format 3D file - compressed 3D Gaussian splat."""
Type = File3D
@comfytype(io_type="FILE_3D_KSPLAT")
class File3DKSPLAT(ComfyTypeIO):
"""KSPLAT format 3D file - 3D Gaussian splat."""
Type = File3D
@comfytype(io_type="HOOKS") @comfytype(io_type="HOOKS")
class Hooks(ComfyTypeIO): class Hooks(ComfyTypeIO):
if TYPE_CHECKING: if TYPE_CHECKING:
@ -2327,10 +2303,6 @@ __all__ = [
"File3DOBJ", "File3DOBJ",
"File3DSTL", "File3DSTL",
"File3DUSDZ", "File3DUSDZ",
"File3DPLY",
"File3DSPLAT",
"File3DSPZ",
"File3DKSPLAT",
"Hooks", "Hooks",
"HookKeyframes", "HookKeyframes",
"TimestepsRange", "TimestepsRange",

21018
openapi.yaml

File diff suppressed because it is too large Load Diff

View File

@ -22,8 +22,8 @@ alembic
SQLAlchemy>=2.0.0 SQLAlchemy>=2.0.0
filelock filelock
av>=16.0.0 av>=16.0.0
comfy-kitchen==0.2.10 comfy-kitchen==0.2.9
comfy-aimdo==0.4.7 comfy-aimdo==0.4.5
requests requests
simpleeval>=1.0.0 simpleeval>=1.0.0
blake3 blake3