mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-13 10:38:10 +08:00
Compare commits
1 Commits
cloud-open
...
feat/video
| Author | SHA1 | Date | |
|---|---|---|---|
| 5616097a0d |
@ -8,7 +8,6 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from comfy.ldm.lightricks.model import Timesteps
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
@ -18,7 +17,9 @@ def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
return apply_rope1(x, freqs_cis)
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape).to(dtype=x.dtype)
|
||||
|
||||
|
||||
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -5,7 +5,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, SPLAT, File3D
|
||||
from ._util import VideoCodec, VideoContainer, VideoBitDepth, VideoComponents, MESH, VOXEL, SPLAT, File3D
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
from comfy_execution.utils import get_executing_context
|
||||
@ -140,6 +140,7 @@ class InputImpl:
|
||||
class Types:
|
||||
VideoCodec = VideoCodec
|
||||
VideoContainer = VideoContainer
|
||||
VideoBitDepth = VideoBitDepth
|
||||
VideoComponents = VideoComponents
|
||||
MESH = MESH
|
||||
VOXEL = VOXEL
|
||||
|
||||
@ -4,7 +4,7 @@ from fractions import Fraction
|
||||
from typing import Optional, Union, IO
|
||||
import io
|
||||
import av
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
from .._util import VideoContainer, VideoCodec, VideoBitDepth, VideoComponents
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
@ -27,10 +27,14 @@ class VideoInput(ABC):
|
||||
path: Union[str, IO[bytes]],
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: VideoBitDepth = VideoBitDepth.AUTO,
|
||||
):
|
||||
"""
|
||||
Abstract method to save the video input to a file.
|
||||
|
||||
bit_depth applies when the video is (re-)encoded: AUTO preserves the
|
||||
source bit depth where one exists, otherwise 8-bit.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ import json
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
from .._util import VideoContainer, VideoCodec, VideoBitDepth, VideoComponents
|
||||
import logging
|
||||
|
||||
|
||||
@ -52,19 +52,41 @@ def get_open_write_kwargs(
|
||||
return open_kwargs
|
||||
|
||||
|
||||
def video_stream_bit_depth(stream) -> int:
|
||||
"""Best-effort bit depth of a video stream's pixel format; defaults to 8."""
|
||||
if stream is None or stream.format is None or not stream.format.components:
|
||||
return 8
|
||||
return max(component.bits for component in stream.format.components)
|
||||
|
||||
|
||||
class VideoFromFile(VideoInput):
|
||||
"""
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
|
||||
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0, bit_depth_cap: int | None = None):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
|
||||
bit_depth_cap limits the default bit depth of saved files (an explicit
|
||||
save_to(bit_depth=...) still wins); tensor access is unaffected.
|
||||
"""
|
||||
self.__file = file
|
||||
self.__start_time = start_time
|
||||
self.__duration = duration
|
||||
self.__bit_depth_cap = bit_depth_cap
|
||||
|
||||
def with_bit_depth_cap(self, bit_depth_cap: Optional[int]) -> "VideoFromFile":
|
||||
"""A copy of this video (sharing the same source) whose saved files default to the capped bit depth.
|
||||
|
||||
Returns self when the cap is already in place; None lifts the cap.
|
||||
"""
|
||||
if bit_depth_cap == self.__bit_depth_cap:
|
||||
return self
|
||||
return VideoFromFile(
|
||||
self.__file, start_time=self.__start_time, duration=self.__duration, bit_depth_cap=bit_depth_cap
|
||||
)
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
@ -377,25 +399,35 @@ class VideoFromFile(VideoInput):
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: VideoBitDepth = VideoBitDepth.AUTO,
|
||||
):
|
||||
bit_depth = VideoBitDepth(bit_depth)
|
||||
if bit_depth == VideoBitDepth.AUTO and self.__bit_depth_cap is not None and self.__bit_depth_cap < 10:
|
||||
bit_depth = VideoBitDepth.BIT_8
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||
video_encoding = video_stream.codec.name if video_stream is not None else None
|
||||
source_bit_depth = video_stream_bit_depth(video_stream)
|
||||
reuse_streams = True
|
||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
if bit_depth != VideoBitDepth.AUTO and video_encoding is not None and bit_depth.bits() != source_bit_depth:
|
||||
reuse_streams = False
|
||||
if self.__start_time or self.__duration:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
if bit_depth == VideoBitDepth.AUTO:
|
||||
bit_depth = VideoBitDepth.BIT_10 if source_bit_depth >= 10 else VideoBitDepth.BIT_8
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path, format=format, codec=codec, metadata=metadata
|
||||
path, format=format, codec=codec, metadata=metadata, bit_depth=bit_depth
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
@ -440,6 +472,7 @@ class VideoFromFile(VideoInput):
|
||||
self.get_stream_source(),
|
||||
start_time=start_time + self.__start_time,
|
||||
duration=duration,
|
||||
bit_depth_cap=self.__bit_depth_cap,
|
||||
)
|
||||
if trimmed.get_duration() < duration and strict_duration:
|
||||
return None
|
||||
@ -467,12 +500,15 @@ class VideoFromComponents(VideoInput):
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: VideoBitDepth = VideoBitDepth.AUTO,
|
||||
):
|
||||
"""Save the video to a file path or BytesIO buffer."""
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
# AUTO is 8-bit: tensor components have no source bit depth to preserve.
|
||||
is_10bit = VideoBitDepth(bit_depth) == VideoBitDepth.BIT_10
|
||||
extra_kwargs = {}
|
||||
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
||||
extra_kwargs["format"] = format.value
|
||||
@ -488,10 +524,11 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
pix_fmt = 'yuv420p10le' if is_10bit else 'yuv420p'
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
video_stream.pix_fmt = pix_fmt
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
@ -505,9 +542,14 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
if is_10bit:
|
||||
# 16-bit RGB keeps float precision through the conversion to 10-bit YUV.
|
||||
img = (frame.float() * 65535).clamp(0, 65535).cpu().numpy().astype(np.uint16) # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb48le')
|
||||
else:
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format=pix_fmt)
|
||||
packet = video_stream.encode(frame)
|
||||
output.mux(packet)
|
||||
|
||||
@ -534,3 +576,19 @@ class VideoFromComponents(VideoInput):
|
||||
return None
|
||||
#TODO Consider tracking duration and trimming at time of save?
|
||||
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)
|
||||
|
||||
|
||||
def apply_video_input_accepts(values: list, input_info: dict | None) -> list:
|
||||
"""Apply a VIDEO input's `accepts` declaration to its bound values.
|
||||
|
||||
Inputs declaring `accepts={"depth": 10}` receive uncapped videos.
|
||||
For the rest, file-backed videos are replaced with copies that save as 8-bit by default,
|
||||
so existing nodes keep producing 8-bit files.
|
||||
VideoFromFile subclasses and other VideoInput implementations own their depth behavior and pass through unchanged.
|
||||
"""
|
||||
accepts = (input_info or {}).get("accepts") or {}
|
||||
cap = None if accepts.get("depth", 8) >= 10 else 8
|
||||
return [
|
||||
value.with_bit_depth_cap(cap) if type(value) is VideoFromFile else value
|
||||
for value in values
|
||||
]
|
||||
|
||||
@ -662,6 +662,26 @@ class Video(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
Type = VideoInput
|
||||
|
||||
class Input(Input):
|
||||
"""Video input socket.
|
||||
|
||||
`accepts` declares which video properties the node handles itself; only "depth" (8 or 10) is supported for now,
|
||||
e.g. `accepts={"depth": 10}`. Inputs without it receive videos whose saved files are capped to 8-bit.
|
||||
"""
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
extra_dict=None, raw_link: bool=None, advanced: bool=None, accepts: dict=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced)
|
||||
if accepts is not None:
|
||||
unknown_keys = set(accepts) - {"depth"}
|
||||
if unknown_keys:
|
||||
raise ValueError(f"Unsupported keys in Video.Input accepts: {sorted(unknown_keys)}")
|
||||
if "depth" in accepts and accepts["depth"] not in (8, 10):
|
||||
raise ValueError("Video.Input accepts['depth'] must be 8 or 10")
|
||||
self.accepts = accepts
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({"accepts": self.accepts})
|
||||
|
||||
@comfytype(io_type="SVG")
|
||||
class SVG(ComfyTypeIO):
|
||||
Type = _SVG
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .video_types import VideoContainer, VideoCodec, VideoBitDepth, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH, SPLAT, File3D
|
||||
from .image_types import SVG
|
||||
|
||||
@ -6,6 +6,7 @@ __all__ = [
|
||||
# Utility Types
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoBitDepth",
|
||||
"VideoComponents",
|
||||
"VOXEL",
|
||||
"MESH",
|
||||
|
||||
@ -15,6 +15,23 @@ class VideoCodec(str, Enum):
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
|
||||
class VideoBitDepth(str, Enum):
|
||||
AUTO = "auto"
|
||||
BIT_8 = "8-bit"
|
||||
BIT_10 = "10-bit"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""Returns a list of bit depth names that can be used as node input."""
|
||||
return [member.value for member in cls]
|
||||
|
||||
def bits(self) -> Optional[int]:
|
||||
"""Returns the numeric bit depth, or None for AUTO."""
|
||||
if self == VideoBitDepth.AUTO:
|
||||
return None
|
||||
return int(self.value.split("-")[0])
|
||||
|
||||
class VideoContainer(str, Enum):
|
||||
AUTO = "auto"
|
||||
MP4 = "mp4"
|
||||
|
||||
@ -82,17 +82,26 @@ class SaveVideo(io.ComfyNode):
|
||||
essentials_category="Basics",
|
||||
description="Saves the input images to your ComfyUI output directory.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to save."),
|
||||
io.Video.Input("video", tooltip="The video to save.", accepts={"depth": 10}),
|
||||
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
||||
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||
io.Combo.Input(
|
||||
"bit_depth",
|
||||
options=Types.VideoBitDepth.as_input(),
|
||||
default="auto",
|
||||
tooltip="Bit depth used when the video has to be re-encoded."
|
||||
" 'auto' keeps the bit depth of the source video (videos created from images are saved as 8-bit)."
|
||||
" 10-bit keeps smoother gradients with less banding, but some players may not support it.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput:
|
||||
def execute(cls, video: Input.Video, filename_prefix, format: str, codec, bit_depth: str = "auto") -> io.NodeOutput:
|
||||
width, height = video.get_dimensions()
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
filename_prefix,
|
||||
@ -110,11 +119,14 @@ class SaveVideo(io.ComfyNode):
|
||||
if len(metadata) > 0:
|
||||
saved_metadata = metadata
|
||||
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
||||
bit_depth = Types.VideoBitDepth(bit_depth)
|
||||
save_kwargs = {} if bit_depth == Types.VideoBitDepth.AUTO else {"bit_depth": bit_depth}
|
||||
video.save_to(
|
||||
os.path.join(full_output_folder, file),
|
||||
format=Types.VideoContainer(format),
|
||||
codec=codec,
|
||||
metadata=saved_metadata
|
||||
metadata=saved_metadata,
|
||||
**save_kwargs,
|
||||
)
|
||||
|
||||
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
|
||||
@ -226,7 +238,7 @@ class VideoSlice(io.ComfyNode):
|
||||
category="video",
|
||||
essentials_category="Video Tools",
|
||||
inputs=[
|
||||
io.Video.Input("video"),
|
||||
io.Video.Input("video", accepts={"depth": 10}),
|
||||
io.Float.Input(
|
||||
"start_time",
|
||||
default=0.0,
|
||||
|
||||
@ -43,6 +43,7 @@ from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_execution.asset_enrichment import enrich_output_with_assets
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io, _io
|
||||
from comfy_api.latest._input_impl.video_types import apply_video_input_accepts
|
||||
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
|
||||
|
||||
|
||||
@ -164,7 +165,7 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
missing_keys = {}
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||
def mark_missing():
|
||||
missing_keys[x] = True
|
||||
input_data_all[x] = (None,)
|
||||
@ -182,6 +183,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
mark_missing()
|
||||
continue
|
||||
obj = cached.outputs[output_index]
|
||||
if input_type == io.Video.io_type:
|
||||
obj = apply_video_input_accepts(obj, input_info)
|
||||
input_data_all[x] = obj
|
||||
elif input_category is not None or (is_v3 and class_def.ACCEPT_ALL_INPUTS):
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
14
openapi.yaml
14
openapi.yaml
@ -896,6 +896,11 @@ components:
|
||||
additionalProperties: true
|
||||
description: The workflow graph to execute
|
||||
type: object
|
||||
prompt_id:
|
||||
description: Optional client-supplied job id. Must be a UUID in canonical lowercase hyphenated form; it is echoed back in the response. Omitted or null means the server generates one.
|
||||
format: uuid
|
||||
nullable: true
|
||||
type: string
|
||||
workflow_id:
|
||||
description: UUID identifying the cloud workflow entity to associate with this job
|
||||
type: string
|
||||
@ -1795,9 +1800,7 @@ paths:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: |
|
||||
Invalid request — no fields provided, or `preview_id` is the zero UUID
|
||||
(`INVALID_PREVIEW_ID`).
|
||||
description: Invalid request (no fields provided)
|
||||
"401":
|
||||
content:
|
||||
application/json:
|
||||
@ -1809,10 +1812,7 @@ paths:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: |
|
||||
Asset not found — returned both when the asset being updated does
|
||||
not exist and when `preview_id` does not reference an asset
|
||||
accessible to the caller.
|
||||
description: Asset not found
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
|
||||
196
tests-unit/comfy_api_test/video_bit_depth_test.py
Normal file
196
tests-unit/comfy_api_test/video_bit_depth_test.py
Normal file
@ -0,0 +1,196 @@
|
||||
import pytest
|
||||
import torch
|
||||
import av
|
||||
import numpy as np
|
||||
from fractions import Fraction
|
||||
from comfy_api.input_impl.video_types import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._input_impl.video_types import apply_video_input_accepts
|
||||
from comfy_api.util.video_types import VideoComponents
|
||||
from comfy_api.latest._util.video_types import VideoBitDepth
|
||||
|
||||
DECLARED = {"accepts": {"depth": 10}}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gradient_components():
|
||||
"""Narrow horizontal ramp (0.25..0.30) that needs more than 8 bits to stay smooth"""
|
||||
width, height, frames = 64, 64, 3
|
||||
ramp = torch.linspace(0.25, 0.30, width).view(1, 1, width, 1).expand(frames, height, width, 3)
|
||||
return VideoComponents(images=ramp.contiguous(), frame_rate=Fraction(30))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def src8(gradient_components, tmp_path_factory):
|
||||
"""8-bit h264 mp4 source file"""
|
||||
path = str(tmp_path_factory.mktemp("video") / "src8.mp4")
|
||||
VideoFromComponents(gradient_components).save_to(path)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def src10(gradient_components, tmp_path_factory):
|
||||
"""10-bit h264 mp4 source file"""
|
||||
path = str(tmp_path_factory.mktemp("video") / "src10.mp4")
|
||||
VideoFromComponents(gradient_components).save_to(path, bit_depth=VideoBitDepth.BIT_10)
|
||||
return path
|
||||
|
||||
|
||||
def probe(path):
|
||||
"""Return (codec, pix_fmt, bit_depth) of the first video stream"""
|
||||
with av.open(path) as container:
|
||||
stream = container.streams.video[0]
|
||||
return (
|
||||
stream.codec.name,
|
||||
stream.format.name,
|
||||
max(component.bits for component in stream.format.components),
|
||||
)
|
||||
|
||||
|
||||
def decoded_levels(path):
|
||||
"""Unique tonal levels in the first decoded frame (banding measure)"""
|
||||
with av.open(path) as container:
|
||||
frame = next(container.decode(container.streams.video[0]))
|
||||
return len(np.unique(frame.to_ndarray(format="gbrpf32le")[..., 0]))
|
||||
|
||||
|
||||
def video_packet_bytes(path):
|
||||
"""Raw video packet payloads; identical to the source's only for a true remux"""
|
||||
with av.open(path) as container:
|
||||
return [bytes(packet) for packet in container.demux(container.streams.video[0]) if packet.size]
|
||||
|
||||
|
||||
def test_components_save_bit_depths(src8, src10):
|
||||
"""Default save stays 8-bit h264; 10-bit keeps h264 and clearly reduces banding"""
|
||||
assert probe(src8) == ("h264", "yuv420p", 8)
|
||||
assert probe(src10) == ("h264", "yuv420p10le", 10)
|
||||
assert decoded_levels(src10) > 2 * decoded_levels(src8)
|
||||
|
||||
|
||||
def test_components_unsupported_codec_raises(gradient_components, tmp_path):
|
||||
with pytest.raises(ValueError, match="H264"):
|
||||
VideoFromComponents(gradient_components).save_to(str(tmp_path / "x.mp4"), codec="vp9")
|
||||
|
||||
|
||||
def test_bit_depth_enum():
|
||||
assert VideoBitDepth.as_input() == ["auto", "8-bit", "10-bit"]
|
||||
assert [d.bits() for d in VideoBitDepth] == [None, 8, 10]
|
||||
|
||||
|
||||
def test_10bit_source_remuxes_untouched(src10, tmp_path):
|
||||
"""auto and a cap of 10 both keep a 10-bit stream untouched"""
|
||||
for name, video in [("auto", VideoFromFile(src10)), ("cap10", VideoFromFile(src10).with_bit_depth_cap(10))]:
|
||||
path = str(tmp_path / f"{name}.mp4")
|
||||
video.save_to(path)
|
||||
assert probe(path) == ("h264", "yuv420p10le", 10)
|
||||
assert video_packet_bytes(path) == video_packet_bytes(src10)
|
||||
|
||||
|
||||
def test_8bit_source_remuxes_on_8bit_request(src8, tmp_path):
|
||||
"""Neither explicit 8-bit nor a cap of 8 re-encodes an already 8-bit source"""
|
||||
for name, save in [
|
||||
("explicit", lambda p: VideoFromFile(src8).save_to(p, bit_depth="8-bit")),
|
||||
("capped", lambda p: VideoFromFile(src8).with_bit_depth_cap(8).save_to(p)),
|
||||
]:
|
||||
path = str(tmp_path / f"{name}.mp4")
|
||||
save(path)
|
||||
assert video_packet_bytes(path) == video_packet_bytes(src8)
|
||||
|
||||
|
||||
def test_trim_keeps_source_depth(src10, tmp_path):
|
||||
"""A re-encode forced by trimming preserves the source's 10-bit depth"""
|
||||
path = str(tmp_path / "trim.mp4")
|
||||
VideoFromFile(src10).as_trimmed(start_time=0, duration=1 / 30, strict_duration=False).save_to(path)
|
||||
assert probe(path) == ("h264", "yuv420p10le", 10)
|
||||
|
||||
|
||||
def test_explicit_depth_mismatch_forces_reencode(src8, src10, tmp_path):
|
||||
"""An explicit depth that differs from the source's re-encodes instead of remuxing"""
|
||||
down = str(tmp_path / "down8.mp4")
|
||||
VideoFromFile(src10).save_to(down, bit_depth=VideoBitDepth.BIT_8)
|
||||
assert probe(down) == ("h264", "yuv420p", 8)
|
||||
|
||||
up = str(tmp_path / "up10.mp4")
|
||||
VideoFromFile(src8).save_to(up, bit_depth=VideoBitDepth.BIT_10)
|
||||
assert probe(up) == ("h264", "yuv420p10le", 10)
|
||||
|
||||
|
||||
def test_bit_depth_cap(src10, tmp_path):
|
||||
"""A cap of 8 makes saves default to 8-bit (also through as_trimmed), but an
|
||||
explicit request wins, and tensor access keeps full precision"""
|
||||
capped = VideoFromFile(src10).with_bit_depth_cap(8)
|
||||
|
||||
path = str(tmp_path / "capped.mp4")
|
||||
capped.save_to(path)
|
||||
assert probe(path) == ("h264", "yuv420p", 8)
|
||||
|
||||
trimmed = str(tmp_path / "trimmed.mp4")
|
||||
capped.as_trimmed(0, 1 / 30, strict_duration=False).save_to(trimmed)
|
||||
assert probe(trimmed) == ("h264", "yuv420p", 8)
|
||||
|
||||
explicit = str(tmp_path / "explicit10.mp4")
|
||||
capped.save_to(explicit, bit_depth=VideoBitDepth.BIT_10)
|
||||
assert probe(explicit) == ("h264", "yuv420p10le", 10)
|
||||
|
||||
images = capped.get_components().images
|
||||
assert images.dtype == torch.float32
|
||||
assert len(torch.unique(images[0, :, :, 0])) > 30 # ~13 levels if quantized to 8-bit
|
||||
|
||||
|
||||
def test_accepts_binding_policy(gradient_components, src10, tmp_path):
|
||||
"""Undeclared inputs get an 8-bit-capped copy of file videos; declared inputs
|
||||
get uncapped videos; everything else passes through untouched"""
|
||||
video = VideoFromFile(src10)
|
||||
|
||||
# undeclared input: capped copy that saves 8-bit
|
||||
[capped] = apply_video_input_accepts([video], {"tooltip": "x"})
|
||||
assert type(capped) is VideoFromFile and capped is not video
|
||||
bound = str(tmp_path / "bound.mp4")
|
||||
capped.save_to(bound)
|
||||
assert probe(bound) == ("h264", "yuv420p", 8)
|
||||
|
||||
# declared input: original passes through; a cap from an earlier binding is lifted
|
||||
assert apply_video_input_accepts([video], DECLARED)[0] is video
|
||||
[lifted] = apply_video_input_accepts([capped], DECLARED)
|
||||
lifted_path = str(tmp_path / "lifted.mp4")
|
||||
lifted.save_to(lifted_path)
|
||||
assert probe(lifted_path) == ("h264", "yuv420p10le", 10)
|
||||
|
||||
# declaring depth 8 is the same as not declaring
|
||||
assert apply_video_input_accepts([video], {"accepts": {"depth": 8}})[0] is not video
|
||||
|
||||
# subclasses, component videos, custom implementations, and non-videos pass through
|
||||
from comfy_api.latest._input import VideoInput as VideoInputABC
|
||||
|
||||
class SubVideo(VideoFromFile):
|
||||
pass
|
||||
|
||||
class CustomVideo(VideoInputABC):
|
||||
def get_components(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_to(self, path, format=None, codec=None, metadata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def as_trimmed(self, start_time=None, duration=None, strict_duration=False):
|
||||
return self
|
||||
|
||||
passthrough = [SubVideo(src10), VideoFromComponents(gradient_components), CustomVideo(), "not a video", None]
|
||||
assert apply_video_input_accepts(passthrough, None) == passthrough
|
||||
|
||||
|
||||
def test_accepts_declaration():
|
||||
"""Video.Input validates and serializes accepts; SaveVideo and VideoSlice declare it"""
|
||||
from comfy_api.latest import io
|
||||
import comfy_extras.nodes_video as nv
|
||||
from comfy_execution.graph import get_input_info
|
||||
|
||||
assert io.Video.Input("video", accepts={"depth": 10}).as_dict()["accepts"] == {"depth": 10}
|
||||
assert "accepts" not in io.Video.Input("video").as_dict()
|
||||
with pytest.raises(ValueError, match="Unsupported keys"):
|
||||
io.Video.Input("video", accepts={"codec": "h264"})
|
||||
with pytest.raises(ValueError, match="must be 8 or 10"):
|
||||
io.Video.Input("video", accepts={"depth": 12})
|
||||
|
||||
for node in (nv.SaveVideo, nv.VideoSlice):
|
||||
_, _, info = get_input_info(node, "video", node.INPUT_TYPES())
|
||||
assert info.get("accepts") == {"depth": 10}, node
|
||||
Reference in New Issue
Block a user