Disable dynamic_vram when using torch compiler (#12612)

* mp: attach re-construction arguments to model patcher

When making a model-patcher from a unet or ckpt, attach a callable
function that can be called to replay the model construction. This
can be used to deep clone model patcher WRT the actual model.

Originally written by Kosinkadink
f4b99bc623

* mp: Add disable_dynamic clone argument

Add a clone argument that lets a caller clone a ModelPatcher but disable
dynamic to demote the clone to regular MP. This is useful for legacy
features where dynamic_vram support is missing or TBD.

* torch_compile: disable dynamic_vram

This is a bigger feature. Disable for the interim to preserve
functionality.
This commit is contained in:
rattus
2026-02-24 16:13:46 -08:00
committed by GitHub
parent befa83d434
commit 3ebe1ac22e
3 changed files with 34 additions and 11 deletions

View File

@ -271,6 +271,7 @@ class ModelPatcher:
self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
self.cached_patcher_init: tuple[Callable, tuple] | None = None
if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
@ -307,8 +308,15 @@ class ModelPatcher:
def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device)
def clone(self):
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
def clone(self, disable_dynamic=False):
class_ = self.__class__
model = self.model
if self.is_dynamic() and disable_dynamic:
class_ = ModelPatcher
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
model = temp_model_patcher.model
n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
@ -362,6 +370,8 @@ class ModelPatcher:
n.is_clip = self.is_clip
n.hook_mode = self.hook_mode
n.cached_patcher_init = self.cached_patcher_init
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
return n