add get_frame_count and get_frame_rate methods to VideoInput class (#10851)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from fractions import Fraction
|
||||
from typing import Optional, Union, IO
|
||||
import io
|
||||
import av
|
||||
@@ -72,6 +73,33 @@ class VideoInput(ABC):
|
||||
frame_count = components.images.shape[0]
|
||||
return float(frame_count / components.frame_rate)
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
"""
|
||||
Returns the number of frames in the video.
|
||||
|
||||
Default implementation uses :meth:`get_components`, which may require
|
||||
loading all frames into memory. File-based implementations should
|
||||
override this method and use container/stream metadata instead.
|
||||
|
||||
Returns:
|
||||
Total number of frames as an integer.
|
||||
"""
|
||||
return int(self.get_components().images.shape[0])
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
"""
|
||||
Returns the frame rate of the video.
|
||||
|
||||
Default implementation materializes the video into memory via
|
||||
`get_components()`. Subclasses that can inspect the underlying
|
||||
container (e.g. `VideoFromFile`) should override this with a more
|
||||
efficient implementation.
|
||||
|
||||
Returns:
|
||||
Frame rate as a Fraction.
|
||||
"""
|
||||
return self.get_components().frame_rate
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
|
||||
@@ -121,6 +121,71 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
"""
|
||||
Returns the number of frames in the video without materializing them as
|
||||
torch tensors.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# 1. Prefer the frames field if available
|
||||
if video_stream.frames and video_stream.frames > 0:
|
||||
return int(video_stream.frames)
|
||||
|
||||
# 2. Try to estimate from duration and average_rate using only metadata
|
||||
if container.duration is not None and video_stream.average_rate:
|
||||
duration_seconds = float(container.duration / av.time_base)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
if (
|
||||
getattr(video_stream, "duration", None) is not None
|
||||
and getattr(video_stream, "time_base", None) is not None
|
||||
and video_stream.average_rate
|
||||
):
|
||||
duration_seconds = float(video_stream.duration * video_stream.time_base)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
# 3. Last resort: decode frames and count them (streaming)
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
|
||||
if frame_count == 0:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||
return frame_count
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
"""
|
||||
Returns the average frame rate of the video using container metadata
|
||||
without decoding all frames.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
|
||||
if video_stream.average_rate:
|
||||
return Fraction(video_stream.average_rate)
|
||||
|
||||
# Fallback: estimate from frames + duration if available
|
||||
if video_stream.frames and container.duration:
|
||||
duration_seconds = float(container.duration / av.time_base)
|
||||
if duration_seconds > 0:
|
||||
return Fraction(video_stream.frames / duration_seconds).limit_denominator()
|
||||
|
||||
# Last resort: match get_components_internal default
|
||||
return Fraction(1)
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
@@ -238,6 +303,13 @@ class VideoFromFile(VideoInput):
|
||||
packet.stream = stream_map[packet.stream]
|
||||
output_container.mux(packet)
|
||||
|
||||
def _get_first_video_stream(self, container: InputContainer):
|
||||
video_stream = next((s for s in container.streams if s.type == "video"), None)
|
||||
if video_stream is None:
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
return video_stream
|
||||
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
"""
|
||||
Class representing video input from tensors.
|
||||
|
||||
Reference in New Issue
Block a user