Compare commits

...

1 Commits

Author SHA1 Message Date
e4f3d335dc feat: Add VideoSlice node with lazy operations on VideoInput
- Add VideoOp base class and SliceOp in _input/video_types.py
- Add sliced() method to VideoInput that returns a copy with operation appended
- Each subclass applies operations in get_components() and get_frame_count()
- After materialization, VideoFromFile delegates to internal VideoFromComponents
- Add VideoSlice node that uses video.sliced(start_frame, frame_count)
- Add tests for SliceOp, sliced() behavior, and materialization
2026-01-23 20:52:15 -08:00
6 changed files with 263 additions and 13 deletions

View File

@ -1,10 +1,12 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .video_types import VideoInput
from .video_types import VideoInput, VideoOp, SliceOp
__all__ = [
"ImageInput",
"AudioInput",
"VideoInput",
"VideoOp",
"SliceOp",
"MaskInput",
"LatentInput",
]

View File

@ -1,11 +1,48 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from fractions import Fraction
from typing import Optional, Union, IO
import copy
import io
import av
from .._util import VideoContainer, VideoCodec, VideoComponents
class VideoOp(ABC):
"""Base class for lazy video operations."""
@abstractmethod
def apply(self, components: VideoComponents) -> VideoComponents:
pass
@abstractmethod
def compute_frame_count(self, input_frame_count: int) -> int:
pass
@dataclass(frozen=True)
class SliceOp(VideoOp):
"""Extract a range of frames from the video."""
start_frame: int
frame_count: int
def apply(self, components: VideoComponents) -> VideoComponents:
total = components.images.shape[0]
start = max(0, min(self.start_frame, total))
end = min(start + self.frame_count, total)
return VideoComponents(
images=components.images[start:end],
audio=components.audio,
frame_rate=components.frame_rate,
metadata=getattr(components, 'metadata', None),
)
def compute_frame_count(self, input_frame_count: int) -> int:
start = max(0, min(self.start_frame, input_frame_count))
return min(self.frame_count, input_frame_count - start)
class VideoInput(ABC):
"""
Abstract base class for video input types.
@ -21,6 +58,12 @@ class VideoInput(ABC):
"""
pass
def sliced(self, start_frame: int, frame_count: int) -> "VideoInput":
"""Return a copy of this video with a slice operation appended."""
new = copy.copy(self)
new._operations = getattr(self, '_operations', []) + [SliceOp(start_frame, frame_count)]
return new
@abstractmethod
def save_to(
self,

View File

@ -1,7 +1,8 @@
from .video_types import VideoFromFile, VideoFromComponents
from .._input import SliceOp
__all__ = [
# Implementations
"VideoFromFile",
"VideoFromComponents",
"SliceOp",
]

View File

@ -3,7 +3,7 @@ from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from .._input import AudioInput, VideoInput
from .._input import AudioInput, VideoInput, VideoOp
import av
import io
import json
@ -63,6 +63,8 @@ class VideoFromFile(VideoInput):
containing the file contents.
"""
self.__file = file
self._operations: list[VideoOp] = []
self.__materialized: Optional[VideoFromComponents] = None
def get_stream_source(self) -> str | io.BytesIO:
"""
@ -161,6 +163,10 @@ class VideoFromFile(VideoInput):
if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
# Apply operations to get final frame count
for op in self._operations:
frame_count = op.compute_frame_count(frame_count)
return frame_count
def get_frame_rate(self) -> Fraction:
@ -239,10 +245,18 @@ class VideoFromFile(VideoInput):
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents:
if self.__materialized is not None:
return self.__materialized.get_components()
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
return self.get_components_internal(container)
components = self.get_components_internal(container)
for op in self._operations:
components = op.apply(components)
self.__materialized = VideoFromComponents(components)
self._operations = []
return components
raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to(
@ -317,14 +331,27 @@ class VideoFromComponents(VideoInput):
def __init__(self, components: VideoComponents):
self.__components = components
self._operations: list[VideoOp] = []
def get_components(self) -> VideoComponents:
if self._operations:
components = self.__components
for op in self._operations:
components = op.apply(components)
self.__components = components
self._operations = []
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate
)
def get_frame_count(self) -> int:
count = int(self.__components.images.shape[0])
for op in self._operations:
count = op.compute_frame_count(count)
return count
def save_to(
self,
path: str,
@ -332,6 +359,9 @@ class VideoFromComponents(VideoInput):
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
# Materialize ops before saving
components = self.get_components()
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
@ -345,22 +375,22 @@ class VideoFromComponents(VideoInput):
for key, value in metadata.items():
output.metadata[key] = json.dumps(value)
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
frame_rate = Fraction(round(components.frame_rate * 1000), 1000)
# Create a video stream
video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2]
video_stream.height = self.__components.images.shape[1]
video_stream.width = components.images.shape[2]
video_stream.height = components.images.shape[1]
video_stream.pix_fmt = 'yuv420p'
# Create an audio stream
audio_sample_rate = 1
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
if components.audio:
audio_sample_rate = int(components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
# Encode video
for i, frame in enumerate(self.__components.images):
for i, frame in enumerate(components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
@ -371,9 +401,9 @@ class VideoFromComponents(VideoInput):
packet = video_stream.encode(None)
output.mux(packet)
if audio_stream and self.__components.audio:
waveform = self.__components.audio['waveform']
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
if audio_stream and components.audio:
waveform = components.audio['waveform']
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * components.images.shape[0])]
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
frame.sample_rate = audio_sample_rate
frame.pts = 0

View File

@ -159,6 +159,29 @@ class GetVideoComponents(io.ComfyNode):
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
class VideoSlice(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VideoSlice",
display_name="Video Slice",
category="image/video",
description="Extract a range of frames from a video.",
inputs=[
io.Video.Input("video", tooltip="The video to slice."),
io.Int.Input("start_frame", default=0, min=0, tooltip="The frame index to start from (0-indexed)."),
io.Int.Input("frame_count", default=1, min=1, tooltip="Number of frames to extract."),
],
outputs=[
io.Video.Output(tooltip="The sliced video."),
],
)
@classmethod
def execute(cls, video: Input.Video, start_frame: int, frame_count: int) -> io.NodeOutput:
return io.NodeOutput(video.sliced(start_frame, frame_count))
class LoadVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -206,6 +229,7 @@ class VideoExtension(ComfyExtension):
SaveVideo,
CreateVideo,
GetVideoComponents,
VideoSlice,
LoadVideo,
]

View File

@ -0,0 +1,150 @@
import pytest
import torch
import tempfile
import os
import av
from fractions import Fraction
from comfy_api.input_impl.video_types import (
VideoFromFile,
VideoFromComponents,
SliceOp,
)
from comfy_api.util.video_types import VideoComponents
def create_test_video(width=4, height=4, frames=10, fps=30):
"""Helper to create a temporary video file."""
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
with av.open(tmp.name, mode="w") as container:
stream = container.add_stream("h264", rate=fps)
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
for i in range(frames):
frame_data = torch.ones(height, width, 3, dtype=torch.uint8) * (i * 25)
frame = av.VideoFrame.from_ndarray(frame_data.numpy(), format="rgb24")
frame = frame.reformat(format="yuv420p")
packet = stream.encode(frame)
container.mux(packet)
packet = stream.encode(None)
container.mux(packet)
return tmp.name
@pytest.fixture
def video_file_10_frames():
file_path = create_test_video(frames=10)
yield file_path
os.unlink(file_path)
@pytest.fixture
def video_components_10_frames():
images = torch.rand(10, 4, 4, 3)
return VideoComponents(images=images, frame_rate=Fraction(30))
class TestSliceOp:
def test_apply_slices_correctly(self, video_components_10_frames):
op = SliceOp(start_frame=2, frame_count=3)
result = op.apply(video_components_10_frames)
assert result.images.shape[0] == 3
assert torch.equal(result.images, video_components_10_frames.images[2:5])
def test_compute_frame_count(self):
op = SliceOp(start_frame=2, frame_count=5)
assert op.compute_frame_count(10) == 5
def test_compute_frame_count_clamps(self):
op = SliceOp(start_frame=8, frame_count=5)
assert op.compute_frame_count(10) == 2
class TestVideoSliced:
def test_sliced_returns_new_instance(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert video is not sliced
assert len(video._operations) == 0
assert len(sliced._operations) == 1
def test_get_components_applies_operations(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert torch.equal(components.images, video_components_10_frames.images[2:5])
def test_get_frame_count(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert sliced.get_frame_count() == 3
def test_get_duration(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(0, 3)
assert sliced.get_duration() == pytest.approx(0.1)
def test_chained_slices_compose(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 6).sliced(1, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert torch.equal(components.images, video_components_10_frames.images[3:6])
def test_operations_list_is_immutable(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced1 = video.sliced(0, 5)
sliced2 = sliced1.sliced(1, 2)
assert len(video._operations) == 0
assert len(sliced1._operations) == 1
assert len(sliced2._operations) == 2
def test_from_file(self, video_file_10_frames):
video = VideoFromFile(video_file_10_frames)
sliced = video.sliced(2, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert sliced.get_frame_count() == 3
def test_save_sliced_video(self, video_components_10_frames, tmp_path):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
output_path = str(tmp_path / "sliced_output.mp4")
sliced.save_to(output_path)
saved_video = VideoFromFile(output_path)
assert saved_video.get_frame_count() == 3
def test_materialization_clears_ops(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert len(sliced._operations) == 1
sliced.get_components()
assert len(sliced._operations) == 0
def test_second_get_components_uses_cache(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
first = sliced.get_components()
second = sliced.get_components()
assert first.images.shape == second.images.shape
assert torch.equal(first.images, second.images)