Skip to content

Commit

Permalink
Merge pull request #2211 from AdeelH/raster_stats_get_chip
Browse files Browse the repository at this point in the history
Compute `RasterStats` from transformed `RasterSource`
  • Loading branch information
AdeelH authored Aug 6, 2024
2 parents d5e3b1c + 906ff30 commit b263422
Show file tree
Hide file tree
Showing 25 changed files with 153 additions and 123 deletions.
2 changes: 1 addition & 1 deletion rastervision_core/rastervision/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def register_plugin(registry):
registry.set_plugin_version('rastervision.core', 13)
registry.set_plugin_version('rastervision.core', 14)
from rastervision.core.cli import predict, predict_scene
registry.add_plugin_command(predict)
registry.add_plugin_command(predict_scene)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,6 @@ def get_chip(self, window: Box,
chip = chip[..., self.channel_order]

for transformer in self.raster_transformers:
chip = transformer.transform(chip, self.channel_order)
chip = transformer.transform(chip)

return chip
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def validate_temporal(self) -> Self:
def build(self, tmp_dir: str | None = None,
use_transformers: bool = True) -> MultiRasterSource:
if use_transformers:
raster_transformers = [t.build() for t in self.transformers]
raster_transformers = [
t.build(channel_order=self.channel_order)
for t in self.transformers
]
else:
raster_transformers = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_chip(self, window: 'Box',
chip = chip[..., self.channel_order]

for transformer in self.raster_transformers:
chip = transformer.transform(chip, self.channel_order)
chip = transformer.transform(chip)

return chip

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def get_chip(self,
bands = self.bands_to_read[bands]
chip = self._get_chip(window, out_shape=out_shape, bands=bands)
for transformer in self.raster_transformers:
chip = transformer.transform(chip, self.channel_order)
chip = transformer.transform(chip)
return chip

def __getitem__(self, key: Any) -> 'np.ndarray':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,15 @@ class RasterioSourceConfig(RasterSourceConfig):
False,
description='Stream assets as needed rather than downloading them.')

def build(self, tmp_dir, use_transformers=True):
raster_transformers = ([rt.build() for rt in self.transformers]
if use_transformers else [])
def build(self, tmp_dir: str | None,
use_transformers: bool = True) -> RasterioSource:
if use_transformers:
raster_transformers = [
t.build(channel_order=self.channel_order)
for t in self.transformers
]
else:
raster_transformers = []
bbox = Box(*self.bbox) if self.bbox is not None else None
return RasterioSource(
uris=self.uris,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_chip(self, window: Box,
chip = np.stack(sub_chips)

for transformer in self.raster_transformers:
chip = transformer.transform(chip, self.channel_order)
chip = transformer.transform(chip)

return chip

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def get_chip(self,
chip = self._get_chip(
window, bands=bands, time=time, out_shape=out_shape)
for transformer in self.raster_transformers:
chip = transformer.transform(chip, bands)
chip = transformer.transform(chip)
return chip

def __getitem__(self, key: Any) -> 'np.ndarray':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,13 @@ class XarraySourceConfig(RasterSourceConfig):
def build(self, tmp_dir: str | None = None,
use_transformers: bool = True) -> XarraySource:
item_or_item_collection = self.stac.build()
raster_transformers = ([rt.build() for rt in self.transformers]
if use_transformers else [])
if use_transformers:
raster_transformers = [
t.build(channel_order=self.channel_order)
for t in self.transformers
]
else:
raster_transformers = []
raster_source = XarraySource.from_stac(
item_or_item_collection,
raster_transformers=raster_transformers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ def __init__(self, to_dtype: str):
def __repr__(self):
return repr_with_args(self, to_dtype=str(self.to_dtype))

def transform(self, chip: np.ndarray,
channel_order: list | None = None) -> np.ndarray:
"""Cast chip to self.to_dtype.
def transform(self, chip: np.ndarray) -> np.ndarray:
"""Cast chip to dtype ``self.to_dtype``.
Args:
chip: ndarray of shape [height, width, channels]
chip: Array of shape (..., H, W, C).
Returns:
[height, width, channels] numpy array
Array of shape (..., H, W, C)
"""
return chip.astype(self.to_dtype)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from rastervision.pipeline.config import register_config, Field
from rastervision.core.data.raster_transformer.raster_transformer_config import ( # noqa
RasterTransformerConfig)
from rastervision.core.data.raster_transformer.cast_transformer import ( # noqa
from rastervision.core.data.raster_transformer.cast_transformer import (
CastTransformer)


Expand All @@ -14,5 +14,5 @@ class CastTransformerConfig(RasterTransformerConfig):
description='dtype to cast raster to. Must be a valid Numpy dtype '
'e.g. "uint8", "float32", etc.')

def build(self):
def build(self, channel_order: list[int] | None = None):
return CastTransformer(to_dtype=self.to_dtype)
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
class MinMaxTransformer(RasterTransformer):
"""Transforms chips by scaling values in each channel to span 0-255."""

def transform(self,
chip: np.ndarray,
channel_order: list[int] | None = None) -> np.ndarray:
def transform(self, chip: np.ndarray) -> np.ndarray:
c = chip.shape[-1]
pixels = chip.reshape(-1, c)
channel_mins = pixels.min(axis=0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
class MinMaxTransformerConfig(RasterTransformerConfig):
"""Configure a :class:`.MinMaxTransformer`."""

def build(self):
def build(self, channel_order: list[int] | None = None):
return MinMaxTransformer()
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,22 @@ class NanTransformer(RasterTransformer):
"""Removes NaN values from float raster."""

def __init__(self, to_value: float = 0.0):
"""Construct a new NanTransformer.
"""Constructor.
Args:
to_value: (float) NaN values are replaced
with this
to_value: NaN values are replaced with this.
"""
self.to_value = to_value

def transform(self, chip, channel_order=None):
"""Transform a chip.
Removes NaN values.
def transform(self, chip):
"""Removes NaN values.
Args:
chip: ndarray of shape [height, width, channels] This is assumed to already
have the channel_order applied to it if channel_order is set. In other
words, channels should be equal to len(channel_order).
chip: Array of shape (..., H, W, C).
Returns:
[height, width, channels] numpy array
Array of shape (..., H, W, C)
"""
chip[np.isnan(chip)] = self.to_value
nan_mask = np.isnan(chip)
chip[nan_mask] = self.to_value
return chip
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ class NanTransformerConfig(RasterTransformerConfig):
to_value: float | None = Field(
0.0, description=('Turn all NaN values into this value.'))

def build(self):
def build(self, channel_order: list[int] | None = None):
return NanTransformer(to_value=self.to_value)
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,12 @@ class RasterTransformer(ABC):
"""Transforms raw chips to be input to a neural network."""

@abstractmethod
def transform(self, chip: 'np.ndarray',
channel_order=None) -> 'np.ndarray':
def transform(self, chip: 'np.ndarray') -> 'np.ndarray':
"""Transform a chip of a raster source.
Args:
chip: ndarray of shape [height, width, channels] This is assumed to already
have the channel_order applied to it if channel_order is set. In other
words, channels should be equal to len(channel_order).
channel_order: list of indices of channels that were extracted from the
raw imagery.
chip: Array of shape (..., H, W, C).
Returns:
(np.ndarray): Array of shape (..., H, W, C)
Array of shape (..., H, W, C)
"""
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rastervision.pipeline.config import Config, register_config

if TYPE_CHECKING:
from rastervision.core.data import SceneConfig
from rastervision.core.data import RasterTransformer, SceneConfig
from rastervision.core.rv_pipeline import RVPipelineConfig


Expand All @@ -17,3 +17,7 @@ def update(self,

def update_root(self, root_dir: str):
pass

def build(self,
channel_order: list[int] | None = None) -> 'RasterTransformer':
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,21 @@ class ReclassTransformer(RasterTransformer):
"""Maps class IDs in a label raster to other values."""

def __init__(self, mapping: dict[int, int]):
"""Construct a new ReclassTransformer.
"""Constructor.
Args:
mapping: (dict) Remapping dictionary
mapping: Remapping dictionary, value_from-->value_to.
"""
self.mapping = mapping

def transform(self,
chip: 'np.ndarray',
channel_order: list[int] | None = None):
"""Transform a chip.
Reclassify a label raster using the given mapping.
def transform(self, chip: 'np.ndarray'):
"""Reclassify a label raster using the given mapping.
Args:
chip: ndarray of shape [height, width, channels] This is assumed to already
have the channel_order applied to it if channel_order is set. In other
words, channels should be equal to len(channel_order).
channel_order: list of indices of channels that were extracted from the
raw imagery.
chip: Array of shape (..., H, W, C).
Returns:
[height, width, channels] numpy array
Array of shape (..., H, W, C)
"""
masks = []
for (value_from, value_to) in self.mapping.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ class ReclassTransformerConfig(RasterTransformerConfig):
mapping: dict[int, int] = Field(
..., description=('The reclassification mapping.'))

def build(self) -> ReclassTransformer:
def build(self,
channel_order: list[int] | None = None) -> ReclassTransformer:
return ReclassTransformer(mapping=self.mapping)
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,14 @@ def __init__(self, class_config: 'ClassConfig'):
],
dtype=np.uint8)

def transform(self,
chip: np.ndarray,
channel_order: list[int] | None = None) -> np.ndarray:
def transform(self, chip: np.ndarray) -> np.ndarray:
"""Transform RGB array to array of class IDs or vice versa.
Args:
chip (np.ndarray): Numpy array of shape (H, W, 3).
channel_order (list[int] | None): List of indices of
channels that were extracted from the raw imagery.
Defaults to None.
chip: Numpy array of shape (H, W, 3).
Returns:
np.ndarray: An array of class IDs.
An array of class IDs of shape (H, W, 1).
"""
return self.rgb_to_class(chip)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ class RGBClassTransformerConfig(RasterTransformerConfig):
description=('The class config defining the mapping between '
'classes and colors.'))

def build(self) -> RGBClassTransformer:
def build(self,
channel_order: list[int] | None = None) -> RGBClassTransformer:
return RGBClassTransformer(class_config=self.class_config)
Loading

0 comments on commit b263422

Please sign in to comment.