Skip to content

Commit

Permalink
Add cv2.vidplay()
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikWin committed Oct 24, 2024
1 parent d9b7d7f commit 15bea4e
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 9 deletions.
61 changes: 61 additions & 0 deletions snake-pit/test_cv2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re

import cv2 as ocv_cv2
import vidformer.cv2 as vf_cv2
Expand Down Expand Up @@ -30,6 +31,28 @@ def test_constants():
assert ocv_cv2.LINE_AA == vf_cv2.LINE_AA


def test_cap_all_frames():
"""Make sure VideoCapture can read all frames of a video correctly."""
import vidformer.cv2 as cv2

cap = cv2.VideoCapture(VID_PATH)
assert cap.isOpened()

fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

while True:
ret, frame = cap.read()
if not ret:
break
assert frame.shape[0] == height
assert frame.shape[1] == width
assert frame.shape[2] == 3

cap.release()


def rw(cv2):
cap = cv2.VideoCapture(VID_PATH)
assert cap.isOpened()
Expand Down Expand Up @@ -103,6 +126,44 @@ def test_numpy():
assert frame_np.shape[2] == 3


def test_vidplay():
import vidformer as vf
import vidformer.cv2 as cv2

cap = cv2.VideoCapture(VID_PATH)
assert cap.isOpened()

fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

assert width == 1280
assert height == 720

out = cv2.VideoWriter(None, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))

count = 0
while True:
ret, frame = cap.read()
if not ret or count > 100:
break

out.write(frame)
count += 1

cap.release()
out.release()

# test vidplay on a Spec
spec = out.spec()
link = cv2.vidplay(spec, method="link")
assert re.match(r"http://localhost:\d+/\w+-\w+-\w+-\w+-\w+/stream.m3u8", link)

# test vidplay on a VideoWriter
link = cv2.vidplay(out, method="link")
assert re.match(r"http://localhost:\d+/\w+-\w+-\w+-\w+-\w+/stream.m3u8", link)


def rectangle(cv2):
cap = cv2.VideoCapture(VID_PATH)
assert cap.isOpened()
Expand Down
27 changes: 25 additions & 2 deletions vidformer-py/vidformer/cv2/vf_cv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def release(self):
class VideoWriter:
def __init__(self, path, fourcc, fps, size):
assert isinstance(fourcc, VideoWriter_fourcc)
if path is not None and not isinstance(path, str):
raise Exception("path must be a string or None")
self._path = path
self._fourcc = fourcc
self._fps = fps
Expand All @@ -164,11 +166,14 @@ def write(self, frame):
self._frames.append(frame._f)

def release(self):
spec = self.vf_spec()
if self._path is None:
return

spec = self.spec()
server = _server()
spec.save(server, self._path)

def vf_spec(self):
def spec(self) -> vf.Spec:
fmt = {
"width": self._size[0],
"height": self._size[1],
Expand Down Expand Up @@ -233,6 +238,24 @@ def imwrite(path, img, *args):
raise Exception("Unsupported image format")


def vidplay(video, *args, **kwargs):
"""
Play a vidformer video specification.
Args:
video: one of [vidformer.Spec, vidformer.Source, vidformer.cv2.VideoWriter]
"""

if isinstance(video, vf.Spec):
return video.play(_server(), *args, **kwargs)
elif isinstance(video, vf.Source):
return video.play(_server(), *args, **kwargs)
elif isinstance(video, VideoWriter):
return video.spec().play(_server(), *args, **kwargs)
else:
raise Exception("Unsupported video type to vidplay")


def rectangle(img, pt1, pt2, color, thickness=None, lineType=None, shift=None):
"""
cv.rectangle( img, pt1, pt2, color[, thickness[, lineType[, shift]]] )
Expand Down
29 changes: 22 additions & 7 deletions vidformer-py/vidformer/vf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,28 @@ def __init__(self, domain: list[Fraction], render, fmt: dict):
self._fmt = fmt

def __repr__(self):
lines = []
for i, t in enumerate(self._domain):
frame_expr = self._render(t, i)
lines.append(
f"{t.numerator}/{t.denominator} => {frame_expr}",
)
return "\n".join(lines)
if len(self._domain) <= 20:
lines = []
for i, t in enumerate(self._domain):
frame_expr = self._render(t, i)
lines.append(
f"{t.numerator}/{t.denominator} => {frame_expr}",
)
return "\n".join(lines)
else:
lines = []
for i, t in enumerate(self._domain[:10]):
frame_expr = self._render(t, i)
lines.append(
f"{t.numerator}/{t.denominator} => {frame_expr}",
)
lines.append("...")
for i, t in enumerate(self._domain[-10:]):
frame_expr = self._render(t, i)
lines.append(
f"{t.numerator}/{t.denominator} => {frame_expr}",
)
return "\n".join(lines)

def _sources(self):
s = set()
Expand Down

0 comments on commit 15bea4e

Please sign in to comment.