From 3f51595a2811dd2add586aa293156b6b4e342bd7 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Tue, 19 May 2026 20:47:28 +0200 Subject: [PATCH] Fix no clamping when normalization is raw and refactor reusable code to _apply_sky_clip and _depth_to_image --- comfy_extras/nodes_depth_anything_3.py | 68 ++++++++++++++------------ 1 file changed, 37 insertions(+), 31 deletions(-) diff --git a/comfy_extras/nodes_depth_anything_3.py b/comfy_extras/nodes_depth_anything_3.py index 999613276..9f41022a2 100644 --- a/comfy_extras/nodes_depth_anything_3.py +++ b/comfy_extras/nodes_depth_anything_3.py @@ -296,29 +296,48 @@ class DepthAnything3(io.ComfyNode): ref_view_strategy, pose_method, ) - @classmethod - def _execute_mono(cls, model, image, process_res, resize_method, - normalization, apply_sky_clip) -> io.NodeOutput: - depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method) + @staticmethod + def _apply_sky_clip(depth: torch.Tensor, sky: torch.Tensor) -> torch.Tensor: + return torch.stack([ + da3_preprocess.apply_sky_aware_clip(depth[i], sky[i]) + for i in range(depth.shape[0]) + ], dim=0) - if apply_sky_clip and sky is not None: - depth = torch.stack([ - da3_preprocess.apply_sky_aware_clip(depth[i], sky[i]) - for i in range(depth.shape[0]) - ], dim=0) + @staticmethod + def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None, + normalization: str) -> torch.Tensor: + """Normalise depth and pack as an (N,H,W,3) image tensor. + Preserves metric units when normalization is 'raw' (no clamping). + """ + N = depth.shape[0] if normalization == "v2_style": norm = torch.stack([ da3_preprocess.normalize_depth_v2_style( - depth[i], sky[i] if sky is not None else None) - for i in range(depth.shape[0]) + depth[i], sky_for_norm[i] if sky_for_norm is not None else None) + for i in range(N) ], dim=0) elif normalization == "min_max": norm = da3_preprocess.normalize_depth_min_max(depth) else: norm = depth - out_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous() + # Preserve metric units when normalization is raw. + out = norm.unsqueeze(-1).repeat(1, 1, 1, 3) + if normalization != "raw": + out = out.clamp(0.0, 1.0) + return out.contiguous() + + @classmethod + def _execute_mono(cls, model, image, process_res, resize_method, + normalization, apply_sky_clip) -> io.NodeOutput: + depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method) + + if apply_sky_clip and sky is not None: + depth = cls._apply_sky_clip(depth, sky) + + out_image = cls._depth_to_image(depth, sky, normalization) + sky_mask = sky if sky is not None else torch.zeros_like(depth) conf_mask = (_normalize_confidence(confidence) if confidence is not None else torch.zeros_like(depth)) @@ -367,18 +386,15 @@ class DepthAnything3(io.ComfyNode): conf_mask = _normalize_confidence(conf_raw) if conf_raw.any() else conf_raw - sky = torch.zeros_like(depth) + sky = None if "sky" in out: sky = torch.nn.functional.interpolate( out["sky"].unsqueeze(1).float(), size=(H, W), mode="bilinear", align_corners=False, ).squeeze(1).cpu() - if apply_sky_clip and "sky" in out: - depth = torch.stack([ - da3_preprocess.apply_sky_aware_clip(depth[i], sky[i]) - for i in range(S) - ], dim=0) + if apply_sky_clip and sky is not None: + depth = cls._apply_sky_clip(depth, sky) if "extrinsics" in out and "intrinsics" in out: extrinsics = out["extrinsics"].float().cpu() @@ -388,19 +404,9 @@ class DepthAnything3(io.ComfyNode): intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone() sky_for_norm = sky if diffusion.has_sky else None - if normalization == "v2_style": - norm = torch.stack([ - da3_preprocess.normalize_depth_v2_style( - depth[i], sky_for_norm[i] if sky_for_norm is not None else None) - for i in range(S) - ], dim=0) - elif normalization == "min_max": - norm = da3_preprocess.normalize_depth_min_max(depth) - else: - norm = depth - - depth_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous() + depth_image = cls._depth_to_image(depth, sky_for_norm, normalization) + sky_mask = sky if sky is not None else torch.zeros_like(depth) camera_latent = { "samples": depth.unsqueeze(0).unsqueeze(2).contiguous(), # (1, S, 1, H, W) "type": "da3_multiview", @@ -411,7 +417,7 @@ class DepthAnything3(io.ComfyNode): } return io.NodeOutput( depth_image, - sky.contiguous(), + sky_mask.contiguous(), conf_mask.contiguous(), camera_latent, )