mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 11:53:04 +08:00
Compare commits
3 Commits
jk/optiona
...
prs/dynami
| Author | SHA1 | Date | |
|---|---|---|---|
| 8423394577 | |||
| f7aebddcf6 | |||
| 89dc4a8df8 |
@ -1213,8 +1213,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||
if signature is not None:
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, weight._v_tensor)[0]
|
||||
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
v_tensor = weight._v_tensor
|
||||
else:
|
||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||
weight._v_tensor = v_tensor
|
||||
weight._v_signature = signature
|
||||
#Send it over
|
||||
v_tensor.copy_(weight, non_blocking=non_blocking)
|
||||
|
||||
@ -1542,7 +1542,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
if vbar is not None and not hasattr(m, "_v"):
|
||||
m._v = vbar.alloc(v_weight_size)
|
||||
m._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(m._v, device_to)
|
||||
allocated_size += v_weight_size
|
||||
|
||||
else:
|
||||
@ -1557,7 +1556,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
weight_size = geometry.numel() * geometry.element_size()
|
||||
if vbar is not None and not hasattr(weight, "_v"):
|
||||
weight._v = vbar.alloc(weight_size)
|
||||
weight._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device_to)
|
||||
weight._model_dtype = model_dtype
|
||||
allocated_size += weight_size
|
||||
vbar.set_watermark_limit(allocated_size)
|
||||
|
||||
21
comfy/ops.py
21
comfy/ops.py
@ -83,14 +83,18 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
if signature is not None:
|
||||
xfer_dest = s._v_tensor
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
if signature is not None:
|
||||
if resident:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
|
||||
if not resident:
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
cast_dest = None
|
||||
|
||||
xfer_source = [ s.weight, s.bias ]
|
||||
@ -140,9 +144,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
post_cast.copy_(pre_cast)
|
||||
xfer_dest = cast_dest
|
||||
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
if signature is not None:
|
||||
s._v_weight = weight
|
||||
s._v_bias = bias
|
||||
s._v_signature=signature
|
||||
|
||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
@ -182,7 +190,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||
if s.bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||
s._v_signature=signature
|
||||
|
||||
#FIXME: weird offload return protocol
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
|
||||
@ -30,46 +30,6 @@ from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL, SVG as _SVG, File3D
|
||||
|
||||
|
||||
class EmptyInputSentinel:
|
||||
"""
|
||||
Sentinel class indicating an empty/missing input.
|
||||
|
||||
Use the class itself (not an instance) as the sentinel.
|
||||
Compare using 'is' or 'is not' only.
|
||||
"""
|
||||
|
||||
def __new__(cls):
|
||||
raise TypeError("EmptyInputSentinel cannot be instantiated, use the class itself")
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
raise TypeError("EmptyInputSentinel cannot be subclassed")
|
||||
|
||||
@classmethod
|
||||
def __class_getitem__(cls, item):
|
||||
raise TypeError("EmptyInputSentinel cannot be subscripted")
|
||||
|
||||
def __repr__(self):
|
||||
return "<EmptyInput>"
|
||||
|
||||
def __bool__(self):
|
||||
raise TypeError("EmptyInputSentinel cannot be used in boolean context")
|
||||
|
||||
def __eq__(self, other):
|
||||
raise TypeError("EmptyInputSentinel cannot be compared with ==, use 'is' instead")
|
||||
|
||||
def __ne__(self, other):
|
||||
raise TypeError("EmptyInputSentinel cannot be compared with !=, use 'is not' instead")
|
||||
|
||||
def __hash__(self):
|
||||
raise TypeError("EmptyInputSentinel cannot be hashed")
|
||||
|
||||
def __iter__(self):
|
||||
raise TypeError("EmptyInputSentinel cannot be iterated")
|
||||
|
||||
def __len__(self):
|
||||
raise TypeError("EmptyInputSentinel has no length")
|
||||
|
||||
|
||||
class FolderType(str, Enum):
|
||||
input = "input"
|
||||
output = "output"
|
||||
@ -2150,7 +2110,6 @@ __all__ = [
|
||||
"DynamicCombo",
|
||||
"Autogrow",
|
||||
# Other classes
|
||||
"EmptyInputSentinel",
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
"NodeInfoV1",
|
||||
|
||||
@ -91,41 +91,6 @@ class SoftSwitchNode(io.ComfyNode):
|
||||
return io.NodeOutput(on_true if switch else on_false)
|
||||
|
||||
|
||||
class OptionalSwitchNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = io.MatchType.Template("switch")
|
||||
return io.Schema(
|
||||
node_id="ComfyOptionalSwitchNode",
|
||||
display_name="Optional Switch",
|
||||
category="logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Boolean.Input("switch"),
|
||||
io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
|
||||
io.MatchType.Input("on_true", template=template, lazy=True, optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.MatchType.Output(template=template, display_name="output"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING):
|
||||
# Only evaluate the input that corresponds to the switch value
|
||||
if switch and on_true is None:
|
||||
return ["on_true"]
|
||||
if not switch and on_false is None:
|
||||
return ["on_false"]
|
||||
|
||||
@classmethod
|
||||
def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput:
|
||||
selected = on_true if switch else on_false
|
||||
if selected is MISSING:
|
||||
return io.NodeOutput(io.EmptyInputSentinel)
|
||||
return io.NodeOutput(selected)
|
||||
|
||||
|
||||
class CustomComboNode(io.ComfyNode):
|
||||
"""
|
||||
Frontend node that allows user to write their own options for a combo.
|
||||
@ -295,7 +260,6 @@ class LogicExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SwitchNode,
|
||||
OptionalSwitchNode,
|
||||
CustomComboNode,
|
||||
# SoftSwitchNode,
|
||||
# ConvertStringToComboNode,
|
||||
|
||||
@ -980,10 +980,6 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
input_filtered[x] = input_data_all[x]
|
||||
if 'input_types' in validate_function_inputs:
|
||||
input_filtered['input_types'] = [received_types]
|
||||
for x in list(input_filtered.keys()):
|
||||
if input_filtered[x] is io.EmptyInputSentinel:
|
||||
del input_filtered[x]
|
||||
|
||||
|
||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data)
|
||||
ret = await resolve_map_node_over_list_results(ret)
|
||||
|
||||
Reference in New Issue
Block a user