diff --git a/comfy/ops.py b/comfy/ops.py index 8275dd0a5..3e19cd1b6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -660,23 +660,29 @@ class fp8_ops(manual_cast): CUBLAS_IS_AVAILABLE = False try: - from cublas_ops import CublasLinear + from cublas_ops import CublasLinear, cublas_half_matmul CUBLAS_IS_AVAILABLE = True except ImportError: pass if CUBLAS_IS_AVAILABLE: - class cublas_ops(disable_weight_init): - class Linear(CublasLinear, disable_weight_init.Linear): + class cublas_ops(manual_cast): + class Linear(CublasLinear, manual_cast.Linear): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): - return super().forward(input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): - return super().forward(*args, **kwargs) - + run_every_op() + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) # ============================================================================== # Mixed Precision Operations