mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-31 21:25:59 +08:00
* memory_management: Add direct to read GPU mode Make destination optional (or make it optionally GPU) and use aimdo to file_read direct to GPU. * ops: Remove stream pin buffers and use aimdo reads This consumed too much RAM and its better to just take the hit on the CPU syncing back the stream on a short ring buffer. Aimdo implements this so just rip the stream pin buffer from comfy. * model_management: all active pin registration movement Its better to just let the active model load past the pin limit as pins and let the pins move around. The saves the HDD and SATA people disk traffic while only costing a few GPU syncs. * utils: use aimdo file handle This opens on windows with more favourable flags * mp: only count the model proper for loaded_ram and vram Exclude live loras from the numbers to avoid the case where the reported loaded memory exceeds the size of the model. This causes me confusion in the Kijai visualizer when it looked fully loaded but was hitting disk due to this accounding disrepency. * utils: add bit reverse utility useful for max scattering something ordered. * pinned_memory: Implement offload balancing Use a max scatter alogorithm to prioritize pins of the same size such that when doing a little bit of offloading it gets scattered, allowing the prefetcher to more evenly swollow the offload. * comfy-aimdo 0.4.7 Aimdo 0.4.7 implement VRAM buffer exhaustion predection to avoid early speculative load of weights that definately wont fix once the inference gets further in. * model-prefetch: consolidate pin ensures on the sync point This could happen mid prefetch block, cause a sync of the entire block and lose overlap. Get ahead of the problem with a free down at the natural compute stream sync point. * mm: Put a 2GB min on the pin ceiling This is reasonably bad if it starts causing swap pressure, moreso than during normal ram-cache proceedings. Clamp it. * add --fast-disk
1502 lines
66 KiB
Python
1502 lines
66 KiB
Python
"""
|
|
This file is part of ComfyUI.
|
|
Copyright (C) 2024 Stability AI
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
import torch
|
|
import logging
|
|
import contextlib
|
|
import comfy.model_management
|
|
from comfy.cli_args import args, PerformanceFeature
|
|
import comfy.float
|
|
import json
|
|
import comfy.memory_management
|
|
import comfy.pinned_memory
|
|
import comfy.utils
|
|
|
|
import comfy_aimdo.model_vbar
|
|
import comfy_aimdo.torch
|
|
|
|
def run_every_op():
|
|
if torch.compiler.is_compiling():
|
|
return
|
|
|
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
|
|
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
|
|
|
|
|
try:
|
|
if torch.cuda.is_available() and comfy.model_management.WINDOWS:
|
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
import inspect
|
|
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
|
SDPA_BACKEND_PRIORITY = [
|
|
SDPBackend.FLASH_ATTENTION,
|
|
SDPBackend.EFFICIENT_ATTENTION,
|
|
SDPBackend.MATH,
|
|
]
|
|
|
|
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
|
|
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
|
if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower
|
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
|
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
|
else:
|
|
logging.warning("Torch version too old to set sdpa backend priority.")
|
|
except (ModuleNotFoundError, TypeError):
|
|
logging.warning("Could not set sdpa backend priority.")
|
|
|
|
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
|
|
try:
|
|
if comfy.model_management.is_nvidia():
|
|
cudnn_version = torch.backends.cudnn.version()
|
|
if (cudnn_version >= 91002 and cudnn_version < 91500) and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
|
|
#TODO: change upper bound version once it's fixed'
|
|
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
|
|
logging.info("working around nvidia conv3d memory bug.")
|
|
except:
|
|
pass
|
|
|
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
|
|
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
|
|
|
|
|
def materialize_meta_param(s, param_keys):
|
|
for param_key in param_keys:
|
|
param = getattr(s, param_key, None)
|
|
if param is not None and getattr(param, "is_meta", False):
|
|
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
|
|
|
|
|
|
# FIXME: add n=1 cache hit fast path
|
|
def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
|
|
offload_stream = None
|
|
cast_buffer = None
|
|
cast_buffer_offset = 0
|
|
|
|
def ensure_offload_stream(module, required_size, check_largest):
|
|
nonlocal offload_stream
|
|
nonlocal cast_buffer
|
|
|
|
if offload_stream is None:
|
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
|
if offload_stream is None or not check_largest or len(comfy_modules) != 1:
|
|
return
|
|
|
|
current_size = 0 if cast_buffer is None else cast_buffer.size()
|
|
if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
|
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
|
cast_buffer = None
|
|
if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
|
|
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
|
|
|
|
def get_cast_buffer(buffer_size):
|
|
nonlocal offload_stream
|
|
nonlocal cast_buffer
|
|
nonlocal cast_buffer_offset
|
|
|
|
if buffer_size == 0:
|
|
return None
|
|
|
|
if offload_stream is None:
|
|
return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
|
|
|
|
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
|
|
buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
|
|
cast_buffer_offset += buffer_size
|
|
return buffer
|
|
|
|
for s in comfy_modules:
|
|
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
|
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
|
prefetch = {
|
|
"signature": signature,
|
|
"resident": resident,
|
|
}
|
|
|
|
if resident:
|
|
s._prefetch = prefetch
|
|
continue
|
|
|
|
materialize_meta_param(s, ["weight", "bias"])
|
|
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
|
|
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
|
cast_dest = None
|
|
needs_cast = False
|
|
|
|
xfer_source = [ s.weight, s.bias ]
|
|
|
|
pin = comfy.pinned_memory.get_pin(s)
|
|
if pin is not None:
|
|
xfer_source = [ pin ]
|
|
|
|
for data, geometry in zip([ s.weight, s.bias ], cast_geometry):
|
|
if data is None:
|
|
continue
|
|
if data.dtype != geometry.dtype:
|
|
needs_cast = True
|
|
cast_dest = xfer_dest
|
|
xfer_dest = None
|
|
break
|
|
|
|
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
|
ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
|
|
if xfer_dest is None:
|
|
xfer_dest = get_cast_buffer(dest_size)
|
|
|
|
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream, xfer_dest2=None):
|
|
if xfer_source is not None:
|
|
if getattr(xfer_source, "is_lowvram_patch", False):
|
|
if xfer_dest is not None:
|
|
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
|
|
xfer_source = [ xfer_dest ]
|
|
xfer_dest = xfer_dest2
|
|
xfer_dest2 = None
|
|
elif xfer_dest2 is not None:
|
|
xfer_source.prepare(xfer_dest2, stream, copy=True, commit=False)
|
|
return
|
|
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream, r2=xfer_dest2)
|
|
|
|
def handle_pin(m, pin, source, dest, subset="weights", size=None):
|
|
if pin is not None:
|
|
cast_maybe_lowvram_patch([pin], dest, offload_stream)
|
|
return
|
|
if signature is None:
|
|
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
|
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
|
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest)
|
|
|
|
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
|
|
|
|
for param_key in ("weight", "bias"):
|
|
lowvram_source = getattr(s, param_key + "_lowvram_function", None)
|
|
if lowvram_source is not None:
|
|
ensure_offload_stream(s, cast_buffer_offset, False)
|
|
lowvram_size = lowvram_source.memory_required()
|
|
lowvram_dest = get_cast_buffer(lowvram_size)
|
|
lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True)
|
|
|
|
pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches")
|
|
handle_pin(lowvram_source, pin, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size)
|
|
|
|
|
|
prefetch["xfer_dest"] = xfer_dest
|
|
prefetch["cast_dest"] = cast_dest
|
|
prefetch["cast_geometry"] = cast_geometry
|
|
prefetch["needs_cast"] = needs_cast
|
|
s._prefetch = prefetch
|
|
|
|
return offload_stream
|
|
|
|
|
|
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
|
|
|
|
prefetch = getattr(s, "_prefetch", None)
|
|
|
|
if prefetch["resident"]:
|
|
weight = s._v_weight
|
|
bias = s._v_bias
|
|
else:
|
|
xfer_dest = prefetch["xfer_dest"]
|
|
if prefetch["needs_cast"]:
|
|
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
|
|
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
|
comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
|
|
if post_cast is not None:
|
|
post_cast.copy_(pre_cast)
|
|
xfer_dest = cast_dest
|
|
|
|
params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
|
|
weight = params[0]
|
|
bias = params[1]
|
|
if prefetch["signature"] is not None:
|
|
s._v_weight = weight
|
|
s._v_bias = bias
|
|
s._v_signature = prefetch["signature"]
|
|
|
|
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
|
fns = getattr(s, param_key + "_function", [])
|
|
|
|
if x is None:
|
|
return None
|
|
|
|
orig = x
|
|
|
|
def to_dequant(tensor, dtype):
|
|
tensor = tensor.to(dtype=dtype)
|
|
if isinstance(tensor, QuantizedTensor):
|
|
tensor = tensor.dequantize()
|
|
return tensor
|
|
|
|
if orig.dtype != dtype or len(fns) > 0:
|
|
x = to_dequant(x, dtype)
|
|
if not resident and lowvram_fn is not None:
|
|
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
|
|
x = lowvram_fn(x)
|
|
if (want_requant and len(fns) == 0 or update_weight):
|
|
seed = comfy.utils.string_to_seed(s.seed_key)
|
|
if isinstance(orig, QuantizedTensor):
|
|
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
|
else:
|
|
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
|
|
if want_requant and len(fns) == 0:
|
|
x = y
|
|
if update_weight:
|
|
orig.copy_(y)
|
|
for f in fns:
|
|
x = f(x)
|
|
return x
|
|
|
|
update_weight = prefetch["signature"] is not None
|
|
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
|
|
if bias is not None:
|
|
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
|
|
|
|
if prefetch["signature"] is not None:
|
|
prefetch["resident"] = True
|
|
|
|
return weight, bias
|
|
|
|
|
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
|
# NOTE: offloadable=False is a legacy mode and if you are a custom node author reading this please pass
|
|
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
|
# will add async-offload support to your cast and improve performance.
|
|
if input is not None:
|
|
if dtype is None:
|
|
if isinstance(input, QuantizedTensor):
|
|
dtype = input.params.orig_dtype
|
|
else:
|
|
dtype = input.dtype
|
|
if bias_dtype is None:
|
|
bias_dtype = dtype
|
|
if device is None:
|
|
device = input.device
|
|
|
|
def format_return(result, offloadable):
|
|
weight, bias, offload_stream = result
|
|
return (weight, bias, offload_stream) if offloadable else (weight, bias)
|
|
|
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
|
|
|
if hasattr(s, "_v"):
|
|
|
|
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
|
#that might switch the layer to the CPU and expect it to work. We have to take
|
|
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
|
#If you are a custom node author reading this, please move your layer to the GPU
|
|
#or declare your ModelPatcher as CPU in the first place.
|
|
if comfy.model_management.is_device_cpu(device):
|
|
materialize_meta_param(s, ["weight", "bias"])
|
|
weight = s.weight.to(dtype=dtype, copy=True)
|
|
if isinstance(weight, QuantizedTensor):
|
|
weight = weight.dequantize()
|
|
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
|
return format_return((weight, bias, (None, None, None)), offloadable)
|
|
|
|
prefetched = hasattr(s, "_prefetch")
|
|
offload_stream = None
|
|
offload_device = None
|
|
if not prefetched:
|
|
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
|
|
comfy.model_management.sync_stream(device, offload_stream)
|
|
|
|
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
|
|
|
|
if not prefetched:
|
|
if getattr(s, "_prefetch")["signature"] is not None:
|
|
offload_device = device
|
|
for param_key in ("weight", "bias"):
|
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
|
if lowvram_fn is not None:
|
|
lowvram_fn.clear_prepared()
|
|
delattr(s, "_prefetch")
|
|
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
|
|
|
|
|
|
if offloadable and (device != s.weight.device or
|
|
(s.bias is not None and device != s.bias.device)):
|
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
|
else:
|
|
offload_stream = None
|
|
|
|
bias = None
|
|
weight = None
|
|
|
|
if offload_stream is not None and not args.cuda_malloc:
|
|
cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ])
|
|
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
|
#The streams can be uneven in buffer capability and reject us. Retry to get the other stream
|
|
if cast_buffer is None:
|
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
|
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
|
params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer)
|
|
weight = params[0]
|
|
bias = params[1]
|
|
|
|
weight_has_function = len(s.weight_function) > 0
|
|
bias_has_function = len(s.bias_function) > 0
|
|
|
|
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight)
|
|
|
|
if s.bias is not None:
|
|
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias)
|
|
|
|
comfy.model_management.sync_stream(device, offload_stream)
|
|
|
|
bias_a = bias
|
|
weight_a = weight
|
|
|
|
if s.bias is not None:
|
|
bias = bias.to(dtype=bias_dtype)
|
|
for f in s.bias_function:
|
|
bias = f(bias)
|
|
|
|
if weight_has_function or weight.dtype != dtype:
|
|
weight = weight.to(dtype=dtype)
|
|
if isinstance(weight, QuantizedTensor):
|
|
weight = weight.dequantize()
|
|
for f in s.weight_function:
|
|
weight = f(weight)
|
|
|
|
return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
|
|
|
|
|
|
def uncast_bias_weight(s, weight, bias, offload_stream):
|
|
if offload_stream is None:
|
|
return
|
|
os, weight_a, bias_a = offload_stream
|
|
device=None
|
|
#FIXME: This is really bad RTTI
|
|
if weight_a is not None and not isinstance(weight_a, torch.Tensor):
|
|
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
|
device = weight_a
|
|
if os is None:
|
|
return
|
|
if device is None:
|
|
if weight_a is not None:
|
|
device = weight_a.device
|
|
else:
|
|
if bias_a is None:
|
|
return
|
|
device = bias_a.device
|
|
os.wait_stream(comfy.model_management.current_stream(device))
|
|
|
|
|
|
class CastWeightBiasOp:
|
|
comfy_cast_weights = False
|
|
weight_function = []
|
|
bias_function = []
|
|
|
|
class disable_weight_init:
|
|
@staticmethod
|
|
def _zero_init_parameter(module, name):
|
|
param = getattr(module, name)
|
|
device = None if getattr(param, "is_meta", False) else param.device
|
|
setattr(module, name, torch.nn.Parameter(torch.zeros(param.shape, device=device, dtype=param.dtype), requires_grad=False))
|
|
|
|
@staticmethod
|
|
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
|
|
missing_keys, unexpected_keys, weight_shape,
|
|
bias_shape=None):
|
|
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
|
prefix_len = len(prefix)
|
|
for k, v in state_dict.items():
|
|
key = k[prefix_len:]
|
|
if key == "weight":
|
|
if not assign_to_params_buffers:
|
|
v = v.clone()
|
|
module.weight = torch.nn.Parameter(v, requires_grad=False)
|
|
elif bias_shape is not None and key == "bias" and v is not None:
|
|
if not assign_to_params_buffers:
|
|
v = v.clone()
|
|
module.bias = torch.nn.Parameter(v, requires_grad=False)
|
|
else:
|
|
unexpected_keys.append(k)
|
|
|
|
if module.weight is None:
|
|
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
|
|
missing_keys.append(prefix + "weight")
|
|
|
|
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
|
|
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
|
|
missing_keys.append(prefix + "bias")
|
|
|
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
|
|
|
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
|
# don't trust subclasses that BYO state dict loader to call us.
|
|
if (not comfy.model_management.WINDOWS
|
|
or not comfy.memory_management.aimdo_enabled
|
|
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
|
|
super().__init__(in_features, out_features, bias, device, dtype)
|
|
return
|
|
|
|
# Issue is with `torch.empty` still reserving the full memory for the layer.
|
|
# Windows doesn't over-commit memory so without this, We are momentarily commit
|
|
# charged for the weight even though we might zero-copy it when we load the
|
|
# state dict. If the commit charge exceeds the ceiling we can destabilize the
|
|
# system.
|
|
torch.nn.Module.__init__(self)
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.weight = None
|
|
self.bias = None
|
|
self.comfy_need_lazy_init_bias=bias
|
|
self.weight_comfy_model_dtype = dtype
|
|
self.bias_comfy_model_dtype = dtype
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
|
strict, missing_keys, unexpected_keys, error_msgs):
|
|
|
|
if (not comfy.model_management.WINDOWS
|
|
or not comfy.memory_management.aimdo_enabled
|
|
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
|
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
disable_weight_init._lazy_load_from_state_dict(
|
|
self,
|
|
state_dict,
|
|
prefix,
|
|
local_metadata,
|
|
missing_keys,
|
|
unexpected_keys,
|
|
weight_shape=(self.in_features, self.out_features),
|
|
bias_shape=(self.out_features,),
|
|
)
|
|
|
|
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = torch.nn.functional.linear(input, weight, bias)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = self._conv_forward(input, weight, bias)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = self._conv_forward(input, weight, bias)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs):
|
|
if autopad == "causal_zero":
|
|
weight = weight[:, :, -input.shape[2]:, :, :]
|
|
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
|
|
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
|
|
if bias is not None:
|
|
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
|
|
return out
|
|
else:
|
|
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
|
|
|
def forward_comfy_cast_weights(self, input, autopad=None):
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = self._conv_forward(input, weight, bias, autopad=autopad)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None
|
|
running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None
|
|
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
if self.weight is not None:
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
else:
|
|
weight = None
|
|
bias = None
|
|
offload_stream = None
|
|
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
self.bias = None
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
if self.weight is not None:
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
else:
|
|
weight = None
|
|
bias = None
|
|
offload_stream = None
|
|
x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input, output_size=None):
|
|
num_spatial_dims = 2
|
|
output_padding = self._output_padding(
|
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
|
num_spatial_dims, self.dilation)
|
|
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = torch.nn.functional.conv_transpose2d(
|
|
input, weight, bias, self.stride, self.padding,
|
|
output_padding, self.groups, self.dilation)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input, output_size=None):
|
|
num_spatial_dims = 1
|
|
output_padding = self._output_padding(
|
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
|
num_spatial_dims, self.dilation)
|
|
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = torch.nn.functional.conv_transpose1d(
|
|
input, weight, bias, self.stride, self.padding,
|
|
output_padding, self.groups, self.dilation)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
|
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
|
|
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
|
|
_freeze=False, device=None, dtype=None):
|
|
# don't trust subclasses that BYO state dict loader to call us.
|
|
if (not comfy.model_management.WINDOWS
|
|
or not comfy.memory_management.aimdo_enabled
|
|
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
|
|
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
|
|
norm_type, scale_grad_by_freq, sparse, _weight,
|
|
_freeze, device, dtype)
|
|
return
|
|
|
|
torch.nn.Module.__init__(self)
|
|
self.num_embeddings = num_embeddings
|
|
self.embedding_dim = embedding_dim
|
|
self.padding_idx = padding_idx
|
|
self.max_norm = max_norm
|
|
self.norm_type = norm_type
|
|
self.scale_grad_by_freq = scale_grad_by_freq
|
|
self.sparse = sparse
|
|
# Keep shape/dtype visible for module introspection without reserving storage.
|
|
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
|
|
self.weight = torch.nn.Parameter(
|
|
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
|
|
requires_grad=False,
|
|
)
|
|
self.bias = None
|
|
self.weight_comfy_model_dtype = dtype
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
|
strict, missing_keys, unexpected_keys, error_msgs):
|
|
|
|
if (not comfy.model_management.WINDOWS
|
|
or not comfy.memory_management.aimdo_enabled
|
|
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
|
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
disable_weight_init._lazy_load_from_state_dict(
|
|
self,
|
|
state_dict,
|
|
prefix,
|
|
local_metadata,
|
|
missing_keys,
|
|
unexpected_keys,
|
|
weight_shape=(self.num_embeddings, self.embedding_dim),
|
|
)
|
|
|
|
def reset_parameters(self):
|
|
self.bias = None
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
|
output_dtype = out_dtype
|
|
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
|
out_dtype = None
|
|
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
|
|
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
if "out_dtype" in kwargs:
|
|
kwargs.pop("out_dtype")
|
|
return super().forward(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def conv_nd(s, dims, *args, **kwargs):
|
|
if dims == 2:
|
|
return s.Conv2d(*args, **kwargs)
|
|
elif dims == 3:
|
|
return s.Conv3d(*args, **kwargs)
|
|
else:
|
|
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
|
|
|
class manual_cast(disable_weight_init):
|
|
class Linear(disable_weight_init.Linear):
|
|
comfy_cast_weights = True
|
|
|
|
class Conv1d(disable_weight_init.Conv1d):
|
|
comfy_cast_weights = True
|
|
|
|
class Conv2d(disable_weight_init.Conv2d):
|
|
comfy_cast_weights = True
|
|
|
|
class Conv3d(disable_weight_init.Conv3d):
|
|
comfy_cast_weights = True
|
|
|
|
class BatchNorm2d(disable_weight_init.BatchNorm2d):
|
|
comfy_cast_weights = True
|
|
|
|
class GroupNorm(disable_weight_init.GroupNorm):
|
|
comfy_cast_weights = True
|
|
|
|
class LayerNorm(disable_weight_init.LayerNorm):
|
|
comfy_cast_weights = True
|
|
|
|
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
|
|
comfy_cast_weights = True
|
|
|
|
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
|
comfy_cast_weights = True
|
|
|
|
class RMSNorm(disable_weight_init.RMSNorm):
|
|
comfy_cast_weights = True
|
|
|
|
class Embedding(disable_weight_init.Embedding):
|
|
comfy_cast_weights = True
|
|
|
|
|
|
def fp8_linear(self, input):
|
|
"""
|
|
Legacy FP8 linear function for backward compatibility.
|
|
Uses QuantizedTensor subclass for dispatch.
|
|
"""
|
|
dtype = self.weight.dtype
|
|
if dtype not in [torch.float8_e4m3fn]:
|
|
return None
|
|
|
|
input_dtype = input.dtype
|
|
input_shape = input.shape
|
|
tensor_3d = input.ndim == 3
|
|
|
|
if tensor_3d:
|
|
input = input.reshape(-1, input_shape[2])
|
|
|
|
if input.ndim != 2:
|
|
return None
|
|
lora_compute_dtype=comfy.model_management.lora_compute_dtype(input.device)
|
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True, compute_dtype=lora_compute_dtype, want_requant=True)
|
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
|
|
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
|
input_fp8 = input.to(dtype).contiguous()
|
|
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
|
|
quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input)
|
|
|
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
|
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
|
|
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
|
|
|
uncast_bias_weight(self, w, bias, offload_stream)
|
|
if tensor_3d:
|
|
o = o.reshape((input_shape[0], input_shape[1], w.shape[0]))
|
|
|
|
return o
|
|
|
|
class fp8_ops(manual_cast):
|
|
class Linear(manual_cast.Linear):
|
|
def reset_parameters(self):
|
|
self.scale_weight = None
|
|
self.scale_input = None
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
if len(self.weight_function) == 0 and len(self.bias_function) == 0:
|
|
try:
|
|
out = fp8_linear(self, input)
|
|
if out is not None:
|
|
return out
|
|
except Exception as e:
|
|
logging.info("Exception during fp8 op: {}".format(e))
|
|
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = torch.nn.functional.linear(input, weight, bias)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
CUBLAS_IS_AVAILABLE = False
|
|
try:
|
|
from cublas_ops import CublasLinear, cublas_half_matmul
|
|
CUBLAS_IS_AVAILABLE = True
|
|
except ImportError:
|
|
pass
|
|
|
|
if CUBLAS_IS_AVAILABLE:
|
|
class cublas_ops(manual_cast):
|
|
class Linear(CublasLinear, manual_cast.Linear):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
# ==============================================================================
|
|
# Mixed Precision Operations
|
|
# ==============================================================================
|
|
from .quant_ops import (
|
|
QuantizedTensor,
|
|
QUANT_ALGOS,
|
|
TensorCoreFP8Layout,
|
|
get_layout_class,
|
|
)
|
|
|
|
|
|
class QuantLinearFunc(torch.autograd.Function):
|
|
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
|
|
|
|
When training_fp8_bwd is enabled:
|
|
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
|
|
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
|
|
- Cached input is FP8 (half the memory of bf16)
|
|
|
|
When training_fp8_bwd is disabled:
|
|
- Forward: quantize input per layout, use quantized matmul
|
|
- Backward: dequantize weight to compute_dtype, use standard matmul
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
|
|
input_shape = input_float.shape
|
|
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
|
|
|
# Quantize input for forward (same layout as weight)
|
|
if layout_type is not None:
|
|
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
|
else:
|
|
q_input = inp
|
|
|
|
w = weight.detach() if weight.requires_grad else weight
|
|
b = bias.detach() if bias is not None and bias.requires_grad else bias
|
|
|
|
output = torch.nn.functional.linear(q_input, w, b)
|
|
|
|
# Unflatten output to match original input shape
|
|
if len(input_shape) > 2:
|
|
output = output.unflatten(0, input_shape[:-1])
|
|
|
|
# Save for backward
|
|
ctx.input_shape = input_shape
|
|
ctx.has_bias = bias is not None
|
|
ctx.compute_dtype = compute_dtype
|
|
ctx.weight_requires_grad = weight.requires_grad
|
|
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
|
|
|
|
if ctx.fp8_bwd:
|
|
# Cache FP8 quantized input — half the memory of bf16
|
|
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
|
|
ctx.q_input = q_input # already FP8, reuse
|
|
else:
|
|
# NVFP4 or other layout — quantize input to FP8 for backward
|
|
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
|
|
ctx.save_for_backward(weight)
|
|
else:
|
|
ctx.q_input = None
|
|
ctx.save_for_backward(input_float, weight)
|
|
|
|
return output
|
|
|
|
@staticmethod
|
|
@torch.autograd.function.once_differentiable
|
|
def backward(ctx, grad_output):
|
|
compute_dtype = ctx.compute_dtype
|
|
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
|
|
|
# Value casting — only difference between fp8 and non-fp8 paths
|
|
if ctx.fp8_bwd:
|
|
weight, = ctx.saved_tensors
|
|
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
|
|
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
|
|
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
|
|
weight_mm = weight
|
|
elif isinstance(weight, QuantizedTensor):
|
|
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
|
else:
|
|
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
|
input_mm = ctx.q_input
|
|
else:
|
|
input_float, weight = ctx.saved_tensors
|
|
# Standard tensors → torch.mm does regular matmul
|
|
grad_mm = grad_2d
|
|
if isinstance(weight, QuantizedTensor):
|
|
weight_mm = weight.dequantize().to(compute_dtype)
|
|
else:
|
|
weight_mm = weight.to(compute_dtype)
|
|
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
|
|
|
|
# Computation — same for both paths, dispatch handles the rest
|
|
grad_input = torch.mm(grad_mm, weight_mm)
|
|
if len(ctx.input_shape) > 2:
|
|
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
|
|
|
grad_weight = None
|
|
if ctx.weight_requires_grad:
|
|
grad_weight = torch.mm(grad_mm.t(), input_mm)
|
|
|
|
grad_bias = None
|
|
if ctx.has_bias:
|
|
grad_bias = grad_2d.sum(dim=0)
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None, None
|
|
|
|
# Quantized-weight module helpers
|
|
|
|
def _quantized_apply(module, fn, recurse=True):
|
|
"""Re-wrap Parameters after fn so .to()/.cuda() propagate through QuantizedTensor weights."""
|
|
if recurse:
|
|
for child in module.children():
|
|
child._apply(fn)
|
|
for key, param in module._parameters.items():
|
|
if param is None:
|
|
continue
|
|
p = fn(param)
|
|
if (not torch.is_inference_mode_enabled()) and p.is_inference():
|
|
p = p.clone()
|
|
module.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
|
for key, buf in module._buffers.items():
|
|
if buf is not None:
|
|
module._buffers[key] = fn(buf)
|
|
return module
|
|
|
|
|
|
def _load_quantized_module(module, super_load, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs, load_extra_params=False):
|
|
"""Shared _load_from_state_dict body for quantized-weight modules.
|
|
|
|
Pops weight (+ scales, +/- extras), populates module.weight as a Parameter
|
|
or Parameter-wrapped QuantizedTensor, then calls super_load and strips
|
|
consumed keys from missing_keys. Reads compute_dtype from factory_kwargs
|
|
and disabled formats from module._disabled_formats.
|
|
"""
|
|
device = module.factory_kwargs["device"]
|
|
compute_dtype = module.factory_kwargs["dtype"]
|
|
disabled_formats = module._disabled_formats
|
|
layer_name = prefix.rstrip('.')
|
|
|
|
weight = state_dict.pop(f"{prefix}weight", None)
|
|
if weight is None:
|
|
logging.warning(f"Missing weight for layer {layer_name}")
|
|
module.weight = None
|
|
return
|
|
manually_loaded_keys = [f"{prefix}weight"]
|
|
|
|
def pop_scale(name, dtype=None):
|
|
key = f"{prefix}{name}"
|
|
v = state_dict.pop(key, None)
|
|
if v is not None:
|
|
v = v.to(device=device)
|
|
if dtype is not None:
|
|
v = v.view(dtype=dtype)
|
|
manually_loaded_keys.append(key)
|
|
return v
|
|
|
|
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
|
if layer_conf is not None:
|
|
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
|
|
|
if layer_conf is None:
|
|
module.weight = torch.nn.Parameter(weight.to(device=device, dtype=compute_dtype), requires_grad=False)
|
|
else:
|
|
module.quant_format = layer_conf.get("format", None)
|
|
module._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
|
if not module._full_precision_mm:
|
|
module._full_precision_mm = module._full_precision_mm_config
|
|
if module.quant_format in disabled_formats:
|
|
module._full_precision_mm = True
|
|
if module.quant_format is None:
|
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
|
|
|
qconfig = QUANT_ALGOS[module.quant_format]
|
|
module.layout_type = qconfig["comfy_tensor_layout"]
|
|
layout_cls = get_layout_class(module.layout_type)
|
|
|
|
# Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8.
|
|
if module.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
|
scales = {"scale": pop_scale("weight_scale")}
|
|
elif module.quant_format == "mxfp8":
|
|
bs = pop_scale("weight_scale", torch.float8_e8m0fnu)
|
|
if bs is None:
|
|
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
|
scales = {"scale": bs}
|
|
elif module.quant_format == "nvfp4":
|
|
ts = pop_scale("weight_scale_2")
|
|
bs = pop_scale("weight_scale", torch.float8_e4m3fn)
|
|
if ts is None or bs is None:
|
|
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
|
scales = {"scale": ts, "block_scale": bs}
|
|
else:
|
|
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
|
|
|
|
params = layout_cls.Params(**scales, orig_dtype=compute_dtype, orig_shape=module._orig_shape)
|
|
module.weight = torch.nn.Parameter(
|
|
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), module.layout_type, params),
|
|
requires_grad=False,
|
|
)
|
|
|
|
if load_extra_params:
|
|
for param_name in qconfig["parameters"]:
|
|
if param_name in {"weight_scale", "weight_scale_2"}:
|
|
continue
|
|
param_key = f"{prefix}{param_name}"
|
|
_v = state_dict.pop(param_key, None)
|
|
if _v is None:
|
|
continue
|
|
module.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
|
manually_loaded_keys.append(param_key)
|
|
|
|
super_load(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
for key in manually_loaded_keys:
|
|
if key in missing_keys:
|
|
missing_keys.remove(key)
|
|
|
|
|
|
def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extra_quant_params=()):
|
|
"""Shared state_dict body. extra_quant_conf merges into the comfy_quant JSON;
|
|
extra_quant_params names attributes written as additional top-level keys."""
|
|
if not hasattr(module, 'weight'):
|
|
logging.warning(f"Warning: state dict on uninitialized op {prefix}")
|
|
return sd
|
|
bias = getattr(module, 'bias', None)
|
|
if bias is not None:
|
|
sd[f"{prefix}bias"] = bias
|
|
if module.weight is None:
|
|
return sd
|
|
if isinstance(module.weight, QuantizedTensor):
|
|
sd.update(module.weight.state_dict(f"{prefix}weight"))
|
|
quant_conf = {"format": module.quant_format}
|
|
if getattr(module, '_full_precision_mm_config', False):
|
|
quant_conf["full_precision_matrix_mult"] = True
|
|
if extra_quant_conf:
|
|
quant_conf.update(extra_quant_conf)
|
|
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
|
|
for name in extra_quant_params:
|
|
value = getattr(module, name, None)
|
|
if value is not None:
|
|
sd[f"{prefix}{name}"] = value
|
|
else:
|
|
sd[f"{prefix}weight"] = module.weight
|
|
return sd
|
|
|
|
|
|
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
|
class MixedPrecisionOps(manual_cast):
|
|
_quant_config = quant_config
|
|
_compute_dtype = compute_dtype
|
|
_full_precision_mm = full_precision_mm
|
|
_disabled = disabled
|
|
|
|
class Linear(torch.nn.Module, CastWeightBiasOp):
|
|
_disabled_formats = disabled
|
|
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
|
super().__init__()
|
|
|
|
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
|
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self._orig_shape = (out_features, in_features)
|
|
if bias:
|
|
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
self.tensor_class = None
|
|
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
|
self._full_precision_mm_config = False
|
|
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def _load_from_state_dict(self, *args):
|
|
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=True)
|
|
|
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
|
sd = destination if destination is not None else {}
|
|
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",))
|
|
|
|
def _forward(self, input, weight, bias):
|
|
return torch.nn.functional.linear(input, weight, bias)
|
|
|
|
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
|
|
x = self._forward(input, weight, bias)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, input, *args, **kwargs):
|
|
run_every_op()
|
|
|
|
input_shape = input.shape
|
|
reshaped_3d = False
|
|
#If cast needs to apply lora, it should be done in the compute dtype
|
|
compute_dtype = input.dtype
|
|
|
|
_use_quantized = (
|
|
getattr(self, 'layout_type', None) is not None and
|
|
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
|
not getattr(self, 'comfy_force_cast_weights', False) and
|
|
len(self.weight_function) == 0 and len(self.bias_function) == 0
|
|
)
|
|
|
|
# Training path: quantized forward with compute_dtype backward via autograd function
|
|
if (input.requires_grad and _use_quantized):
|
|
|
|
weight, bias, offload_stream = cast_bias_weight(
|
|
self,
|
|
input,
|
|
offloadable=True,
|
|
compute_dtype=compute_dtype,
|
|
want_requant=True
|
|
)
|
|
|
|
scale = getattr(self, 'input_scale', None)
|
|
if scale is not None:
|
|
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
|
|
|
output = QuantLinearFunc.apply(
|
|
input, weight, bias, self.layout_type, scale, compute_dtype
|
|
)
|
|
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return output
|
|
|
|
# Inference path (unchanged)
|
|
if _use_quantized:
|
|
|
|
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
|
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
|
|
|
# Fall back to non-quantized for non-2D tensors
|
|
if input_reshaped.ndim == 2:
|
|
reshaped_3d = input.ndim == 3
|
|
# dtype is now implicit in the layout class
|
|
scale = getattr(self, 'input_scale', None)
|
|
if scale is not None:
|
|
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
|
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
|
|
|
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
|
|
|
# Reshape output back to 3D if input was 3D
|
|
if reshaped_3d:
|
|
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
|
|
|
|
return output
|
|
|
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
|
if isinstance(weight, QuantizedTensor):
|
|
return weight.dequantize()
|
|
else:
|
|
return weight
|
|
|
|
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
|
if getattr(self, 'layout_type', None) is not None:
|
|
# dtype is now implicit in the layout class
|
|
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
|
else:
|
|
weight = weight.to(self.weight.dtype)
|
|
if return_weight:
|
|
return weight
|
|
|
|
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
|
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
|
|
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
|
|
return _quantized_apply(self, fn, recurse)
|
|
|
|
class MoEExperts(torch.nn.Module, CastWeightBiasOp):
|
|
"""Container for E quantized expert weights, indexed via expert_weight(i).
|
|
|
|
The bank lives on self.weight as a single 3D tensor — either a
|
|
compute_dtype Parameter or a Parameter wrapping a QuantizedTensor
|
|
with leading expert dim.
|
|
|
|
State-dict layout matches mixed_precision_ops.Linear with a leading
|
|
expert dim:
|
|
{prefix}.weight quant data (storage_t), leading dim = E
|
|
{prefix}.weight_scale block / per-tensor scale
|
|
{prefix}.weight_scale_2 [E] or scalar NVFP4 only
|
|
{prefix}.bias [E, out_features] optional, compute_dtype
|
|
{prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}}
|
|
|
|
Without comfy_quant the weight loads as a plain compute_dtype 3D Parameter [E, out, in].
|
|
"""
|
|
|
|
_disabled_formats = disabled
|
|
|
|
def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
|
super().__init__()
|
|
self.num_experts = num_experts
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self._orig_shape = (num_experts, out_features, in_features)
|
|
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
|
if bias:
|
|
self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs))
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
# Populated by _load_from_state_dict:
|
|
self.weight = None
|
|
self.quant_format = None
|
|
self.layout_type = None
|
|
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
|
self._full_precision_mm_config = False
|
|
self._resident_bank = None
|
|
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def _apply(self, fn, recurse=True):
|
|
return _quantized_apply(self, fn, recurse)
|
|
|
|
def _load_from_state_dict(self, *args):
|
|
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=False)
|
|
|
|
def expert_weight(self, i: int):
|
|
"""Expert i's weight (Tensor or per-expert QuantizedTensor view)."""
|
|
if isinstance(self.weight, QuantizedTensor):
|
|
return self._expert_qt_from(self.weight, i)
|
|
return self.weight[i]
|
|
|
|
@contextlib.contextmanager
|
|
def bank_resident(self, input):
|
|
"""Cast the whole bank once; expert_linear inside reuses the cast.
|
|
Not re-entrant — do not nest calls on the same instance.
|
|
"""
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
self._resident_bank = (weight, bias)
|
|
try:
|
|
yield self
|
|
finally:
|
|
self._resident_bank = None
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
|
|
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
|
|
"""Linear against expert i's weight (with optional bias)."""
|
|
resident = getattr(self, "_resident_bank", None)
|
|
if resident is not None:
|
|
weight, bias = resident
|
|
return self._expert_linear_impl(input, weight, bias, i)
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
try:
|
|
return self._expert_linear_impl(input, weight, bias, i)
|
|
finally:
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
|
|
def _expert_linear_impl(self, input, weight, bias, i):
|
|
if isinstance(weight, QuantizedTensor):
|
|
qw = self._expert_qt_from(weight, i)
|
|
else:
|
|
qw = weight[i]
|
|
b = cast_to_input(bias[i], input, copy=False) if bias is not None else None
|
|
|
|
if isinstance(qw, QuantizedTensor):
|
|
use_fast = (
|
|
not self._full_precision_mm
|
|
and qw.layout_cls.supports_fast_matmul()
|
|
and input.dim() == 2
|
|
)
|
|
if use_fast:
|
|
qin = QuantizedTensor.from_float(input, self.layout_type)
|
|
return torch.nn.functional.linear(qin, qw, b)
|
|
out = input @ qw.dequantize().t()
|
|
return out + b if b is not None else out
|
|
return torch.nn.functional.linear(input, qw, b)
|
|
|
|
def _expert_qt_from(self, weight: QuantizedTensor, i: int) -> QuantizedTensor:
|
|
"""Build a per-expert QuantizedTensor by indexing into a resident bank."""
|
|
params = weight._params
|
|
kwargs = {
|
|
"scale": params.scale[i] if params.scale.dim() else params.scale,
|
|
"orig_dtype": params.orig_dtype,
|
|
"orig_shape": (self.out_features, self.in_features),
|
|
}
|
|
if hasattr(params, "block_scale"): # NVFP4
|
|
kwargs["block_scale"] = params.block_scale[i]
|
|
return QuantizedTensor(weight._qdata[i], weight._layout_cls, type(params)(**kwargs))
|
|
|
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
|
sd = destination if destination is not None else {}
|
|
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf={"num_experts": self.num_experts})
|
|
|
|
class Embedding(manual_cast.Embedding):
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
|
weight_key = f"{prefix}weight"
|
|
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
|
if layer_conf is not None:
|
|
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
|
|
|
# Only fp8 makes sense for embeddings (per-row dequant via index select).
|
|
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
|
|
quant_format = layer_conf.get("format") if layer_conf is not None else None
|
|
manually_loaded_keys = []
|
|
|
|
if quant_format in ("float8_e4m3fn", "float8_e5m2") and weight_key in state_dict:
|
|
self.quant_format = quant_format
|
|
qconfig = QUANT_ALGOS[quant_format]
|
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
|
layout_cls = get_layout_class(self.layout_type)
|
|
weight = state_dict.pop(weight_key)
|
|
manually_loaded_keys.append(weight_key)
|
|
|
|
scale_key = f"{prefix}weight_scale"
|
|
scale = state_dict.pop(scale_key, None)
|
|
if scale is not None:
|
|
scale = scale.float()
|
|
manually_loaded_keys.append(scale_key)
|
|
|
|
params = layout_cls.Params(
|
|
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
|
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
|
orig_shape=(self.num_embeddings, self.embedding_dim),
|
|
)
|
|
self.weight = torch.nn.Parameter(
|
|
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
|
|
requires_grad=False)
|
|
elif layer_conf is not None:
|
|
# Unsupported format — restore the marker so it round-trips; fall through to default load.
|
|
state_dict[f"{prefix}comfy_quant"] = torch.tensor(
|
|
list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
|
|
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
for k in manually_loaded_keys:
|
|
if k in missing_keys:
|
|
missing_keys.remove(k)
|
|
|
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
|
sd = destination if destination is not None else {}
|
|
return _quantized_weight_state_dict(self, sd, prefix)
|
|
|
|
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
|
weight = self.weight
|
|
|
|
# Optimized path: lookup in fp8, dequantize only the selected rows.
|
|
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
|
|
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
|
|
if isinstance(qdata, QuantizedTensor):
|
|
scale = qdata._params.scale
|
|
qdata = qdata._qdata
|
|
else:
|
|
scale = None
|
|
|
|
x = torch.nn.functional.embedding(
|
|
input, qdata, self.padding_idx, self.max_norm,
|
|
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
|
uncast_bias_weight(self, qdata, None, offload_stream)
|
|
target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
|
|
x = x.to(dtype=target_dtype)
|
|
if scale is not None and scale != 1.0:
|
|
x = x * scale.to(dtype=target_dtype)
|
|
return x
|
|
|
|
# Fallback for non-quantized or weight_function (LoRA) case
|
|
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
|
|
|
|
return MixedPrecisionOps
|
|
|
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
|
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
|
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
|
|
|
|
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
|
logging.info("Using mixed precision operations")
|
|
disabled = set()
|
|
if not nvfp4_compute:
|
|
disabled.add("nvfp4")
|
|
if not mxfp8_compute:
|
|
disabled.add("mxfp8")
|
|
if not fp8_compute:
|
|
disabled.add("float8_e4m3fn")
|
|
disabled.add("float8_e5m2")
|
|
logging.info("Native ops: {} {}".format(", ".join(QUANT_ALGOS.keys() - disabled), ", emulated ops: {}".format(", ".join(disabled)) if len(disabled) > 0 else ""))
|
|
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
|
|
|
|
if (
|
|
fp8_compute and
|
|
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
|
|
not disable_fast_fp8
|
|
):
|
|
return fp8_ops
|
|
|
|
if (
|
|
PerformanceFeature.CublasOps in args.fast and
|
|
CUBLAS_IS_AVAILABLE and
|
|
weight_dtype == torch.float16 and
|
|
(compute_dtype == torch.float16 or compute_dtype is None)
|
|
):
|
|
logging.info("Using cublas ops")
|
|
return cublas_ops
|
|
|
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
|
return disable_weight_init
|
|
|
|
return manual_cast
|