Skip to content
This repository has been archived by the owner on Mar 28, 2022. It is now read-only.

Commit

Permalink
Add segmentation data type (#66)
Browse files Browse the repository at this point in the history
* semantic map wip

* additional data types wip

* more data types wip

* merge

* add tests to segmentation data type

* remove accidental changes

* remove accidental changelog change

* fix for py27

* add more tests

* Update tests/test_data_types.py

* correctly handle array input

* add ability to pass 3-channel colormap as input

* fix default label check

* remove array input for segmentation, clarify 3-channel input

* bump to v0.1.0
  • Loading branch information
agermanidis authored Jun 12, 2019
1 parent 3dd4988 commit 3b4d1da
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## v0.1.0

- Add `segmentation` data type.

## v0.0.75

- Use PNG as default serialization format for images.
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Pillow>=4.3.0
gevent>=1.4.0
wget>=3.2
six>=1.12.0
colorcet>=2.0.1
2 changes: 1 addition & 1 deletion runway/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.75'
__version__ = '0.1.0'
127 changes: 126 additions & 1 deletion runway/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from io import BytesIO as IO
import numpy as np
from PIL import Image
from .utils import is_url, extract_tarball, try_cast_np_scalar, download_file
from .utils import is_url, extract_tarball, try_cast_np_scalar, download_file, get_color_palette
from .exceptions import MissingArgumentError, InvalidArgumentError

class BaseType(object):
Expand Down Expand Up @@ -476,3 +476,128 @@ def to_dict(self):
if self.is_directory: ret['isDirectory'] = self.is_directory
if self.extension: ret['extension'] = self.extension
return ret


class segmentation(BaseType):
"""A datatype that represents a pixel-level segmentation of an image.
Each pixel is annotated with a label id from 0-255, each corresponding to a
different object class.
When used as an input data type, `segmentation` accepts a 1-channel base64-encoded PNG image,
where each pixel takes the value of one of the ids defined in `pixel_to_id`, or a 3-channel
base64-encoded PNG colormap image, where each pixel takes the value of one of the colors
defined in `pixel_to_color`.
When used as an output data type, it serializes as a 1-channel base64-encoded PNG image,
where each pixel takes the value of one of the ids defined in `pixel_to_id`.
.. code-block:: python
import runway
from runway.data_types import segmentation, image
inputs = {"segmentation_map": segmentation(label_to_id={"background": 0, "person": 1})}
outputs = {"image": image()}
@runway.command("synthesize_pose", inputs=inputs, outputs=outputs)
def synthesize_human_pose(model, args):
result = model.convert(args["segmentation_map"])
return { "image": result }
:param description: A description of this variable and how its used in the model,
defaults to None
:type description: string, optional
:param label_to_id: A mapping from labels to pixel values from 0-255 corresponding to those labels
:type label_to_id: dict
:param default_label: The default label to use when a pixel value not in `label_to_id` is encountered
:type default_label: string, optional
:param label_to_color: A mapping from label names to colors to represent those labels
:type label_to_color: dict, optional
:param min_width: The minimum width of the segmentation image, defaults to None
:type min_width: int, optional
:param min_height: The minimum height of the segmentation image, defaults to None
:type min_height: int, optional
:param max_width: The maximum width of the segmentation image, defaults to None
:type max_width: int, optional
:param max_height: The maximum height of the segmentation image, defaults to None
:type max_height: int, optional
:param width: The width of the segmentation image, defaults to None.
:type width: int, optional
:param height: The height of the segmentation image, defaults to None
:type height: int, optional
"""
def __init__(self, description=None, label_to_id=None, label_to_color=None, default_label=None, min_width=None, min_height=None, max_width=None, max_height=None, width=None, height=None):
super(segmentation, self).__init__('segmentation', description=description)
if label_to_id is None:
raise MissingArgumentError('label_to_id')
if type(label_to_id) is not dict or len(label_to_id.keys()) == 0:
msg = 'label_to_id argument has invalid type'
raise InvalidArgumentError(msg)
if default_label is not None and default_label not in label_to_id.keys():
msg = 'default_label {} is not in label map'.format(default_label)
raise InvalidArgumentError(msg)
self.label_to_id = label_to_id
self.label_to_color = self.complete_colors(label_to_color or {})
self.default_label = default_label or list(self.label_to_id.keys())[0]
self.width = width
self.height = height
self.min_width = min_width
self.min_height = min_height
self.max_width = max_width
self.max_height = max_height

def complete_colors(self, seed_colors):
colors = {}
palette = get_color_palette('glasbey_bw')
for label, label_id in self.label_to_id.items():
if label in seed_colors:
colors[label] = seed_colors[label]
else:
colors[label] = palette[label_id]
return colors

def colormap_to_segmentation(self, img):
cmap = np.array(img)[:, :, :3]
seg = np.zeros(cmap.shape[:2], dtype=np.uint8)
for label, color in self.label_to_color.items():
label_id = self.label_to_id[label]
seg[(cmap==color).all(axis=2)] = label_id
return Image.fromarray(seg, 'L')

def deserialize(self, value):
try:
image = value[value.find(",")+1:]
image = base64.decodestring(image.encode('utf8'))
buffer = IO(image)
img = Image.open(buffer)
if img.mode.startswith('RGB'):
return self.colormap_to_segmentation(img)
else:
return img
except:
msg = 'unable to parse expected base64-encoded image'
raise InvalidArgumentError(msg)

def serialize(self, value):
if type(value) is np.ndarray:
im_pil = Image.fromarray(value)
elif issubclass(type(value), Image.Image):
im_pil = value
else:
raise InvalidArgumentError(self.name or self.type, 'value is not a PIL or numpy image')
buffer = IO()
im_pil.save(buffer, format='PNG')
return 'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode('utf8')

def to_dict(self):
ret = super(segmentation, self).to_dict()
ret['labels'] = list(self.label_to_id.keys())
ret['defaultLabel'] = self.default_label
ret['labelToId'] = self.label_to_id
ret['labelToColor'] = self.label_to_color
if self.min_width: ret['minWidth'] = self.min_width
if self.max_width: ret['maxWidth'] = self.max_width
if self.min_height: ret['minHeight'] = self.min_height
if self.max_height: ret['maxHeight'] = self.max_height
if self.width: ret['width'] = self.width
if self.height: ret['height'] = self.height
return ret
8 changes: 7 additions & 1 deletion runway/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import gzip
import datetime
import colorcet
if sys.version_info[0] < 3:
from cStringIO import StringIO as IO
else:
Expand Down Expand Up @@ -119,8 +120,13 @@ def timestamp_millis():
return int(offset.total_seconds() * 1000)


def get_color_palette(name):
palette = getattr(colorcet, name)
return [[int(c[0]*255), int(c[1]*255), int(c[2]*255)] for c in palette]


def argspec(fn):
if sys.version_info[0] < 3:
return inspect.getargspec(fn)
else:
return inspect.getfullargspec(fn)
return inspect.getfullargspec(fn)
48 changes: 48 additions & 0 deletions tests/test_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,51 @@ def test_image_serialize_invalid_type():

with pytest.raises(InvalidArgumentError):
image().serialize('data:image/jpeg;base64,')

# SEGMENTATION ------------------------------------------------------------------------
def test_segmentation_to_dict():
seg = segmentation(label_to_id={"background": 0, "person": 1}, label_to_color={'background': [0, 0, 0]}, width=512, height=512)
obj = seg.to_dict()
assert obj['type'] == 'segmentation'
assert obj['labelToId'] == {"background": 0, "person": 1}
assert obj['labelToColor'] == {"background": [0, 0, 0], "person": [140, 59, 255]}
assert obj['description'] == None

def test_segmentation_serialize_and_deserialize():
directory = os.path.dirname(os.path.realpath(__file__))
img = Image.open(os.path.join(directory, 'test_segmentation.png'))
serialized_pil = segmentation(label_to_id={"background": 0, "person": 1}).serialize(img)
deserialized_pil = segmentation(label_to_id={"background": 0, "person": 1}).deserialize(serialized_pil)
assert issubclass(type(deserialized_pil), Image.Image)

def test_segmentation_no_label_to_id():
with pytest.raises(MissingArgumentError):
segmentation()

def test_segmentation_invalid_label_to_id():
with pytest.raises(InvalidArgumentError):
segmentation(label_to_id={})

with pytest.raises(InvalidArgumentError):
segmentation(label_to_id=[])

def test_segmentation_invalid_default_label():
with pytest.raises(InvalidArgumentError):
segmentation(label_to_id={"background": 0, "person": 1}, default_label='building')

def test_segmentation_serialize_invalid_type():
with pytest.raises(InvalidArgumentError):
segmentation(label_to_id={"background": 0, "person": 1}).serialize(True)

with pytest.raises(InvalidArgumentError):
segmentation(label_to_id={"background": 0, "person": 1}).serialize([])

with pytest.raises(InvalidArgumentError):
segmentation(label_to_id={"background": 0, "person": 1}).serialize('data:image/jpeg;base64,')

def test_segmentation_deserialize_invalid_type():
with pytest.raises(InvalidArgumentError):
segmentation(label_to_id={"background": 0, "person": 1}).deserialize(True)

with pytest.raises(InvalidArgumentError):
segmentation(label_to_id={"background": 0, "person": 1}).deserialize('data:image/jpeg;base64,')
Binary file added tests/test_segmentation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 3b4d1da

Please sign in to comment.