Skip to content

Commit

Permalink
Merge pull request #41 from hhoppe:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588342735
  • Loading branch information
The mediapy Authors committed Dec 6, 2023
2 parents 0bbddd9 + ea4e934 commit d7948a5
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 35 deletions.
54 changes: 29 additions & 25 deletions mediapy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
from __future__ import annotations

__docformat__ = 'google'
__version__ = '1.1.9'
__version__ = '1.2.0'
__version_info__ = tuple(int(num) for num in __version__.split('.'))

import base64
Expand All @@ -116,7 +116,6 @@
import itertools
import math
import numbers
import os
import pathlib
import re
import shlex
Expand All @@ -129,12 +128,15 @@
import urllib.request

import IPython.display
import matplotlib
import matplotlib.pyplot
import numpy as np
import numpy.typing as npt
import PIL.Image
import PIL.ImageOps

if typing.TYPE_CHECKING:
import os # pylint: disable=g-bad-import-order

if not hasattr(PIL.Image, 'Resampling'): # Allow Pillow<9.0.
PIL.Image.Resampling = PIL.Image

Expand Down Expand Up @@ -187,7 +189,7 @@

_IPYTHON_HTML_SIZE_LIMIT = 20_000_000
_T = typing.TypeVar('_T')
_Path = typing.Union[str, os.PathLike]
_Path = typing.Union[str, 'os.PathLike[str]']

_IMAGE_COMPARISON_HTML = """\
<script
Expand Down Expand Up @@ -504,10 +506,10 @@ def generate_image(image_index: int) -> _NDArray:
"""Returns a video frame image."""
image = color_ramp(shape, dtype=dtype)
yx = np.moveaxis(np.indices(shape), 0, -1)
center = (shape[0] * 0.6, shape[1] * (image_index + 0.5) / num_images)
center = shape[0] * 0.6, shape[1] * (image_index + 0.5) / num_images
radius_squared = (min(shape) * 0.1) ** 2
inside = np.sum((yx - center) ** 2, axis=-1) < radius_squared
white_circle_color = (1.0, 1.0, 1.0)
white_circle_color = 1.0, 1.0, 1.0
if np.issubdtype(dtype, np.unsignedinteger):
white_circle_color = to_type([white_circle_color], dtype)[0]
image[inside] = white_circle_color
Expand Down Expand Up @@ -837,7 +839,7 @@ def to_rgb(
a = (a.astype('float') - vmin) / (vmax - vmin + np.finfo(float).eps)
if isinstance(cmap, str):
if hasattr(matplotlib, 'colormaps'):
rgb_from_scalar = matplotlib.colormaps[cmap] # Newer version.
rgb_from_scalar: Any = matplotlib.colormaps[cmap] # Newer version.
else:
rgb_from_scalar = matplotlib.pyplot.cm.get_cmap(cmap)
else:
Expand Down Expand Up @@ -1234,19 +1236,20 @@ def _get_video_metadata(path: _Path) -> VideoMetadata:
fps = 10
else:
raise RuntimeError(f'Unable to parse video framerate in line {line}')
if (
(match := re.fullmatch(r'\s*rotate\s*:\s*(\d+)', line)) or
(match := re.fullmatch(r'\s*.*rotation of -?(\d+)\s*.*\sdegrees', line))
):
if match := re.fullmatch(r'\s*rotate\s*:\s*(\d+)', line):
rotation = int(match.group(1))
if match := re.fullmatch(r'.*rotation of (-?\d+).*\sdegrees', line):
rotation = int(match.group(1))
if not num_images:
raise RuntimeError(f'Unable to find frames in video: {err}')
if not width:
raise RuntimeError(f'Unable to parse video header: {err}')
# By default, ffmpeg enables "-autorotate"; we just fix the dimensions.
if rotation in (90, 270):
if rotation in (90, 270, -90, -270):
width, height = height, width
shape = (height, width)
assert height is not None and width is not None
shape = height, width
assert fps is not None
return VideoMetadata(num_images, shape, fps, bps)


Expand Down Expand Up @@ -1388,7 +1391,8 @@ def read(self) -> _NDArray | None:
array with 3 color channels, except for format 'gray' which is 2D.
"""
assert self._proc, 'Error: reading from an already closed context.'
assert (stdout := self._proc.stdout) is not None
stdout = self._proc.stdout
assert stdout is not None
data = stdout.read(self._num_bytes_per_image)
if not data: # Due to either end-of-file or subprocess error.
self.close() # Raises exception if subprocess had error.
Expand Down Expand Up @@ -1427,7 +1431,7 @@ def close(self) -> None:
class VideoWriter(_VideoIO):
"""Context to write a compressed video.
>>> shape = (480, 640)
>>> shape = 480, 640
>>> with VideoWriter('/tmp/v.mp4', shape, fps=60) as writer:
... for image in moving_circle(shape, num_images=60):
... writer.add_image(image)
Expand Down Expand Up @@ -1644,7 +1648,8 @@ def add_image(self, image: _NDArray) -> None:
if self.input_format == 'yuv': # Convert from per-pixel YUV to planar YUV.
image = np.moveaxis(image, 2, 0)
data = image.tobytes()
assert (stdin := self._proc.stdin) is not None
stdin = self._proc.stdin
assert stdin is not None
if stdin.write(data) != len(data):
self._proc.wait()
stderr = self._proc.stderr
Expand All @@ -1655,8 +1660,9 @@ def add_image(self, image: _NDArray) -> None:
def close(self) -> None:
"""Finishes writing the video. (Called automatically at end of context.)"""
if self._popen:
assert self._proc and self._proc.stdin and self._proc.stderr
assert (stdin := self._proc.stdin) is not None
assert self._proc, 'Error: closing an already closed context.'
stdin = self._proc.stdin
assert stdin is not None
stdin.close()
if self._proc.wait():
stderr = self._proc.stderr
Expand Down Expand Up @@ -1910,9 +1916,9 @@ def show_videos(
'Cannot have both a video dictionary and a titles parameter.'
)
list_titles = list(videos.keys())
list_videos: list[Iterable[_NDArray]] = list(videos.values())
list_videos = list(videos.values())
else:
list_videos = list(videos)
list_videos = list(typing.cast('Iterable[_NDArray]', videos))
list_titles = [None] * len(list_videos) if titles is None else list(titles)
if len(list_videos) != len(list_titles):
raise ValueError(
Expand All @@ -1926,10 +1932,8 @@ def show_videos(
for video, title in zip(list_videos, list_titles):
metadata: VideoMetadata | None = getattr(video, 'metadata', None)
first_image, video = _peek_first(video)
w, h = _get_width_height(
width, height, first_image.shape[:2] # type: ignore[arg-type]
)
if downsample and (w < first_image.shape[1] or h < first_image.shape[0]): # pytype: disable=attribute-error
w, h = _get_width_height(width, height, first_image.shape[:2])
if downsample and (w < first_image.shape[1] or h < first_image.shape[0]):
# Not resize_video() because each image may have different depth and type.
video = [resize_image(image, (h, w)) for image in video]
first_image = video[0]
Expand All @@ -1942,7 +1946,7 @@ def show_videos(
with _open(path, mode='wb') as f:
f.write(data)
if codec == 'gif':
pixelated = h > first_image.shape[0] or w > first_image.shape[1] # pytype: disable=attribute-error
pixelated = h > first_image.shape[0] or w > first_image.shape[1]
html_string = html_from_compressed_image(
data, w, h, title=title, fmt='gif', pixelated=pixelated, **kwargs
)
Expand Down
16 changes: 9 additions & 7 deletions mediapy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_to_uint8(self):
)

def test_color_ramp_float(self):
shape = (2, 3)
shape = 2, 3
image = media.color_ramp(shape=shape)
self.assert_all_equal(image.shape[:2], shape)
self.assert_all_close(
Expand All @@ -242,7 +242,7 @@ def test_color_ramp_float(self):
)

def test_color_ramp_uint8(self):
shape = (1, 3)
shape = 1, 3
image = media.color_ramp(shape=shape, dtype=np.uint8)
self.assert_all_equal(image.shape[:2], shape)
expected = [[
Expand Down Expand Up @@ -565,7 +565,7 @@ def test_compare_images(self):

@parameterized.parameters(False, True)
def test_video_non_streaming_write_read_roundtrip(self, use_generator):
shape = (240, 320)
shape = 240, 320
num_images = 10
fps = 40
qp = 20
Expand All @@ -577,14 +577,15 @@ def test_video_non_streaming_write_read_roundtrip(self, use_generator):
tmp_path = pathlib.Path(directory_name) / 'test.mp4'
media.write_video(tmp_path, video, fps=fps, qp=qp)
new_video = media.read_video(tmp_path)
assert new_video.metadata
self.assertEqual(new_video.metadata.num_images, num_images)
self.assertEqual(new_video.metadata.shape, shape)
self.assertEqual(new_video.metadata.fps, fps)
self.assertGreater(new_video.metadata.bps, 1_000)
self._check_similar(original_video, new_video, 3.0)

def test_video_streaming_write_read_roundtrip(self):
shape = (62, 744)
shape = 62, 744
num_images = 20
fps = 120
bps = 400_000
Expand All @@ -610,7 +611,7 @@ def test_video_streaming_write_read_roundtrip(self):
self._check_similar(images[index], new_image, 7.0, f'index={index}')

def test_video_streaming_read_write(self):
shape = (400, 400)
shape = 400, 400
num_images = 4
fps = 25
bps = 40_000_000
Expand All @@ -632,14 +633,15 @@ def test_video_streaming_read_write(self):
writer.add_image(image)

new_video = media.read_video(path2)
assert new_video.metadata
self.assertEqual(new_video.metadata.num_images, num_images)
self.assertEqual(new_video.metadata.shape, shape)
self.assertEqual(new_video.metadata.fps, fps)
self.assertGreater(new_video.metadata.bps, 1_000)
self._check_similar(video, new_video, 3.0)

def test_video_read_write_10bit(self):
shape = (256, 256)
shape = 256, 256
num_images = 4
fps = 60
bps = 40_000_000
Expand Down Expand Up @@ -690,7 +692,7 @@ def test_compress_decompress_video_roundtrip(self):
self._check_similar(video, new_video, max_rms=8.0)

def test_html_from_compressed_video(self):
shape = (240, 320)
shape = 240, 320
video = media.moving_circle(shape, 10)
text = media.html_from_compressed_video(
media.compress_video(video), shape[1], shape[0]
Expand Down
1 change: 0 additions & 1 deletion pdoc_files/make.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

#!/usr/bin/env python3
"""Create HTML documentation from the source code using `pdoc`."""
# Note: Invoke this from the parent directory as "python3 pdoc_files/make.py".

Expand Down
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dev = [
"pylint>=2.6.0",
"pytest",
"pytest-xdist",
"pytype",
]

[build-system]
Expand Down Expand Up @@ -75,8 +76,8 @@ extend-exclude = "\\.ipynb"
disable = [
"unspecified-encoding", "line-too-long", "too-many-lines",
"too-few-public-methods", "too-many-locals", "too-many-instance-attributes",
"too-many-branches", "too-many-statements", "using-constant-test",
"wrong-import-order", "use-dict-literal",
"too-many-branches", "too-many-statements", "too-many-arguments",
"using-constant-test", "wrong-import-order", "use-dict-literal",
]
reports = false
score = false
Expand All @@ -88,3 +89,9 @@ good-names-rgxs = "^[a-z][a-z0-9]?|[A-Z]([A-Z_]*[A-Z])?$"
[tool.pylint.format]
indent-string = " "
expected-line-ending-format = "LF"

[tool.pytype]
keep_going = true
strict_none_binding = true
use_enum_overlay = true
use_fiddle_overlay = true

0 comments on commit d7948a5

Please sign in to comment.