mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 15:56:18 +08:00
Compare commits
1 Commits
v0.11.0
...
cb/video-s
| Author | SHA1 | Date | |
|---|---|---|---|
| e4f3d335dc |
@ -1,10 +1,12 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .video_types import VideoInput
|
||||
from .video_types import VideoInput, VideoOp, SliceOp
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"VideoInput",
|
||||
"VideoOp",
|
||||
"SliceOp",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
]
|
||||
|
||||
@ -1,11 +1,48 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from fractions import Fraction
|
||||
from typing import Optional, Union, IO
|
||||
import copy
|
||||
import io
|
||||
import av
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
class VideoOp(ABC):
|
||||
"""Base class for lazy video operations."""
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, components: VideoComponents) -> VideoComponents:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_frame_count(self, input_frame_count: int) -> int:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SliceOp(VideoOp):
|
||||
"""Extract a range of frames from the video."""
|
||||
start_frame: int
|
||||
frame_count: int
|
||||
|
||||
def apply(self, components: VideoComponents) -> VideoComponents:
|
||||
total = components.images.shape[0]
|
||||
start = max(0, min(self.start_frame, total))
|
||||
end = min(start + self.frame_count, total)
|
||||
return VideoComponents(
|
||||
images=components.images[start:end],
|
||||
audio=components.audio,
|
||||
frame_rate=components.frame_rate,
|
||||
metadata=getattr(components, 'metadata', None),
|
||||
)
|
||||
|
||||
def compute_frame_count(self, input_frame_count: int) -> int:
|
||||
start = max(0, min(self.start_frame, input_frame_count))
|
||||
return min(self.frame_count, input_frame_count - start)
|
||||
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
Abstract base class for video input types.
|
||||
@ -21,6 +58,12 @@ class VideoInput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def sliced(self, start_frame: int, frame_count: int) -> "VideoInput":
|
||||
"""Return a copy of this video with a slice operation appended."""
|
||||
new = copy.copy(self)
|
||||
new._operations = getattr(self, '_operations', []) + [SliceOp(start_frame, frame_count)]
|
||||
return new
|
||||
|
||||
@abstractmethod
|
||||
def save_to(
|
||||
self,
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from .video_types import VideoFromFile, VideoFromComponents
|
||||
from .._input import SliceOp
|
||||
|
||||
__all__ = [
|
||||
# Implementations
|
||||
"VideoFromFile",
|
||||
"VideoFromComponents",
|
||||
"SliceOp",
|
||||
]
|
||||
|
||||
@ -3,7 +3,7 @@ from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from .._input import AudioInput, VideoInput
|
||||
from .._input import AudioInput, VideoInput, VideoOp
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
@ -63,6 +63,8 @@ class VideoFromFile(VideoInput):
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
self._operations: list[VideoOp] = []
|
||||
self.__materialized: Optional[VideoFromComponents] = None
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
@ -161,6 +163,10 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
if frame_count == 0:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||
|
||||
# Apply operations to get final frame count
|
||||
for op in self._operations:
|
||||
frame_count = op.compute_frame_count(frame_count)
|
||||
return frame_count
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
@ -239,10 +245,18 @@ class VideoFromFile(VideoInput):
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if self.__materialized is not None:
|
||||
return self.__materialized.get_components()
|
||||
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return self.get_components_internal(container)
|
||||
components = self.get_components_internal(container)
|
||||
for op in self._operations:
|
||||
components = op.apply(components)
|
||||
self.__materialized = VideoFromComponents(components)
|
||||
self._operations = []
|
||||
return components
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def save_to(
|
||||
@ -317,14 +331,27 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
self.__components = components
|
||||
self._operations: list[VideoOp] = []
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if self._operations:
|
||||
components = self.__components
|
||||
for op in self._operations:
|
||||
components = op.apply(components)
|
||||
self.__components = components
|
||||
self._operations = []
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
)
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
count = int(self.__components.images.shape[0])
|
||||
for op in self._operations:
|
||||
count = op.compute_frame_count(count)
|
||||
return count
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
@ -332,6 +359,9 @@ class VideoFromComponents(VideoInput):
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
# Materialize ops before saving
|
||||
components = self.get_components()
|
||||
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
@ -345,22 +375,22 @@ class VideoFromComponents(VideoInput):
|
||||
for key, value in metadata.items():
|
||||
output.metadata[key] = json.dumps(value)
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
frame_rate = Fraction(round(components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.width = components.images.shape[2]
|
||||
video_stream.height = components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
if components.audio:
|
||||
audio_sample_rate = int(components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
for i, frame in enumerate(components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
@ -371,9 +401,9 @@ class VideoFromComponents(VideoInput):
|
||||
packet = video_stream.encode(None)
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
if audio_stream and components.audio:
|
||||
waveform = components.audio['waveform']
|
||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * components.images.shape[0])]
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
|
||||
frame.sample_rate = audio_sample_rate
|
||||
frame.pts = 0
|
||||
|
||||
@ -159,6 +159,29 @@ class GetVideoComponents(io.ComfyNode):
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
||||
|
||||
|
||||
class VideoSlice(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VideoSlice",
|
||||
display_name="Video Slice",
|
||||
category="image/video",
|
||||
description="Extract a range of frames from a video.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to slice."),
|
||||
io.Int.Input("start_frame", default=0, min=0, tooltip="The frame index to start from (0-indexed)."),
|
||||
io.Int.Input("frame_count", default=1, min=1, tooltip="Number of frames to extract."),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(tooltip="The sliced video."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: Input.Video, start_frame: int, frame_count: int) -> io.NodeOutput:
|
||||
return io.NodeOutput(video.sliced(start_frame, frame_count))
|
||||
|
||||
|
||||
class LoadVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -206,6 +229,7 @@ class VideoExtension(ComfyExtension):
|
||||
SaveVideo,
|
||||
CreateVideo,
|
||||
GetVideoComponents,
|
||||
VideoSlice,
|
||||
LoadVideo,
|
||||
]
|
||||
|
||||
|
||||
150
tests-unit/comfy_api_test/video_slice_test.py
Normal file
150
tests-unit/comfy_api_test/video_slice_test.py
Normal file
@ -0,0 +1,150 @@
|
||||
import pytest
|
||||
import torch
|
||||
import tempfile
|
||||
import os
|
||||
import av
|
||||
from fractions import Fraction
|
||||
from comfy_api.input_impl.video_types import (
|
||||
VideoFromFile,
|
||||
VideoFromComponents,
|
||||
SliceOp,
|
||||
)
|
||||
from comfy_api.util.video_types import VideoComponents
|
||||
|
||||
|
||||
def create_test_video(width=4, height=4, frames=10, fps=30):
|
||||
"""Helper to create a temporary video file."""
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
||||
with av.open(tmp.name, mode="w") as container:
|
||||
stream = container.add_stream("h264", rate=fps)
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
for i in range(frames):
|
||||
frame_data = torch.ones(height, width, 3, dtype=torch.uint8) * (i * 25)
|
||||
frame = av.VideoFrame.from_ndarray(frame_data.numpy(), format="rgb24")
|
||||
frame = frame.reformat(format="yuv420p")
|
||||
packet = stream.encode(frame)
|
||||
container.mux(packet)
|
||||
|
||||
packet = stream.encode(None)
|
||||
container.mux(packet)
|
||||
|
||||
return tmp.name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_file_10_frames():
|
||||
file_path = create_test_video(frames=10)
|
||||
yield file_path
|
||||
os.unlink(file_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_components_10_frames():
|
||||
images = torch.rand(10, 4, 4, 3)
|
||||
return VideoComponents(images=images, frame_rate=Fraction(30))
|
||||
|
||||
|
||||
class TestSliceOp:
|
||||
def test_apply_slices_correctly(self, video_components_10_frames):
|
||||
op = SliceOp(start_frame=2, frame_count=3)
|
||||
result = op.apply(video_components_10_frames)
|
||||
|
||||
assert result.images.shape[0] == 3
|
||||
assert torch.equal(result.images, video_components_10_frames.images[2:5])
|
||||
|
||||
def test_compute_frame_count(self):
|
||||
op = SliceOp(start_frame=2, frame_count=5)
|
||||
assert op.compute_frame_count(10) == 5
|
||||
|
||||
def test_compute_frame_count_clamps(self):
|
||||
op = SliceOp(start_frame=8, frame_count=5)
|
||||
assert op.compute_frame_count(10) == 2
|
||||
|
||||
|
||||
class TestVideoSliced:
|
||||
def test_sliced_returns_new_instance(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
assert video is not sliced
|
||||
assert len(video._operations) == 0
|
||||
assert len(sliced._operations) == 1
|
||||
|
||||
def test_get_components_applies_operations(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
components = sliced.get_components()
|
||||
|
||||
assert components.images.shape[0] == 3
|
||||
assert torch.equal(components.images, video_components_10_frames.images[2:5])
|
||||
|
||||
def test_get_frame_count(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
assert sliced.get_frame_count() == 3
|
||||
|
||||
def test_get_duration(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(0, 3)
|
||||
|
||||
assert sliced.get_duration() == pytest.approx(0.1)
|
||||
|
||||
def test_chained_slices_compose(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 6).sliced(1, 3)
|
||||
|
||||
components = sliced.get_components()
|
||||
|
||||
assert components.images.shape[0] == 3
|
||||
assert torch.equal(components.images, video_components_10_frames.images[3:6])
|
||||
|
||||
def test_operations_list_is_immutable(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced1 = video.sliced(0, 5)
|
||||
sliced2 = sliced1.sliced(1, 2)
|
||||
|
||||
assert len(video._operations) == 0
|
||||
assert len(sliced1._operations) == 1
|
||||
assert len(sliced2._operations) == 2
|
||||
|
||||
def test_from_file(self, video_file_10_frames):
|
||||
video = VideoFromFile(video_file_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
components = sliced.get_components()
|
||||
|
||||
assert components.images.shape[0] == 3
|
||||
assert sliced.get_frame_count() == 3
|
||||
|
||||
def test_save_sliced_video(self, video_components_10_frames, tmp_path):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
output_path = str(tmp_path / "sliced_output.mp4")
|
||||
sliced.save_to(output_path)
|
||||
|
||||
saved_video = VideoFromFile(output_path)
|
||||
assert saved_video.get_frame_count() == 3
|
||||
|
||||
def test_materialization_clears_ops(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
assert len(sliced._operations) == 1
|
||||
sliced.get_components()
|
||||
assert len(sliced._operations) == 0
|
||||
|
||||
def test_second_get_components_uses_cache(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
first = sliced.get_components()
|
||||
second = sliced.get_components()
|
||||
|
||||
assert first.images.shape == second.images.shape
|
||||
assert torch.equal(first.images, second.images)
|
||||
Reference in New Issue
Block a user